mas_handlers/oauth2/
revoke.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use hyper::StatusCode;
9use mas_axum_utils::{
10    client_authorization::{ClientAuthorization, CredentialsVerificationError},
11    record_error,
12};
13use mas_data_model::TokenType;
14use mas_iana::oauth::OAuthTokenTypeHint;
15use mas_keystore::Encrypter;
16use mas_storage::{
17    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
18    queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
19};
20use oauth2_types::{
21    errors::{ClientError, ClientErrorCode},
22    requests::RevocationRequest,
23};
24use thiserror::Error;
25use ulid::Ulid;
26
27use crate::{BoundActivityTracker, impl_from_error_for_route};
28
29#[derive(Debug, Error)]
30pub(crate) enum RouteError {
31    #[error(transparent)]
32    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[error("bad request")]
35    BadRequest,
36
37    #[error("client not found")]
38    ClientNotFound,
39
40    #[error("client not allowed")]
41    ClientNotAllowed,
42
43    #[error("invalid client credentials for client {client_id}")]
44    InvalidClientCredentials {
45        client_id: Ulid,
46        #[source]
47        source: CredentialsVerificationError,
48    },
49
50    #[error("could not verify client credentials for client {client_id}")]
51    ClientCredentialsVerification {
52        client_id: Ulid,
53        #[source]
54        source: CredentialsVerificationError,
55    },
56
57    #[error("client is unauthorized")]
58    UnauthorizedClient,
59
60    #[error("unsupported token type")]
61    UnsupportedTokenType,
62
63    #[error("unknown token")]
64    UnknownToken,
65}
66
67impl IntoResponse for RouteError {
68    fn into_response(self) -> axum::response::Response {
69        let sentry_event_id = record_error!(self, Self::Internal(_));
70        let response = match self {
71            Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
72                StatusCode::INTERNAL_SERVER_ERROR,
73                Json(ClientError::from(ClientErrorCode::ServerError)),
74            )
75                .into_response(),
76
77            Self::BadRequest => (
78                StatusCode::BAD_REQUEST,
79                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
80            )
81                .into_response(),
82
83            Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
84                StatusCode::UNAUTHORIZED,
85                Json(ClientError::from(ClientErrorCode::InvalidClient)),
86            )
87                .into_response(),
88
89            Self::ClientNotAllowed | Self::UnauthorizedClient => (
90                StatusCode::UNAUTHORIZED,
91                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
92            )
93                .into_response(),
94
95            Self::UnsupportedTokenType => (
96                StatusCode::BAD_REQUEST,
97                Json(ClientError::from(ClientErrorCode::UnsupportedTokenType)),
98            )
99                .into_response(),
100
101            // If the token is unknown, we still return a 200 OK response.
102            Self::UnknownToken => StatusCode::OK.into_response(),
103        };
104
105        (sentry_event_id, response).into_response()
106    }
107}
108
109impl_from_error_for_route!(mas_storage::RepositoryError);
110
111impl From<mas_data_model::TokenFormatError> for RouteError {
112    fn from(_e: mas_data_model::TokenFormatError) -> Self {
113        Self::UnknownToken
114    }
115}
116
117#[tracing::instrument(
118    name = "handlers.oauth2.revoke.post",
119    fields(client.id = client_authorization.client_id()),
120    skip_all,
121)]
122pub(crate) async fn post(
123    clock: BoxClock,
124    mut rng: BoxRng,
125    State(http_client): State<reqwest::Client>,
126    mut repo: BoxRepository,
127    activity_tracker: BoundActivityTracker,
128    State(encrypter): State<Encrypter>,
129    client_authorization: ClientAuthorization<RevocationRequest>,
130) -> Result<impl IntoResponse, RouteError> {
131    let client = client_authorization
132        .credentials
133        .fetch(&mut repo)
134        .await?
135        .ok_or(RouteError::ClientNotFound)?;
136
137    let method = client
138        .token_endpoint_auth_method
139        .as_ref()
140        .ok_or(RouteError::ClientNotAllowed)?;
141
142    client_authorization
143        .credentials
144        .verify(&http_client, &encrypter, method, &client)
145        .await
146        .map_err(|err| {
147            if err.is_internal() {
148                RouteError::ClientCredentialsVerification {
149                    client_id: client.id,
150                    source: err,
151                }
152            } else {
153                RouteError::InvalidClientCredentials {
154                    client_id: client.id,
155                    source: err,
156                }
157            }
158        })?;
159
160    let Some(form) = client_authorization.form else {
161        return Err(RouteError::BadRequest);
162    };
163
164    let token_type = TokenType::check(&form.token)?;
165
166    // Find the ID of the session to end.
167    let session_id = match (form.token_type_hint, token_type) {
168        (Some(OAuthTokenTypeHint::AccessToken) | None, TokenType::AccessToken) => {
169            let access_token = repo
170                .oauth2_access_token()
171                .find_by_token(&form.token)
172                .await?
173                .ok_or(RouteError::UnknownToken)?;
174
175            if !access_token.is_valid(clock.now()) {
176                return Err(RouteError::UnknownToken);
177            }
178            access_token.session_id
179        }
180
181        (Some(OAuthTokenTypeHint::RefreshToken) | None, TokenType::RefreshToken) => {
182            let refresh_token = repo
183                .oauth2_refresh_token()
184                .find_by_token(&form.token)
185                .await?
186                .ok_or(RouteError::UnknownToken)?;
187
188            if !refresh_token.is_valid() {
189                return Err(RouteError::UnknownToken);
190            }
191
192            refresh_token.session_id
193        }
194
195        // This case can happen if there is a mismatch between the token type hint and the guessed
196        // token type or if the token was a compat access/refresh token. In those cases, we return
197        // an unknown token error.
198        (Some(OAuthTokenTypeHint::AccessToken | OAuthTokenTypeHint::RefreshToken) | None, _) => {
199            return Err(RouteError::UnknownToken);
200        }
201
202        (Some(_), _) => return Err(RouteError::UnsupportedTokenType),
203    };
204
205    let session = repo
206        .oauth2_session()
207        .lookup(session_id)
208        .await?
209        .ok_or(RouteError::UnknownToken)?;
210
211    // Check that the session is still valid.
212    if !session.is_valid() {
213        return Err(RouteError::UnknownToken);
214    }
215
216    // Check that the client ending the session is the same as the client that
217    // created it.
218    if client.id != session.client_id {
219        return Err(RouteError::UnauthorizedClient);
220    }
221
222    activity_tracker
223        .record_oauth2_session(&clock, &session)
224        .await;
225
226    // If the session is associated with a user, make sure we schedule a device
227    // deletion job for all the devices associated with the session.
228    if let Some(user_id) = session.user_id {
229        // Fetch the user
230        let user = repo
231            .user()
232            .lookup(user_id)
233            .await?
234            .ok_or(RouteError::UnknownToken)?;
235
236        // Schedule a job to sync the devices of the user with the homeserver
237        repo.queue_job()
238            .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
239            .await?;
240    }
241
242    // Now that we checked everything, we can end the session.
243    repo.oauth2_session().finish(&clock, session).await?;
244
245    repo.save().await?;
246
247    Ok(())
248}
249
250#[cfg(test)]
251mod tests {
252    use chrono::Duration;
253    use hyper::Request;
254    use mas_data_model::{AccessToken, RefreshToken};
255    use mas_router::SimpleRoute;
256    use mas_storage::RepositoryAccess;
257    use oauth2_types::{
258        registration::ClientRegistrationResponse,
259        requests::AccessTokenResponse,
260        scope::{OPENID, Scope},
261    };
262    use sqlx::PgPool;
263
264    use super::*;
265    use crate::{
266        oauth2::generate_token_pair,
267        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
268    };
269
270    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
271    async fn test_revoke_access_token(pool: PgPool) {
272        setup();
273        let state = TestState::from_pool(pool).await.unwrap();
274
275        let request =
276            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
277                "client_uri": "https://example.com/",
278                "redirect_uris": ["https://example.com/callback"],
279                "token_endpoint_auth_method": "client_secret_post",
280                "response_types": ["code"],
281                "grant_types": ["authorization_code", "refresh_token"],
282            }));
283
284        let response = state.request(request).await;
285        response.assert_status(StatusCode::CREATED);
286
287        let client_registration: ClientRegistrationResponse = response.json();
288
289        let client_id = client_registration.client_id;
290        let client_secret = client_registration.client_secret.unwrap();
291
292        // Let's provision a user and create a session for them. This part is hard to
293        // test with just HTTP requests, so we'll use the repository directly.
294        let mut repo = state.repository().await.unwrap();
295
296        let user = repo
297            .user()
298            .add(&mut state.rng(), &state.clock, "alice".to_owned())
299            .await
300            .unwrap();
301
302        let browser_session = repo
303            .browser_session()
304            .add(&mut state.rng(), &state.clock, &user, None)
305            .await
306            .unwrap();
307
308        // Lookup the client in the database.
309        let client = repo
310            .oauth2_client()
311            .find_by_client_id(&client_id)
312            .await
313            .unwrap()
314            .unwrap();
315
316        let session = repo
317            .oauth2_session()
318            .add_from_browser_session(
319                &mut state.rng(),
320                &state.clock,
321                &client,
322                &browser_session,
323                Scope::from_iter([OPENID]),
324            )
325            .await
326            .unwrap();
327
328        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
329            generate_token_pair(
330                &mut state.rng(),
331                &state.clock,
332                &mut repo,
333                &session,
334                Duration::microseconds(5 * 60 * 1000 * 1000),
335            )
336            .await
337            .unwrap();
338
339        repo.save().await.unwrap();
340
341        // Check that the token is valid
342        assert!(state.is_access_token_valid(&access_token).await);
343
344        // Now let's revoke the access token.
345        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
346            "token": access_token,
347            "token_type_hint": "access_token",
348            "client_id": client_id,
349            "client_secret": client_secret,
350        }));
351
352        let response = state.request(request).await;
353        response.assert_status(StatusCode::OK);
354
355        // Check that the token is no longer valid
356        assert!(!state.is_access_token_valid(&access_token).await);
357
358        // Revoking a second time shouldn't fail
359        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
360            "token": access_token,
361            "token_type_hint": "access_token",
362            "client_id": client_id,
363            "client_secret": client_secret,
364        }));
365
366        let response = state.request(request).await;
367        response.assert_status(StatusCode::OK);
368
369        // Try using the refresh token to get a new access token, it should fail.
370        let request =
371            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
372                "grant_type": "refresh_token",
373                "refresh_token": refresh_token,
374                "client_id": client_id,
375                "client_secret": client_secret,
376            }));
377
378        let response = state.request(request).await;
379        response.assert_status(StatusCode::BAD_REQUEST);
380
381        // Now try with a new grant, and by revoking the refresh token instead
382        let mut repo = state.repository().await.unwrap();
383        let session = repo
384            .oauth2_session()
385            .add_from_browser_session(
386                &mut state.rng(),
387                &state.clock,
388                &client,
389                &browser_session,
390                Scope::from_iter([OPENID]),
391            )
392            .await
393            .unwrap();
394
395        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
396            generate_token_pair(
397                &mut state.rng(),
398                &state.clock,
399                &mut repo,
400                &session,
401                Duration::microseconds(5 * 60 * 1000 * 1000),
402            )
403            .await
404            .unwrap();
405
406        repo.save().await.unwrap();
407
408        // Use the refresh token to get a new access token.
409        let request =
410            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
411                "grant_type": "refresh_token",
412                "refresh_token": refresh_token,
413                "client_id": client_id,
414                "client_secret": client_secret,
415            }));
416
417        let response = state.request(request).await;
418        response.assert_status(StatusCode::OK);
419
420        let old_access_token = access_token;
421        let old_refresh_token = refresh_token;
422        let AccessTokenResponse {
423            access_token,
424            refresh_token,
425            ..
426        } = response.json();
427        assert!(state.is_access_token_valid(&access_token).await);
428        assert!(!state.is_access_token_valid(&old_access_token).await);
429
430        // Revoking the old access token shouldn't do anything.
431        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
432            "token": old_access_token,
433            "token_type_hint": "access_token",
434            "client_id": client_id,
435            "client_secret": client_secret,
436        }));
437
438        let response = state.request(request).await;
439        response.assert_status(StatusCode::OK);
440
441        assert!(state.is_access_token_valid(&access_token).await);
442
443        // Revoking the old refresh token shouldn't do anything.
444        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
445            "token": old_refresh_token,
446            "token_type_hint": "refresh_token",
447            "client_id": client_id,
448            "client_secret": client_secret,
449        }));
450
451        let response = state.request(request).await;
452        response.assert_status(StatusCode::OK);
453
454        assert!(state.is_access_token_valid(&access_token).await);
455
456        // Revoking the new refresh token should invalidate the session
457        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
458            "token": refresh_token,
459            "token_type_hint": "refresh_token",
460            "client_id": client_id,
461            "client_secret": client_secret,
462        }));
463
464        let response = state.request(request).await;
465        response.assert_status(StatusCode::OK);
466
467        assert!(!state.is_access_token_valid(&access_token).await);
468    }
469}