mas_handlers/oauth2/device/
authorize.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13    client_authorization::{ClientAuthorization, CredentialsVerificationError},
14    record_error,
15};
16use mas_keystore::Encrypter;
17use mas_router::UrlBuilder;
18use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2DeviceCodeGrantParams};
19use oauth2_types::{
20    errors::{ClientError, ClientErrorCode},
21    requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
22    scope::ScopeToken,
23};
24use rand::distributions::{Alphanumeric, DistString};
25use thiserror::Error;
26use ulid::Ulid;
27
28use crate::{BoundActivityTracker, impl_from_error_for_route};
29
30#[derive(Debug, Error)]
31pub(crate) enum RouteError {
32    #[error(transparent)]
33    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35    #[error("client not found")]
36    ClientNotFound,
37
38    #[error("client {0} is not allowed to use the device code grant")]
39    ClientNotAllowed(Ulid),
40
41    #[error("invalid client credentials for client {client_id}")]
42    InvalidClientCredentials {
43        client_id: Ulid,
44        #[source]
45        source: CredentialsVerificationError,
46    },
47
48    #[error("could not verify client credentials for client {client_id}")]
49    ClientCredentialsVerification {
50        client_id: Ulid,
51        #[source]
52        source: CredentialsVerificationError,
53    },
54}
55
56impl_from_error_for_route!(mas_storage::RepositoryError);
57
58impl IntoResponse for RouteError {
59    fn into_response(self) -> axum::response::Response {
60        let sentry_event_id = record_error!(self, Self::Internal(_));
61
62        let response = match self {
63            Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
64                StatusCode::INTERNAL_SERVER_ERROR,
65                Json(ClientError::from(ClientErrorCode::ServerError)),
66            ),
67            Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
68                StatusCode::UNAUTHORIZED,
69                Json(ClientError::from(ClientErrorCode::InvalidClient)),
70            ),
71            Self::ClientNotAllowed(_) => (
72                StatusCode::UNAUTHORIZED,
73                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
74            ),
75        };
76
77        (sentry_event_id, response).into_response()
78    }
79}
80
81#[tracing::instrument(
82    name = "handlers.oauth2.device.request.post",
83    fields(client.id = client_authorization.client_id()),
84    skip_all,
85)]
86pub(crate) async fn post(
87    mut rng: BoxRng,
88    clock: BoxClock,
89    mut repo: BoxRepository,
90    user_agent: Option<TypedHeader<headers::UserAgent>>,
91    activity_tracker: BoundActivityTracker,
92    State(url_builder): State<UrlBuilder>,
93    State(http_client): State<reqwest::Client>,
94    State(encrypter): State<Encrypter>,
95    client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
96) -> Result<impl IntoResponse, RouteError> {
97    let client = client_authorization
98        .credentials
99        .fetch(&mut repo)
100        .await?
101        .ok_or(RouteError::ClientNotFound)?;
102
103    // Reuse the token endpoint auth method to verify the client
104    let method = client
105        .token_endpoint_auth_method
106        .as_ref()
107        .ok_or(RouteError::ClientNotAllowed(client.id))?;
108
109    client_authorization
110        .credentials
111        .verify(&http_client, &encrypter, method, &client)
112        .await
113        .map_err(|err| {
114            if err.is_internal() {
115                RouteError::ClientCredentialsVerification {
116                    client_id: client.id,
117                    source: err,
118                }
119            } else {
120                RouteError::InvalidClientCredentials {
121                    client_id: client.id,
122                    source: err,
123                }
124            }
125        })?;
126
127    if !client.grant_types.contains(&GrantType::DeviceCode) {
128        return Err(RouteError::ClientNotAllowed(client.id));
129    }
130
131    let scope = client_authorization
132        .form
133        .and_then(|f| f.scope)
134        // XXX: Is this really how we do empty scopes?
135        .unwrap_or(std::iter::empty::<ScopeToken>().collect());
136
137    let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
138
139    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
140    let ip_address = activity_tracker.ip();
141
142    let device_code = Alphanumeric.sample_string(&mut rng, 32);
143    let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
144
145    let device_code = repo
146        .oauth2_device_code_grant()
147        .add(
148            &mut rng,
149            &clock,
150            OAuth2DeviceCodeGrantParams {
151                client: &client,
152                scope,
153                device_code,
154                user_code,
155                expires_in,
156                user_agent,
157                ip_address,
158            },
159        )
160        .await?;
161
162    repo.save().await?;
163
164    let response = DeviceAuthorizationResponse {
165        device_code: device_code.device_code,
166        user_code: device_code.user_code.clone(),
167        verification_uri: url_builder.device_code_link(),
168        verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
169        expires_in,
170        interval: Some(Duration::microseconds(5 * 1000 * 1000)),
171    };
172
173    Ok((
174        StatusCode::OK,
175        TypedHeader(CacheControl::new().with_no_store()),
176        TypedHeader(Pragma::no_cache()),
177        Json(response),
178    ))
179}
180
181#[cfg(test)]
182mod tests {
183    use hyper::{Request, StatusCode};
184    use mas_router::SimpleRoute;
185    use oauth2_types::{
186        registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
187    };
188    use sqlx::PgPool;
189
190    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
191
192    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
193    async fn test_device_code_request(pool: PgPool) {
194        setup();
195        let state = TestState::from_pool(pool).await.unwrap();
196
197        // Provision a client
198        let request =
199            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
200                "client_uri": "https://example.com/",
201                "token_endpoint_auth_method": "none",
202                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
203                "response_types": [],
204            }));
205
206        let response = state.request(request).await;
207        response.assert_status(StatusCode::CREATED);
208
209        let response: ClientRegistrationResponse = response.json();
210        let client_id = response.client_id;
211
212        // Test the happy path: the client is allowed to use the device code grant type
213        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
214            serde_json::json!({
215                "client_id": client_id,
216                "scope": "openid",
217            }),
218        );
219        let response = state.request(request).await;
220        response.assert_status(StatusCode::OK);
221
222        let response: DeviceAuthorizationResponse = response.json();
223        assert_eq!(response.device_code.len(), 32);
224        assert_eq!(response.user_code.len(), 6);
225    }
226}