mas_handlers/oauth2/device/
consent.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use anyhow::Context;
8use axum::{
9    Form,
10    extract::{Path, State},
11    response::{Html, IntoResponse, Response},
12};
13use axum_extra::TypedHeader;
14use mas_axum_utils::{
15    InternalError,
16    cookies::CookieJar,
17    csrf::{CsrfExt, ProtectedForm},
18};
19use mas_policy::Policy;
20use mas_router::UrlBuilder;
21use mas_storage::{BoxClock, BoxRepository, BoxRng};
22use mas_templates::{DeviceConsentContext, PolicyViolationContext, TemplateContext, Templates};
23use serde::Deserialize;
24use tracing::warn;
25use ulid::Ulid;
26
27use crate::{
28    BoundActivityTracker, PreferredLanguage,
29    session::{SessionOrFallback, load_session_or_fallback},
30};
31
32#[derive(Deserialize, Debug)]
33#[serde(rename_all = "lowercase")]
34enum Action {
35    Consent,
36    Reject,
37}
38
39#[derive(Deserialize, Debug)]
40pub(crate) struct ConsentForm {
41    action: Action,
42}
43
44#[tracing::instrument(name = "handlers.oauth2.device.consent.get", skip_all)]
45pub(crate) async fn get(
46    mut rng: BoxRng,
47    clock: BoxClock,
48    PreferredLanguage(locale): PreferredLanguage,
49    State(templates): State<Templates>,
50    State(url_builder): State<UrlBuilder>,
51    mut repo: BoxRepository,
52    mut policy: Policy,
53    activity_tracker: BoundActivityTracker,
54    user_agent: Option<TypedHeader<headers::UserAgent>>,
55    cookie_jar: CookieJar,
56    Path(grant_id): Path<Ulid>,
57) -> Result<Response, InternalError> {
58    let (cookie_jar, maybe_session) = match load_session_or_fallback(
59        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
60    )
61    .await?
62    {
63        SessionOrFallback::MaybeSession {
64            cookie_jar,
65            maybe_session,
66            ..
67        } => (cookie_jar, maybe_session),
68        SessionOrFallback::Fallback { response } => return Ok(response),
69    };
70
71    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
72
73    let user_agent = user_agent.map(|ua| ua.to_string());
74
75    let Some(session) = maybe_session else {
76        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
77        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
78    };
79
80    activity_tracker
81        .record_browser_session(&clock, &session)
82        .await;
83
84    // TODO: better error handling
85    let grant = repo
86        .oauth2_device_code_grant()
87        .lookup(grant_id)
88        .await?
89        .context("Device grant not found")
90        .map_err(InternalError::from_anyhow)?;
91
92    if grant.expires_at < clock.now() {
93        return Err(InternalError::from_anyhow(anyhow::anyhow!(
94            "Grant is expired"
95        )));
96    }
97
98    let client = repo
99        .oauth2_client()
100        .lookup(grant.client_id)
101        .await?
102        .context("Client not found")
103        .map_err(InternalError::from_anyhow)?;
104
105    // Evaluate the policy
106    let res = policy
107        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
108            grant_type: mas_policy::GrantType::DeviceCode,
109            client: &client,
110            scope: &grant.scope,
111            user: Some(&session.user),
112            requester: mas_policy::Requester {
113                ip_address: activity_tracker.ip(),
114                user_agent,
115            },
116        })
117        .await?;
118    if !res.valid() {
119        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
120
121        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
122        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
123            .with_session(session)
124            .with_csrf(csrf_token.form_value())
125            .with_language(locale);
126
127        let content = templates.render_policy_violation(&ctx)?;
128
129        return Ok((cookie_jar, Html(content)).into_response());
130    }
131
132    let ctx = DeviceConsentContext::new(grant, client)
133        .with_session(session)
134        .with_csrf(csrf_token.form_value())
135        .with_language(locale);
136
137    let rendered = templates
138        .render_device_consent(&ctx)
139        .context("Failed to render template")
140        .map_err(InternalError::from_anyhow)?;
141
142    Ok((cookie_jar, Html(rendered)).into_response())
143}
144
145#[tracing::instrument(name = "handlers.oauth2.device.consent.post", skip_all)]
146pub(crate) async fn post(
147    mut rng: BoxRng,
148    clock: BoxClock,
149    PreferredLanguage(locale): PreferredLanguage,
150    State(templates): State<Templates>,
151    State(url_builder): State<UrlBuilder>,
152    mut repo: BoxRepository,
153    mut policy: Policy,
154    activity_tracker: BoundActivityTracker,
155    user_agent: Option<TypedHeader<headers::UserAgent>>,
156    cookie_jar: CookieJar,
157    Path(grant_id): Path<Ulid>,
158    Form(form): Form<ProtectedForm<ConsentForm>>,
159) -> Result<Response, InternalError> {
160    let form = cookie_jar.verify_form(&clock, form)?;
161    let (cookie_jar, maybe_session) = match load_session_or_fallback(
162        cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
163    )
164    .await?
165    {
166        SessionOrFallback::MaybeSession {
167            cookie_jar,
168            maybe_session,
169            ..
170        } => (cookie_jar, maybe_session),
171        SessionOrFallback::Fallback { response } => return Ok(response),
172    };
173    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
174
175    let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
176
177    let Some(session) = maybe_session else {
178        let login = mas_router::Login::and_continue_device_code_grant(grant_id);
179        return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
180    };
181
182    activity_tracker
183        .record_browser_session(&clock, &session)
184        .await;
185
186    // TODO: better error handling
187    let grant = repo
188        .oauth2_device_code_grant()
189        .lookup(grant_id)
190        .await?
191        .context("Device grant not found")
192        .map_err(InternalError::from_anyhow)?;
193
194    if grant.expires_at < clock.now() {
195        return Err(InternalError::from_anyhow(anyhow::anyhow!(
196            "Grant is expired"
197        )));
198    }
199
200    let client = repo
201        .oauth2_client()
202        .lookup(grant.client_id)
203        .await?
204        .context("Client not found")
205        .map_err(InternalError::from_anyhow)?;
206
207    // Evaluate the policy
208    let res = policy
209        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
210            grant_type: mas_policy::GrantType::DeviceCode,
211            client: &client,
212            scope: &grant.scope,
213            user: Some(&session.user),
214            requester: mas_policy::Requester {
215                ip_address: activity_tracker.ip(),
216                user_agent,
217            },
218        })
219        .await?;
220    if !res.valid() {
221        warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id);
222
223        let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
224        let ctx = PolicyViolationContext::for_device_code_grant(grant, client)
225            .with_session(session)
226            .with_csrf(csrf_token.form_value())
227            .with_language(locale);
228
229        let content = templates.render_policy_violation(&ctx)?;
230
231        return Ok((cookie_jar, Html(content)).into_response());
232    }
233
234    let grant = if grant.is_pending() {
235        match form.action {
236            Action::Consent => {
237                repo.oauth2_device_code_grant()
238                    .fulfill(&clock, grant, &session)
239                    .await?
240            }
241            Action::Reject => {
242                repo.oauth2_device_code_grant()
243                    .reject(&clock, grant, &session)
244                    .await?
245            }
246        }
247    } else {
248        // XXX: In case we're not pending, let's just return the grant as-is
249        // since it might just be a form resubmission, and feedback is nice enough
250        warn!(
251            oauth2_device_code.id = %grant.id,
252            browser_session.id = %session.id,
253            user.id = %session.user.id,
254            "Grant is not pending",
255        );
256        grant
257    };
258
259    repo.save().await?;
260
261    let ctx = DeviceConsentContext::new(grant, client)
262        .with_session(session)
263        .with_csrf(csrf_token.form_value())
264        .with_language(locale);
265
266    let rendered = templates
267        .render_device_consent(&ctx)
268        .context("Failed to render template")
269        .map_err(InternalError::from_anyhow)?;
270
271    Ok((cookie_jar, Html(rendered)).into_response())
272}