mas_handlers/oauth2/authorization/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{
8    extract::{Form, State},
9    response::{IntoResponse, Response},
10};
11use hyper::StatusCode;
12use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, record_error};
13use mas_data_model::{AuthorizationCode, Pkce};
14use mas_router::{PostAuthAction, UrlBuilder};
15use mas_storage::{
16    BoxClock, BoxRepository, BoxRng,
17    oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository},
18};
19use mas_templates::Templates;
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    pkce,
23    requests::{AuthorizationRequest, GrantType, Prompt, ResponseMode},
24    response_type::ResponseType,
25};
26use rand::{Rng, distributions::Alphanumeric};
27use serde::Deserialize;
28use thiserror::Error;
29
30use self::callback::CallbackDestination;
31use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
32
33mod callback;
34pub(crate) mod consent;
35
36#[derive(Debug, Error)]
37pub enum RouteError {
38    #[error(transparent)]
39    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
40
41    #[error("could not find client")]
42    ClientNotFound,
43
44    #[error("invalid response mode")]
45    InvalidResponseMode,
46
47    #[error("invalid parameters")]
48    IntoCallbackDestination(#[from] self::callback::IntoCallbackDestinationError),
49
50    #[error("invalid redirect uri")]
51    UnknownRedirectUri(#[from] mas_data_model::InvalidRedirectUriError),
52}
53
54impl IntoResponse for RouteError {
55    fn into_response(self) -> axum::response::Response {
56        let sentry_event_id = record_error!(self, Self::Internal(_));
57        // TODO: better error pages
58        let response = match self {
59            RouteError::Internal(e) => {
60                (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
61            }
62            RouteError::ClientNotFound => {
63                (StatusCode::BAD_REQUEST, "could not find client").into_response()
64            }
65            RouteError::InvalidResponseMode => {
66                (StatusCode::BAD_REQUEST, "invalid response mode").into_response()
67            }
68            RouteError::IntoCallbackDestination(e) => {
69                (StatusCode::BAD_REQUEST, e.to_string()).into_response()
70            }
71            RouteError::UnknownRedirectUri(e) => (
72                StatusCode::BAD_REQUEST,
73                format!("Invalid redirect URI ({e})"),
74            )
75                .into_response(),
76        };
77
78        (sentry_event_id, response).into_response()
79    }
80}
81
82impl_from_error_for_route!(mas_storage::RepositoryError);
83impl_from_error_for_route!(mas_templates::TemplateError);
84impl_from_error_for_route!(self::callback::CallbackDestinationError);
85impl_from_error_for_route!(mas_policy::LoadError);
86impl_from_error_for_route!(mas_policy::EvaluationError);
87
88#[derive(Deserialize)]
89pub(crate) struct Params {
90    #[serde(flatten)]
91    auth: AuthorizationRequest,
92
93    #[serde(flatten)]
94    pkce: Option<pkce::AuthorizationRequest>,
95}
96
97/// Given a list of response types and an optional user-defined response mode,
98/// figure out what response mode must be used, and emit an error if the
99/// suggested response mode isn't allowed for the given response types.
100fn resolve_response_mode(
101    response_type: &ResponseType,
102    suggested_response_mode: Option<ResponseMode>,
103) -> Result<ResponseMode, RouteError> {
104    use ResponseMode as M;
105
106    // If the response type includes either "token" or "id_token", the default
107    // response mode is "fragment" and the response mode "query" must not be
108    // used
109    if response_type.has_token() || response_type.has_id_token() {
110        match suggested_response_mode {
111            None => Ok(M::Fragment),
112            Some(M::Query) => Err(RouteError::InvalidResponseMode),
113            Some(mode) => Ok(mode),
114        }
115    } else {
116        // In other cases, all response modes are allowed, defaulting to "query"
117        Ok(suggested_response_mode.unwrap_or(M::Query))
118    }
119}
120
121#[tracing::instrument(
122    name = "handlers.oauth2.authorization.get",
123    fields(client.id = %params.auth.client_id),
124    skip_all,
125)]
126#[allow(clippy::too_many_lines)]
127pub(crate) async fn get(
128    mut rng: BoxRng,
129    clock: BoxClock,
130    PreferredLanguage(locale): PreferredLanguage,
131    State(templates): State<Templates>,
132    State(url_builder): State<UrlBuilder>,
133    activity_tracker: BoundActivityTracker,
134    mut repo: BoxRepository,
135    cookie_jar: CookieJar,
136    Form(params): Form<Params>,
137) -> Result<Response, RouteError> {
138    // First, figure out what client it is
139    let client = repo
140        .oauth2_client()
141        .find_by_client_id(&params.auth.client_id)
142        .await?
143        .ok_or(RouteError::ClientNotFound)?;
144
145    // And resolve the redirect_uri and response_mode
146    let redirect_uri = client
147        .resolve_redirect_uri(&params.auth.redirect_uri)?
148        .clone();
149    let response_type = params.auth.response_type;
150    let response_mode = resolve_response_mode(&response_type, params.auth.response_mode)?;
151
152    // Now we have a proper callback destination to go to on error
153    let callback_destination = CallbackDestination::try_new(
154        &response_mode,
155        redirect_uri.clone(),
156        params.auth.state.clone(),
157    )?;
158
159    // Get the session info from the cookie
160    let (session_info, cookie_jar) = cookie_jar.session_info();
161
162    // One day, we will have try blocks
163    let res: Result<Response, RouteError> = ({
164        let templates = templates.clone();
165        let callback_destination = callback_destination.clone();
166        let locale = locale.clone();
167        async move {
168            let maybe_session = session_info.load_active_session(&mut repo).await?;
169            let prompt = params.auth.prompt.as_deref().unwrap_or_default();
170
171            // Check if the request/request_uri/registration params are used. If so, reply
172            // with the right error since we don't support them.
173            if params.auth.request.is_some() {
174                return Ok(callback_destination.go(
175                    &templates,
176                    &locale,
177                    ClientError::from(ClientErrorCode::RequestNotSupported),
178                )?);
179            }
180
181            if params.auth.request_uri.is_some() {
182                return Ok(callback_destination.go(
183                    &templates,
184                    &locale,
185                    ClientError::from(ClientErrorCode::RequestUriNotSupported),
186                )?);
187            }
188
189            // Check if the client asked for a `token` response type, and bail out if it's
190            // the case, since we don't support them
191            if response_type.has_token() {
192                return Ok(callback_destination.go(
193                    &templates,
194                    &locale,
195                    ClientError::from(ClientErrorCode::UnsupportedResponseType),
196                )?);
197            }
198
199            // If the client asked for a `id_token` response type, we must check if it can
200            // use the `implicit` grant type
201            if response_type.has_id_token() && !client.grant_types.contains(&GrantType::Implicit) {
202                return Ok(callback_destination.go(
203                    &templates,
204                    &locale,
205                    ClientError::from(ClientErrorCode::UnauthorizedClient),
206                )?);
207            }
208
209            if params.auth.registration.is_some() {
210                return Ok(callback_destination.go(
211                    &templates,
212                    &locale,
213                    ClientError::from(ClientErrorCode::RegistrationNotSupported),
214                )?);
215            }
216
217            // Fail early if prompt=none; we never let it go through
218            if prompt.contains(&Prompt::None) {
219                return Ok(callback_destination.go(
220                    &templates,
221                    &locale,
222                    ClientError::from(ClientErrorCode::LoginRequired),
223                )?);
224            }
225
226            let code: Option<AuthorizationCode> = if response_type.has_code() {
227                // Check if it is allowed to use this grant type
228                if !client.grant_types.contains(&GrantType::AuthorizationCode) {
229                    return Ok(callback_destination.go(
230                        &templates,
231                        &locale,
232                        ClientError::from(ClientErrorCode::UnauthorizedClient),
233                    )?);
234                }
235
236                // 32 random alphanumeric characters, about 190bit of entropy
237                let code: String = (&mut rng)
238                    .sample_iter(&Alphanumeric)
239                    .take(32)
240                    .map(char::from)
241                    .collect();
242
243                let pkce = params.pkce.map(|p| Pkce {
244                    challenge: p.code_challenge,
245                    challenge_method: p.code_challenge_method,
246                });
247
248                Some(AuthorizationCode { code, pkce })
249            } else {
250                // If the request had PKCE params but no code asked, it should get back with an
251                // error
252                if params.pkce.is_some() {
253                    return Ok(callback_destination.go(
254                        &templates,
255                        &locale,
256                        ClientError::from(ClientErrorCode::InvalidRequest),
257                    )?);
258                }
259
260                None
261            };
262
263            let grant = repo
264                .oauth2_authorization_grant()
265                .add(
266                    &mut rng,
267                    &clock,
268                    &client,
269                    redirect_uri.clone(),
270                    params.auth.scope,
271                    code,
272                    params.auth.state.clone(),
273                    params.auth.nonce,
274                    response_mode,
275                    response_type.has_id_token(),
276                    params.auth.login_hint,
277                )
278                .await?;
279            let continue_grant = PostAuthAction::continue_grant(grant.id);
280
281            let res = match maybe_session {
282                None if prompt.contains(&Prompt::Create) => {
283                    // Client asked for a registration, show the registration prompt
284                    repo.save().await?;
285
286                    url_builder
287                        .redirect(&mas_router::Register::and_then(continue_grant))
288                        .into_response()
289                }
290
291                None => {
292                    // Other cases where we don't have a session, ask for a login
293                    repo.save().await?;
294
295                    url_builder
296                        .redirect(&mas_router::Login::and_then(continue_grant))
297                        .into_response()
298                }
299
300                Some(user_session) => {
301                    // TODO: better support for prompt=create when we have a session
302                    repo.save().await?;
303
304                    activity_tracker
305                        .record_browser_session(&clock, &user_session)
306                        .await;
307                    url_builder
308                        .redirect(&mas_router::Consent(grant.id))
309                        .into_response()
310                }
311            };
312
313            Ok(res)
314        }
315    })
316    .await;
317
318    let response = match res {
319        Ok(r) => r,
320        Err(err) => {
321            tracing::error!(message = &err as &dyn std::error::Error);
322            callback_destination.go(
323                &templates,
324                &locale,
325                ClientError::from(ClientErrorCode::ServerError),
326            )?
327        }
328    };
329
330    Ok((cookie_jar, response).into_response())
331}