mas_handlers/compat/
refresh.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use chrono::Duration;
9use hyper::StatusCode;
10use mas_axum_utils::record_error;
11use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
12use mas_storage::{
13    BoxClock, BoxRepository, BoxRng, Clock,
14    compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
15};
16use serde::{Deserialize, Serialize};
17use serde_with::{DurationMilliSeconds, serde_as};
18use thiserror::Error;
19use ulid::Ulid;
20
21use super::MatrixError;
22use crate::{BoundActivityTracker, impl_from_error_for_route};
23
24#[derive(Debug, Deserialize)]
25pub struct RequestBody {
26    refresh_token: String,
27}
28
29#[derive(Debug, Error)]
30pub enum RouteError {
31    #[error(transparent)]
32    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[error("invalid token")]
35    InvalidToken(#[from] TokenFormatError),
36
37    #[error("unknown token")]
38    UnknownToken,
39
40    #[error("invalid token type {0}, expected a compat refresh token")]
41    InvalidTokenType(TokenType),
42
43    #[error("refresh token already consumed {0}")]
44    RefreshTokenConsumed(Ulid),
45
46    #[error("invalid compat session {0}")]
47    InvalidSession(Ulid),
48
49    #[error("unknown comapt session {0}")]
50    UnknownSession(Ulid),
51}
52
53impl IntoResponse for RouteError {
54    fn into_response(self) -> axum::response::Response {
55        let sentry_event_id = record_error!(self, Self::Internal(_) | Self::UnknownSession(_));
56        let response = match self {
57            Self::Internal(_) | Self::UnknownSession(_) => MatrixError {
58                errcode: "M_UNKNOWN",
59                error: "Internal error",
60                status: StatusCode::INTERNAL_SERVER_ERROR,
61            },
62            Self::InvalidToken(_)
63            | Self::UnknownToken
64            | Self::InvalidTokenType(_)
65            | Self::InvalidSession(_)
66            | Self::RefreshTokenConsumed(_) => MatrixError {
67                errcode: "M_UNKNOWN_TOKEN",
68                error: "Invalid refresh token",
69                status: StatusCode::UNAUTHORIZED,
70            },
71        };
72
73        (sentry_event_id, response).into_response()
74    }
75}
76
77impl_from_error_for_route!(mas_storage::RepositoryError);
78
79#[serde_as]
80#[derive(Debug, Serialize)]
81pub struct ResponseBody {
82    access_token: String,
83    refresh_token: String,
84    #[serde_as(as = "DurationMilliSeconds<i64>")]
85    expires_in_ms: Duration,
86}
87
88#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all)]
89pub(crate) async fn post(
90    mut rng: BoxRng,
91    clock: BoxClock,
92    mut repo: BoxRepository,
93    activity_tracker: BoundActivityTracker,
94    State(site_config): State<SiteConfig>,
95    Json(input): Json<RequestBody>,
96) -> Result<impl IntoResponse, RouteError> {
97    let token_type = TokenType::check(&input.refresh_token)?;
98
99    if token_type != TokenType::CompatRefreshToken {
100        return Err(RouteError::InvalidTokenType(token_type));
101    }
102
103    let refresh_token = repo
104        .compat_refresh_token()
105        .find_by_token(&input.refresh_token)
106        .await?
107        .ok_or(RouteError::UnknownToken)?;
108
109    if !refresh_token.is_valid() {
110        return Err(RouteError::RefreshTokenConsumed(refresh_token.id));
111    }
112
113    let session = repo
114        .compat_session()
115        .lookup(refresh_token.session_id)
116        .await?
117        .ok_or(RouteError::UnknownSession(refresh_token.session_id))?;
118
119    if !session.is_valid() {
120        return Err(RouteError::InvalidSession(refresh_token.session_id));
121    }
122
123    activity_tracker
124        .record_compat_session(&clock, &session)
125        .await;
126
127    let access_token = repo
128        .compat_access_token()
129        .lookup(refresh_token.access_token_id)
130        .await?
131        .filter(|t| t.is_valid(clock.now()));
132
133    let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
134    let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
135
136    let expires_in = site_config.compat_token_ttl;
137    let new_access_token = repo
138        .compat_access_token()
139        .add(
140            &mut rng,
141            &clock,
142            &session,
143            new_access_token_str,
144            Some(expires_in),
145        )
146        .await?;
147    let new_refresh_token = repo
148        .compat_refresh_token()
149        .add(
150            &mut rng,
151            &clock,
152            &session,
153            &new_access_token,
154            new_refresh_token_str,
155        )
156        .await?;
157
158    repo.compat_refresh_token()
159        .consume(&clock, refresh_token)
160        .await?;
161
162    if let Some(access_token) = access_token {
163        repo.compat_access_token()
164            .expire(&clock, access_token)
165            .await?;
166    }
167
168    repo.save().await?;
169
170    Ok(Json(ResponseBody {
171        access_token: new_access_token.token,
172        refresh_token: new_refresh_token.token,
173        expires_in_ms: expires_in,
174    }))
175}