mas_handlers/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17    GenericError, SessionInfoExt,
18    cookies::CookieJar,
19    csrf::{CsrfExt, ProtectedForm},
20    record_error,
21};
22use mas_jose::jwt::Jwt;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28    queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
29    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
30    user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
31};
32use mas_templates::{
33    AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
34    ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
35};
36use minijinja::Environment;
37use opentelemetry::{Key, KeyValue, metrics::Counter};
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44    UpstreamSessionsCookie,
45    template::{AttributeMappingContext, environment},
46};
47use crate::{
48    BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49    views::shared::OptionalPostAuthAction,
50};
51
52static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
53    METER
54        .u64_counter("mas.upstream_oauth2.login")
55        .with_description("Successful upstream OAuth 2.0 login to existing accounts")
56        .with_unit("{login}")
57        .build()
58});
59static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
60    METER
61        .u64_counter("mas.upstream_oauth2.registration")
62        .with_description("Successful upstream OAuth 2.0 registration")
63        .with_unit("{registration}")
64        .build()
65});
66const PROVIDER: Key = Key::from_static_str("provider");
67
68const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
69const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
70const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
71
72#[derive(Debug, Error)]
73pub(crate) enum RouteError {
74    /// Couldn't find the link specified in the URL
75    #[error("Link not found")]
76    LinkNotFound,
77
78    /// Couldn't find the session on the link
79    #[error("Session {0} not found")]
80    SessionNotFound(Ulid),
81
82    /// Couldn't find the user
83    #[error("User {0} not found")]
84    UserNotFound(Ulid),
85
86    /// Couldn't find upstream provider
87    #[error("Upstream provider {0} not found")]
88    ProviderNotFound(Ulid),
89
90    /// Required attribute rendered to an empty string
91    #[error("Template {template:?} rendered to an empty string")]
92    RequiredAttributeEmpty { template: String },
93
94    /// Required claim was missing in `id_token`
95    #[error(
96        "Template {template:?} could not be rendered from the upstream provider's response for required claim"
97    )]
98    RequiredAttributeRender {
99        template: String,
100
101        #[source]
102        source: minijinja::Error,
103    },
104
105    /// Session was already consumed
106    #[error("Session {0} already consumed")]
107    SessionConsumed(Ulid),
108
109    #[error("Missing session cookie")]
110    MissingCookie,
111
112    #[error("Invalid form action")]
113    InvalidFormAction,
114
115    #[error("Homeserver connection error")]
116    HomeserverConnection(#[source] anyhow::Error),
117
118    #[error(transparent)]
119    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
120}
121
122impl_from_error_for_route!(mas_templates::TemplateError);
123impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
124impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
125impl_from_error_for_route!(mas_storage::RepositoryError);
126impl_from_error_for_route!(mas_policy::EvaluationError);
127impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
128
129impl IntoResponse for RouteError {
130    fn into_response(self) -> axum::response::Response {
131        let sentry_event_id = record_error!(
132            self,
133            Self::Internal(_)
134                | Self::RequiredAttributeEmpty { .. }
135                | Self::RequiredAttributeRender { .. }
136                | Self::SessionNotFound(_)
137                | Self::ProviderNotFound(_)
138                | Self::UserNotFound(_)
139                | Self::HomeserverConnection(_)
140        );
141
142        let status_code = match self {
143            Self::LinkNotFound => StatusCode::NOT_FOUND,
144            _ => StatusCode::INTERNAL_SERVER_ERROR,
145        };
146
147        let response = GenericError::new(status_code, self);
148        (sentry_event_id, response).into_response()
149    }
150}
151
152/// Utility function to render an attribute template.
153///
154/// # Parameters
155///
156/// * `environment` - The minijinja environment to use to render the template
157/// * `template` - The template to use to render the claim
158/// * `required` - Whether the attribute is required or not
159///
160/// # Errors
161///
162/// Returns an error if the attribute is required but fails to render or is
163/// empty
164fn render_attribute_template(
165    environment: &Environment,
166    template: &str,
167    context: &minijinja::Value,
168    required: bool,
169) -> Result<Option<String>, RouteError> {
170    match environment.render_str(template, context) {
171        Ok(value) if value.is_empty() => {
172            if required {
173                return Err(RouteError::RequiredAttributeEmpty {
174                    template: template.to_owned(),
175                });
176            }
177
178            Ok(None)
179        }
180
181        Ok(value) => Ok(Some(value)),
182
183        Err(source) => {
184            if required {
185                return Err(RouteError::RequiredAttributeRender {
186                    template: template.to_owned(),
187                    source,
188                });
189            }
190
191            tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
192            Ok(None)
193        }
194    }
195}
196
197#[derive(Deserialize, Serialize)]
198#[serde(rename_all = "lowercase", tag = "action")]
199pub(crate) enum FormData {
200    Register {
201        #[serde(default)]
202        username: Option<String>,
203        #[serde(default)]
204        import_email: Option<String>,
205        #[serde(default)]
206        import_display_name: Option<String>,
207        #[serde(default)]
208        accept_terms: Option<String>,
209    },
210    Link,
211}
212
213impl ToFormState for FormData {
214    type Field = mas_templates::UpstreamRegisterFormField;
215}
216
217#[tracing::instrument(
218    name = "handlers.upstream_oauth2.link.get",
219    fields(upstream_oauth_link.id = %link_id),
220    skip_all,
221)]
222pub(crate) async fn get(
223    mut rng: BoxRng,
224    clock: BoxClock,
225    mut repo: BoxRepository,
226    mut policy: Policy,
227    PreferredLanguage(locale): PreferredLanguage,
228    State(templates): State<Templates>,
229    State(url_builder): State<UrlBuilder>,
230    State(homeserver): State<Arc<dyn HomeserverConnection>>,
231    cookie_jar: CookieJar,
232    activity_tracker: BoundActivityTracker,
233    user_agent: Option<TypedHeader<headers::UserAgent>>,
234    Path(link_id): Path<Ulid>,
235) -> Result<impl IntoResponse, RouteError> {
236    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
237    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
238    let (session_id, post_auth_action) = sessions_cookie
239        .lookup_link(link_id)
240        .map_err(|_| RouteError::MissingCookie)?;
241
242    let post_auth_action = OptionalPostAuthAction {
243        post_auth_action: post_auth_action.cloned(),
244    };
245
246    let link = repo
247        .upstream_oauth_link()
248        .lookup(link_id)
249        .await?
250        .ok_or(RouteError::LinkNotFound)?;
251
252    let upstream_session = repo
253        .upstream_oauth_session()
254        .lookup(session_id)
255        .await?
256        .ok_or(RouteError::SessionNotFound(session_id))?;
257
258    // This checks that we're in a browser session which is allowed to consume this
259    // link: the upstream auth session should have been started in this browser.
260    if upstream_session.link_id() != Some(link.id) {
261        return Err(RouteError::SessionNotFound(session_id));
262    }
263
264    if upstream_session.is_consumed() {
265        return Err(RouteError::SessionConsumed(session_id));
266    }
267
268    let (user_session_info, cookie_jar) = cookie_jar.session_info();
269    let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
270    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
271
272    let response = match (maybe_user_session, link.user_id) {
273        (Some(session), Some(user_id)) if session.user.id == user_id => {
274            // Session already linked, and link matches the currently logged
275            // user. Mark the session as consumed and renew the authentication.
276            let upstream_session = repo
277                .upstream_oauth_session()
278                .consume(&clock, upstream_session)
279                .await?;
280
281            repo.browser_session()
282                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
283                .await?;
284
285            cookie_jar = cookie_jar.set_session(&session);
286
287            repo.save().await?;
288
289            post_auth_action.go_next(&url_builder).into_response()
290        }
291
292        (Some(user_session), Some(user_id)) => {
293            // Session already linked, but link doesn't match the currently
294            // logged user. Suggest logging out of the current user
295            // and logging in with the new one
296            let user = repo
297                .user()
298                .lookup(user_id)
299                .await?
300                .ok_or(RouteError::UserNotFound(user_id))?;
301
302            let ctx = UpstreamExistingLinkContext::new(user)
303                .with_session(user_session)
304                .with_csrf(csrf_token.form_value())
305                .with_language(locale);
306
307            Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
308        }
309
310        (Some(user_session), None) => {
311            // Session not linked, but user logged in: suggest linking account
312            let ctx = UpstreamSuggestLink::new(&link)
313                .with_session(user_session)
314                .with_csrf(csrf_token.form_value())
315                .with_language(locale);
316
317            Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
318        }
319
320        (None, Some(user_id)) => {
321            // Session linked, but user not logged in: do the login
322            let user = repo
323                .user()
324                .lookup(user_id)
325                .await?
326                .ok_or(RouteError::UserNotFound(user_id))?;
327
328            // Check that the user is not locked or deactivated
329            if user.deactivated_at.is_some() {
330                // The account is deactivated, show the 'account deactivated' fallback
331                let ctx = AccountInactiveContext::new(user)
332                    .with_csrf(csrf_token.form_value())
333                    .with_language(locale);
334                let fallback = templates.render_account_deactivated(&ctx)?;
335                return Ok((cookie_jar, Html(fallback).into_response()));
336            }
337
338            if user.locked_at.is_some() {
339                // The account is locked, show the 'account locked' fallback
340                let ctx = AccountInactiveContext::new(user)
341                    .with_csrf(csrf_token.form_value())
342                    .with_language(locale);
343                let fallback = templates.render_account_locked(&ctx)?;
344                return Ok((cookie_jar, Html(fallback).into_response()));
345            }
346
347            let session = repo
348                .browser_session()
349                .add(&mut rng, &clock, &user, user_agent)
350                .await?;
351
352            let upstream_session = repo
353                .upstream_oauth_session()
354                .consume(&clock, upstream_session)
355                .await?;
356
357            repo.browser_session()
358                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
359                .await?;
360
361            cookie_jar = sessions_cookie
362                .consume_link(link_id)?
363                .save(cookie_jar, &clock);
364            cookie_jar = cookie_jar.set_session(&session);
365
366            repo.save().await?;
367
368            LOGIN_COUNTER.add(
369                1,
370                &[KeyValue::new(
371                    PROVIDER,
372                    upstream_session.provider_id.to_string(),
373                )],
374            );
375
376            post_auth_action.go_next(&url_builder).into_response()
377        }
378
379        (None, None) => {
380            // Session not linked and used not logged in: suggest creating an
381            // account or logging in an existing user
382            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
383
384            let provider = repo
385                .upstream_oauth_provider()
386                .lookup(link.provider_id)
387                .await?
388                .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
389
390            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
391
392            let env = environment();
393
394            let mut context = AttributeMappingContext::new();
395            if let Some(id_token) = id_token {
396                let (_, payload) = id_token.into_parts();
397                context = context.with_id_token_claims(payload);
398            }
399            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
400                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
401            }
402            if let Some(userinfo) = upstream_session.userinfo() {
403                context = context.with_userinfo_claims(userinfo.clone());
404            }
405            let context = context.build();
406
407            let ctx = if provider.claims_imports.displayname.ignore() {
408                ctx
409            } else {
410                let template = provider
411                    .claims_imports
412                    .displayname
413                    .template
414                    .as_deref()
415                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
416
417                match render_attribute_template(
418                    &env,
419                    template,
420                    &context,
421                    provider.claims_imports.displayname.is_required(),
422                )? {
423                    Some(value) => ctx
424                        .with_display_name(value, provider.claims_imports.displayname.is_forced()),
425                    None => ctx,
426                }
427            };
428
429            let ctx = if provider.claims_imports.email.ignore() {
430                ctx
431            } else {
432                let template = provider
433                    .claims_imports
434                    .email
435                    .template
436                    .as_deref()
437                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
438
439                match render_attribute_template(
440                    &env,
441                    template,
442                    &context,
443                    provider.claims_imports.email.is_required(),
444                )? {
445                    Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
446                    None => ctx,
447                }
448            };
449
450            let ctx = if provider.claims_imports.localpart.ignore() {
451                ctx
452            } else {
453                let template = provider
454                    .claims_imports
455                    .localpart
456                    .template
457                    .as_deref()
458                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
459
460                match render_attribute_template(
461                    &env,
462                    template,
463                    &context,
464                    provider.claims_imports.localpart.is_required(),
465                )? {
466                    Some(localpart) => {
467                        // We could run policy & existing user checks when the user submits the
468                        // form, but this lead to poor UX. This is why we do
469                        // it ahead of time here.
470                        let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
471                        let is_available = homeserver
472                            .is_localpart_available(&localpart)
473                            .await
474                            .map_err(RouteError::HomeserverConnection)?;
475
476                        if maybe_existing_user.is_some() || !is_available {
477                            if let Some(existing_user) = maybe_existing_user {
478                                // The mapper returned a username which already exists, but isn't
479                                // linked to this upstream user.
480                                warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
481                            }
482
483                            // TODO: translate
484                            let ctx = ErrorContext::new()
485                                .with_code("User exists")
486                                .with_description(format!(
487                                    r"Upstream account provider returned {localpart:?} as username,
488                                    which is not linked to that upstream account"
489                                ))
490                                .with_language(&locale);
491
492                            return Ok((
493                                cookie_jar,
494                                Html(templates.render_error(&ctx)?).into_response(),
495                            ));
496                        }
497
498                        let res = policy
499                            .evaluate_register(mas_policy::RegisterInput {
500                                registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
501                                username: &localpart,
502                                email: None,
503                                requester: mas_policy::Requester {
504                                    ip_address: activity_tracker.ip(),
505                                    user_agent: user_agent.clone(),
506                                },
507                            })
508                            .await?;
509
510                        if res.valid() {
511                            // The username passes the policy check, add it to the context
512                            ctx.with_localpart(
513                                localpart,
514                                provider.claims_imports.localpart.is_forced(),
515                            )
516                        } else if provider.claims_imports.localpart.is_forced() {
517                            // If the username claim is 'forced' but doesn't pass the policy check,
518                            // we display an error message.
519                            // TODO: translate
520                            let ctx = ErrorContext::new()
521                                .with_code("Policy error")
522                                .with_description(format!(
523                                    r"Upstream account provider returned {localpart:?} as username,
524                                    which does not pass the policy check: {res}"
525                                ))
526                                .with_language(&locale);
527
528                            return Ok((
529                                cookie_jar,
530                                Html(templates.render_error(&ctx)?).into_response(),
531                            ));
532                        } else {
533                            // Else, we just ignore it when it doesn't pass the policy check.
534                            ctx
535                        }
536                    }
537                    None => ctx,
538                }
539            };
540
541            let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
542
543            Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
544        }
545    };
546
547    Ok((cookie_jar, response))
548}
549
550#[tracing::instrument(
551    name = "handlers.upstream_oauth2.link.post",
552    fields(upstream_oauth_link.id = %link_id),
553    skip_all,
554)]
555pub(crate) async fn post(
556    mut rng: BoxRng,
557    clock: BoxClock,
558    mut repo: BoxRepository,
559    cookie_jar: CookieJar,
560    user_agent: Option<TypedHeader<headers::UserAgent>>,
561    mut policy: Policy,
562    PreferredLanguage(locale): PreferredLanguage,
563    activity_tracker: BoundActivityTracker,
564    State(templates): State<Templates>,
565    State(homeserver): State<Arc<dyn HomeserverConnection>>,
566    State(url_builder): State<UrlBuilder>,
567    State(site_config): State<SiteConfig>,
568    Path(link_id): Path<Ulid>,
569    Form(form): Form<ProtectedForm<FormData>>,
570) -> Result<Response, RouteError> {
571    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
572    let form = cookie_jar.verify_form(&clock, form)?;
573
574    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
575    let (session_id, post_auth_action) = sessions_cookie
576        .lookup_link(link_id)
577        .map_err(|_| RouteError::MissingCookie)?;
578
579    let post_auth_action = OptionalPostAuthAction {
580        post_auth_action: post_auth_action.cloned(),
581    };
582
583    let link = repo
584        .upstream_oauth_link()
585        .lookup(link_id)
586        .await?
587        .ok_or(RouteError::LinkNotFound)?;
588
589    let upstream_session = repo
590        .upstream_oauth_session()
591        .lookup(session_id)
592        .await?
593        .ok_or(RouteError::SessionNotFound(session_id))?;
594
595    // This checks that we're in a browser session which is allowed to consume this
596    // link: the upstream auth session should have been started in this browser.
597    if upstream_session.link_id() != Some(link.id) {
598        return Err(RouteError::SessionNotFound(session_id));
599    }
600
601    if upstream_session.is_consumed() {
602        return Err(RouteError::SessionConsumed(session_id));
603    }
604
605    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
606    let (user_session_info, cookie_jar) = cookie_jar.session_info();
607    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
608    let form_state = form.to_form_state();
609
610    let session = match (maybe_user_session, link.user_id, form) {
611        (Some(session), None, FormData::Link) => {
612            // The user is already logged in, the link is not linked to any user, and the
613            // user asked to link their account.
614            repo.upstream_oauth_link()
615                .associate_to_user(&link, &session.user)
616                .await?;
617
618            session
619        }
620
621        (
622            None,
623            None,
624            FormData::Register {
625                username,
626                import_email,
627                import_display_name,
628                accept_terms,
629            },
630        ) => {
631            // The user got the form to register a new account, and is not logged in.
632            // Depending on the claims_imports, we've let the user choose their username,
633            // choose whether they want to import the email and display name, or
634            // not.
635
636            // Those fields are Some("on") if the checkbox is checked
637            let import_email = import_email.is_some();
638            let import_display_name = import_display_name.is_some();
639            let accept_terms = accept_terms.is_some();
640
641            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
642
643            let provider = repo
644                .upstream_oauth_provider()
645                .lookup(link.provider_id)
646                .await?
647                .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
648
649            // Let's try to import the claims from the ID token
650            let env = environment();
651
652            let mut context = AttributeMappingContext::new();
653            if let Some(id_token) = id_token {
654                let (_, payload) = id_token.into_parts();
655                context = context.with_id_token_claims(payload);
656            }
657            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
658                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
659            }
660            if let Some(userinfo) = upstream_session.userinfo() {
661                context = context.with_userinfo_claims(userinfo.clone());
662            }
663            let context = context.build();
664
665            // Create a template context in case we need to re-render because of an error
666            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
667
668            let display_name = if provider
669                .claims_imports
670                .displayname
671                .should_import(import_display_name)
672            {
673                let template = provider
674                    .claims_imports
675                    .displayname
676                    .template
677                    .as_deref()
678                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
679
680                render_attribute_template(
681                    &env,
682                    template,
683                    &context,
684                    provider.claims_imports.displayname.is_required(),
685                )?
686            } else {
687                None
688            };
689
690            let ctx = if let Some(ref display_name) = display_name {
691                ctx.with_display_name(
692                    display_name.clone(),
693                    provider.claims_imports.email.is_forced(),
694                )
695            } else {
696                ctx
697            };
698
699            let email = if provider.claims_imports.email.should_import(import_email) {
700                let template = provider
701                    .claims_imports
702                    .email
703                    .template
704                    .as_deref()
705                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
706
707                render_attribute_template(
708                    &env,
709                    template,
710                    &context,
711                    provider.claims_imports.email.is_required(),
712                )?
713            } else {
714                None
715            };
716
717            let ctx = if let Some(ref email) = email {
718                ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
719            } else {
720                ctx
721            };
722
723            let username = if provider.claims_imports.localpart.is_forced() {
724                let template = provider
725                    .claims_imports
726                    .localpart
727                    .template
728                    .as_deref()
729                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
730
731                render_attribute_template(&env, template, &context, true)?
732            } else {
733                // If there is no forced username, we can use the one the user entered
734                username
735            }
736            .unwrap_or_default();
737
738            let ctx = ctx.with_localpart(
739                username.clone(),
740                provider.claims_imports.localpart.is_forced(),
741            );
742
743            // Validate the form
744            let form_state = {
745                let mut form_state = form_state;
746                let mut homeserver_denied_username = false;
747                if username.is_empty() {
748                    form_state.add_error_on_field(
749                        mas_templates::UpstreamRegisterFormField::Username,
750                        FieldError::Required,
751                    );
752                } else if repo.user().exists(&username).await? {
753                    form_state.add_error_on_field(
754                        mas_templates::UpstreamRegisterFormField::Username,
755                        FieldError::Exists,
756                    );
757                } else if !homeserver
758                    .is_localpart_available(&username)
759                    .await
760                    .map_err(RouteError::HomeserverConnection)?
761                {
762                    // The user already exists on the homeserver
763                    tracing::warn!(
764                        %username,
765                        "Homeserver denied username provided by user"
766                    );
767
768                    // We defer adding the error on the field, until we know whether we had another
769                    // error from the policy, to avoid showing both
770                    homeserver_denied_username = true;
771                }
772
773                // If we have a TOS in the config, make sure the user has accepted it
774                if site_config.tos_uri.is_some() && !accept_terms {
775                    form_state.add_error_on_field(
776                        mas_templates::UpstreamRegisterFormField::AcceptTerms,
777                        FieldError::Required,
778                    );
779                }
780
781                // Policy check
782                let res = policy
783                    .evaluate_register(mas_policy::RegisterInput {
784                        registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
785                        username: &username,
786                        email: email.as_deref(),
787                        requester: mas_policy::Requester {
788                            ip_address: activity_tracker.ip(),
789                            user_agent: user_agent.clone(),
790                        },
791                    })
792                    .await?;
793
794                for violation in res.violations {
795                    match violation.field.as_deref() {
796                        Some("username") => {
797                            // If the homeserver denied the username, but we also had an error on
798                            // the policy side, we don't want to show
799                            // both, so we reset the state here
800                            homeserver_denied_username = false;
801                            form_state.add_error_on_field(
802                                mas_templates::UpstreamRegisterFormField::Username,
803                                FieldError::Policy {
804                                    code: violation.code.map(|c| c.as_str()),
805                                    message: violation.msg,
806                                },
807                            );
808                        }
809                        _ => form_state.add_error_on_form(FormError::Policy {
810                            code: violation.code.map(|c| c.as_str()),
811                            message: violation.msg,
812                        }),
813                    }
814                }
815
816                if homeserver_denied_username {
817                    // XXX: we may want to return different errors like "this username is reserved"
818                    form_state.add_error_on_field(
819                        mas_templates::UpstreamRegisterFormField::Username,
820                        FieldError::Exists,
821                    );
822                }
823
824                form_state
825            };
826
827            if !form_state.is_valid() {
828                let ctx = ctx
829                    .with_form_state(form_state)
830                    .with_csrf(csrf_token.form_value())
831                    .with_language(locale);
832
833                return Ok((
834                    cookie_jar,
835                    Html(templates.render_upstream_oauth2_do_register(&ctx)?),
836                )
837                    .into_response());
838            }
839
840            REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
841
842            // Now we can create the user
843            let user = repo.user().add(&mut rng, &clock, username).await?;
844
845            if let Some(terms_url) = &site_config.tos_uri {
846                repo.user_terms()
847                    .accept_terms(&mut rng, &clock, &user, terms_url.clone())
848                    .await?;
849            }
850
851            // And schedule the job to provision it
852            let mut job = ProvisionUserJob::new(&user);
853
854            // If we have a display name, set it during provisioning
855            if let Some(name) = display_name {
856                job = job.set_display_name(name);
857            }
858
859            repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
860
861            // If we have an email, add it to the user
862            if let Some(email) = email {
863                repo.user_email()
864                    .add(&mut rng, &clock, &user, email)
865                    .await?;
866            }
867
868            repo.upstream_oauth_link()
869                .associate_to_user(&link, &user)
870                .await?;
871
872            repo.browser_session()
873                .add(&mut rng, &clock, &user, user_agent)
874                .await?
875        }
876
877        _ => return Err(RouteError::InvalidFormAction),
878    };
879
880    let upstream_session = repo
881        .upstream_oauth_session()
882        .consume(&clock, upstream_session)
883        .await?;
884
885    repo.browser_session()
886        .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
887        .await?;
888
889    let cookie_jar = sessions_cookie
890        .consume_link(link_id)?
891        .save(cookie_jar, &clock);
892    let cookie_jar = cookie_jar.set_session(&session);
893
894    repo.save().await?;
895
896    Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
897}
898
899#[cfg(test)]
900mod tests {
901    use hyper::{Request, StatusCode, header::CONTENT_TYPE};
902    use mas_data_model::{
903        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
904        UpstreamOAuthProviderTokenAuthMethod,
905    };
906    use mas_iana::jose::JsonWebSignatureAlg;
907    use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
908    use mas_router::Route;
909    use mas_storage::{
910        Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
911    };
912    use oauth2_types::scope::{OPENID, Scope};
913    use sqlx::PgPool;
914
915    use super::UpstreamSessionsCookie;
916    use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
917
918    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
919    async fn test_register(pool: PgPool) {
920        setup();
921        let state = TestState::from_pool(pool).await.unwrap();
922        let mut rng = state.rng();
923        let cookies = CookieHelper::new();
924
925        let claims_imports = UpstreamOAuthProviderClaimsImports {
926            localpart: UpstreamOAuthProviderImportPreference {
927                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
928                template: None,
929            },
930            email: UpstreamOAuthProviderImportPreference {
931                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
932                template: None,
933            },
934            ..UpstreamOAuthProviderClaimsImports::default()
935        };
936
937        let id_token = serde_json::json!({
938            "preferred_username": "john",
939            "email": "john@example.com",
940            "email_verified": true,
941        });
942
943        // Grab a key to sign the id_token
944        // We could generate a key on the fly, but because we have one available here,
945        // why not use it?
946        let key = state
947            .key_store
948            .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
949            .unwrap();
950
951        let signer = key
952            .params()
953            .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
954            .unwrap();
955        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
956        let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
957
958        // Provision a provider and a link
959        let mut repo = state.repository().await.unwrap();
960        let provider = repo
961            .upstream_oauth_provider()
962            .add(
963                &mut rng,
964                &state.clock,
965                UpstreamOAuthProviderParams {
966                    issuer: Some("https://example.com/".to_owned()),
967                    human_name: Some("Example Ltd.".to_owned()),
968                    brand_name: None,
969                    scope: Scope::from_iter([OPENID]),
970                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
971                    token_endpoint_signing_alg: None,
972                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
973                    client_id: "client".to_owned(),
974                    encrypted_client_secret: None,
975                    claims_imports,
976                    authorization_endpoint_override: None,
977                    token_endpoint_override: None,
978                    userinfo_endpoint_override: None,
979                    fetch_userinfo: false,
980                    userinfo_signed_response_alg: None,
981                    jwks_uri_override: None,
982                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
983                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
984                    response_mode: None,
985                    additional_authorization_parameters: Vec::new(),
986                    ui_order: 0,
987                },
988            )
989            .await
990            .unwrap();
991
992        let session = repo
993            .upstream_oauth_session()
994            .add(
995                &mut rng,
996                &state.clock,
997                &provider,
998                "state".to_owned(),
999                None,
1000                "nonce".to_owned(),
1001            )
1002            .await
1003            .unwrap();
1004
1005        let link = repo
1006            .upstream_oauth_link()
1007            .add(
1008                &mut rng,
1009                &state.clock,
1010                &provider,
1011                "subject".to_owned(),
1012                None,
1013            )
1014            .await
1015            .unwrap();
1016
1017        let session = repo
1018            .upstream_oauth_session()
1019            .complete_with_link(
1020                &state.clock,
1021                session,
1022                &link,
1023                Some(id_token.into_string()),
1024                None,
1025                None,
1026            )
1027            .await
1028            .unwrap();
1029
1030        repo.save().await.unwrap();
1031
1032        let cookie_jar = state.cookie_jar();
1033        let upstream_sessions = UpstreamSessionsCookie::default()
1034            .add(session.id, provider.id, "state".to_owned(), None)
1035            .add_link_to_session(session.id, link.id)
1036            .unwrap();
1037        let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1038        cookies.import(cookie_jar);
1039
1040        let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1041        let request = cookies.with_cookies(request);
1042        let response = state.request(request).await;
1043        cookies.save_cookies(&response);
1044        response.assert_status(StatusCode::OK);
1045        response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1046
1047        // Extract the CSRF token from the response body
1048        let csrf_token = response
1049            .body()
1050            .split("name=\"csrf\" value=\"")
1051            .nth(1)
1052            .unwrap()
1053            .split('\"')
1054            .next()
1055            .unwrap();
1056
1057        let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1058            serde_json::json!({
1059                "csrf": csrf_token,
1060                "action": "register",
1061                "import_email": "on",
1062                "accept_terms": "on",
1063            }),
1064        );
1065        let request = cookies.with_cookies(request);
1066        let response = state.request(request).await;
1067        cookies.save_cookies(&response);
1068        response.assert_status(StatusCode::SEE_OTHER);
1069
1070        // Check that we have a registered user, with the email imported
1071        let mut repo = state.repository().await.unwrap();
1072        let user = repo
1073            .user()
1074            .find_by_username("john")
1075            .await
1076            .unwrap()
1077            .expect("user exists");
1078
1079        let link = repo
1080            .upstream_oauth_link()
1081            .find_by_subject(&provider, "subject")
1082            .await
1083            .unwrap()
1084            .expect("link exists");
1085
1086        assert_eq!(link.user_id, Some(user.id));
1087
1088        let page = repo
1089            .user_email()
1090            .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1091            .await
1092            .unwrap();
1093        let email = page.edges.first().expect("email exists");
1094
1095        assert_eq!(email.email, "john@example.com");
1096    }
1097}