1use axum::{Json, extract::State, response::IntoResponse};
8use hyper::StatusCode;
9use mas_axum_utils::{
10 client_authorization::{ClientAuthorization, CredentialsVerificationError},
11 record_error,
12};
13use mas_data_model::TokenType;
14use mas_iana::oauth::OAuthTokenTypeHint;
15use mas_keystore::Encrypter;
16use mas_storage::{
17 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
18 queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
19};
20use oauth2_types::{
21 errors::{ClientError, ClientErrorCode},
22 requests::RevocationRequest,
23};
24use thiserror::Error;
25use ulid::Ulid;
26
27use crate::{BoundActivityTracker, impl_from_error_for_route};
28
29#[derive(Debug, Error)]
30pub(crate) enum RouteError {
31 #[error(transparent)]
32 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
33
34 #[error("bad request")]
35 BadRequest,
36
37 #[error("client not found")]
38 ClientNotFound,
39
40 #[error("client not allowed")]
41 ClientNotAllowed,
42
43 #[error("invalid client credentials for client {client_id}")]
44 InvalidClientCredentials {
45 client_id: Ulid,
46 #[source]
47 source: CredentialsVerificationError,
48 },
49
50 #[error("could not verify client credentials for client {client_id}")]
51 ClientCredentialsVerification {
52 client_id: Ulid,
53 #[source]
54 source: CredentialsVerificationError,
55 },
56
57 #[error("client is unauthorized")]
58 UnauthorizedClient,
59
60 #[error("unsupported token type")]
61 UnsupportedTokenType,
62
63 #[error("unknown token")]
64 UnknownToken,
65}
66
67impl IntoResponse for RouteError {
68 fn into_response(self) -> axum::response::Response {
69 let sentry_event_id = record_error!(self, Self::Internal(_));
70 let response = match self {
71 Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
72 StatusCode::INTERNAL_SERVER_ERROR,
73 Json(ClientError::from(ClientErrorCode::ServerError)),
74 )
75 .into_response(),
76
77 Self::BadRequest => (
78 StatusCode::BAD_REQUEST,
79 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
80 )
81 .into_response(),
82
83 Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
84 StatusCode::UNAUTHORIZED,
85 Json(ClientError::from(ClientErrorCode::InvalidClient)),
86 )
87 .into_response(),
88
89 Self::ClientNotAllowed | Self::UnauthorizedClient => (
90 StatusCode::UNAUTHORIZED,
91 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
92 )
93 .into_response(),
94
95 Self::UnsupportedTokenType => (
96 StatusCode::BAD_REQUEST,
97 Json(ClientError::from(ClientErrorCode::UnsupportedTokenType)),
98 )
99 .into_response(),
100
101 Self::UnknownToken => StatusCode::OK.into_response(),
103 };
104
105 (sentry_event_id, response).into_response()
106 }
107}
108
109impl_from_error_for_route!(mas_storage::RepositoryError);
110
111impl From<mas_data_model::TokenFormatError> for RouteError {
112 fn from(_e: mas_data_model::TokenFormatError) -> Self {
113 Self::UnknownToken
114 }
115}
116
117#[tracing::instrument(
118 name = "handlers.oauth2.revoke.post",
119 fields(client.id = client_authorization.client_id()),
120 skip_all,
121)]
122pub(crate) async fn post(
123 clock: BoxClock,
124 mut rng: BoxRng,
125 State(http_client): State<reqwest::Client>,
126 mut repo: BoxRepository,
127 activity_tracker: BoundActivityTracker,
128 State(encrypter): State<Encrypter>,
129 client_authorization: ClientAuthorization<RevocationRequest>,
130) -> Result<impl IntoResponse, RouteError> {
131 let client = client_authorization
132 .credentials
133 .fetch(&mut repo)
134 .await?
135 .ok_or(RouteError::ClientNotFound)?;
136
137 let method = client
138 .token_endpoint_auth_method
139 .as_ref()
140 .ok_or(RouteError::ClientNotAllowed)?;
141
142 client_authorization
143 .credentials
144 .verify(&http_client, &encrypter, method, &client)
145 .await
146 .map_err(|err| {
147 if err.is_internal() {
148 RouteError::ClientCredentialsVerification {
149 client_id: client.id,
150 source: err,
151 }
152 } else {
153 RouteError::InvalidClientCredentials {
154 client_id: client.id,
155 source: err,
156 }
157 }
158 })?;
159
160 let Some(form) = client_authorization.form else {
161 return Err(RouteError::BadRequest);
162 };
163
164 let token_type = TokenType::check(&form.token)?;
165
166 let session_id = match (form.token_type_hint, token_type) {
168 (Some(OAuthTokenTypeHint::AccessToken) | None, TokenType::AccessToken) => {
169 let access_token = repo
170 .oauth2_access_token()
171 .find_by_token(&form.token)
172 .await?
173 .ok_or(RouteError::UnknownToken)?;
174
175 if !access_token.is_valid(clock.now()) {
176 return Err(RouteError::UnknownToken);
177 }
178 access_token.session_id
179 }
180
181 (Some(OAuthTokenTypeHint::RefreshToken) | None, TokenType::RefreshToken) => {
182 let refresh_token = repo
183 .oauth2_refresh_token()
184 .find_by_token(&form.token)
185 .await?
186 .ok_or(RouteError::UnknownToken)?;
187
188 if !refresh_token.is_valid() {
189 return Err(RouteError::UnknownToken);
190 }
191
192 refresh_token.session_id
193 }
194
195 (Some(OAuthTokenTypeHint::AccessToken | OAuthTokenTypeHint::RefreshToken) | None, _) => {
199 return Err(RouteError::UnknownToken);
200 }
201
202 (Some(_), _) => return Err(RouteError::UnsupportedTokenType),
203 };
204
205 let session = repo
206 .oauth2_session()
207 .lookup(session_id)
208 .await?
209 .ok_or(RouteError::UnknownToken)?;
210
211 if !session.is_valid() {
213 return Err(RouteError::UnknownToken);
214 }
215
216 if client.id != session.client_id {
219 return Err(RouteError::UnauthorizedClient);
220 }
221
222 activity_tracker
223 .record_oauth2_session(&clock, &session)
224 .await;
225
226 if let Some(user_id) = session.user_id {
229 let user = repo
231 .user()
232 .lookup(user_id)
233 .await?
234 .ok_or(RouteError::UnknownToken)?;
235
236 repo.queue_job()
238 .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
239 .await?;
240 }
241
242 repo.oauth2_session().finish(&clock, session).await?;
244
245 repo.save().await?;
246
247 Ok(())
248}
249
250#[cfg(test)]
251mod tests {
252 use chrono::Duration;
253 use hyper::Request;
254 use mas_data_model::{AccessToken, RefreshToken};
255 use mas_router::SimpleRoute;
256 use mas_storage::RepositoryAccess;
257 use oauth2_types::{
258 registration::ClientRegistrationResponse,
259 requests::AccessTokenResponse,
260 scope::{OPENID, Scope},
261 };
262 use sqlx::PgPool;
263
264 use super::*;
265 use crate::{
266 oauth2::generate_token_pair,
267 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
268 };
269
270 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
271 async fn test_revoke_access_token(pool: PgPool) {
272 setup();
273 let state = TestState::from_pool(pool).await.unwrap();
274
275 let request =
276 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
277 "client_uri": "https://example.com/",
278 "redirect_uris": ["https://example.com/callback"],
279 "token_endpoint_auth_method": "client_secret_post",
280 "response_types": ["code"],
281 "grant_types": ["authorization_code", "refresh_token"],
282 }));
283
284 let response = state.request(request).await;
285 response.assert_status(StatusCode::CREATED);
286
287 let client_registration: ClientRegistrationResponse = response.json();
288
289 let client_id = client_registration.client_id;
290 let client_secret = client_registration.client_secret.unwrap();
291
292 let mut repo = state.repository().await.unwrap();
295
296 let user = repo
297 .user()
298 .add(&mut state.rng(), &state.clock, "alice".to_owned())
299 .await
300 .unwrap();
301
302 let browser_session = repo
303 .browser_session()
304 .add(&mut state.rng(), &state.clock, &user, None)
305 .await
306 .unwrap();
307
308 let client = repo
310 .oauth2_client()
311 .find_by_client_id(&client_id)
312 .await
313 .unwrap()
314 .unwrap();
315
316 let session = repo
317 .oauth2_session()
318 .add_from_browser_session(
319 &mut state.rng(),
320 &state.clock,
321 &client,
322 &browser_session,
323 Scope::from_iter([OPENID]),
324 )
325 .await
326 .unwrap();
327
328 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
329 generate_token_pair(
330 &mut state.rng(),
331 &state.clock,
332 &mut repo,
333 &session,
334 Duration::microseconds(5 * 60 * 1000 * 1000),
335 )
336 .await
337 .unwrap();
338
339 repo.save().await.unwrap();
340
341 assert!(state.is_access_token_valid(&access_token).await);
343
344 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
346 "token": access_token,
347 "token_type_hint": "access_token",
348 "client_id": client_id,
349 "client_secret": client_secret,
350 }));
351
352 let response = state.request(request).await;
353 response.assert_status(StatusCode::OK);
354
355 assert!(!state.is_access_token_valid(&access_token).await);
357
358 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
360 "token": access_token,
361 "token_type_hint": "access_token",
362 "client_id": client_id,
363 "client_secret": client_secret,
364 }));
365
366 let response = state.request(request).await;
367 response.assert_status(StatusCode::OK);
368
369 let request =
371 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
372 "grant_type": "refresh_token",
373 "refresh_token": refresh_token,
374 "client_id": client_id,
375 "client_secret": client_secret,
376 }));
377
378 let response = state.request(request).await;
379 response.assert_status(StatusCode::BAD_REQUEST);
380
381 let mut repo = state.repository().await.unwrap();
383 let session = repo
384 .oauth2_session()
385 .add_from_browser_session(
386 &mut state.rng(),
387 &state.clock,
388 &client,
389 &browser_session,
390 Scope::from_iter([OPENID]),
391 )
392 .await
393 .unwrap();
394
395 let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
396 generate_token_pair(
397 &mut state.rng(),
398 &state.clock,
399 &mut repo,
400 &session,
401 Duration::microseconds(5 * 60 * 1000 * 1000),
402 )
403 .await
404 .unwrap();
405
406 repo.save().await.unwrap();
407
408 let request =
410 Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
411 "grant_type": "refresh_token",
412 "refresh_token": refresh_token,
413 "client_id": client_id,
414 "client_secret": client_secret,
415 }));
416
417 let response = state.request(request).await;
418 response.assert_status(StatusCode::OK);
419
420 let old_access_token = access_token;
421 let old_refresh_token = refresh_token;
422 let AccessTokenResponse {
423 access_token,
424 refresh_token,
425 ..
426 } = response.json();
427 assert!(state.is_access_token_valid(&access_token).await);
428 assert!(!state.is_access_token_valid(&old_access_token).await);
429
430 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
432 "token": old_access_token,
433 "token_type_hint": "access_token",
434 "client_id": client_id,
435 "client_secret": client_secret,
436 }));
437
438 let response = state.request(request).await;
439 response.assert_status(StatusCode::OK);
440
441 assert!(state.is_access_token_valid(&access_token).await);
442
443 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
445 "token": old_refresh_token,
446 "token_type_hint": "refresh_token",
447 "client_id": client_id,
448 "client_secret": client_secret,
449 }));
450
451 let response = state.request(request).await;
452 response.assert_status(StatusCode::OK);
453
454 assert!(state.is_access_token_valid(&access_token).await);
455
456 let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
458 "token": refresh_token,
459 "token_type_hint": "refresh_token",
460 "client_id": client_id,
461 "client_secret": client_secret,
462 }));
463
464 let response = state.request(request).await;
465 response.assert_status(StatusCode::OK);
466
467 assert!(!state.is_access_token_valid(&access_token).await);
468 }
469}