1use std::sync::{Arc, LazyLock};
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17 GenericError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, ProtectedForm},
20 record_error,
21};
22use mas_jose::jwt::Jwt;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28 queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
29 upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
30 user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
31};
32use mas_templates::{
33 AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
34 ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
35};
36use minijinja::Environment;
37use opentelemetry::{Key, KeyValue, metrics::Counter};
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44 UpstreamSessionsCookie,
45 template::{AttributeMappingContext, environment},
46};
47use crate::{
48 BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49 views::shared::OptionalPostAuthAction,
50};
51
52static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
53 METER
54 .u64_counter("mas.upstream_oauth2.login")
55 .with_description("Successful upstream OAuth 2.0 login to existing accounts")
56 .with_unit("{login}")
57 .build()
58});
59static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
60 METER
61 .u64_counter("mas.upstream_oauth2.registration")
62 .with_description("Successful upstream OAuth 2.0 registration")
63 .with_unit("{registration}")
64 .build()
65});
66const PROVIDER: Key = Key::from_static_str("provider");
67
68const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
69const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
70const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
71
72#[derive(Debug, Error)]
73pub(crate) enum RouteError {
74 #[error("Link not found")]
76 LinkNotFound,
77
78 #[error("Session {0} not found")]
80 SessionNotFound(Ulid),
81
82 #[error("User {0} not found")]
84 UserNotFound(Ulid),
85
86 #[error("Upstream provider {0} not found")]
88 ProviderNotFound(Ulid),
89
90 #[error("Template {template:?} rendered to an empty string")]
92 RequiredAttributeEmpty { template: String },
93
94 #[error(
96 "Template {template:?} could not be rendered from the upstream provider's response for required claim"
97 )]
98 RequiredAttributeRender {
99 template: String,
100
101 #[source]
102 source: minijinja::Error,
103 },
104
105 #[error("Session {0} already consumed")]
107 SessionConsumed(Ulid),
108
109 #[error("Missing session cookie")]
110 MissingCookie,
111
112 #[error("Invalid form action")]
113 InvalidFormAction,
114
115 #[error("Homeserver connection error")]
116 HomeserverConnection(#[source] anyhow::Error),
117
118 #[error(transparent)]
119 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
120}
121
122impl_from_error_for_route!(mas_templates::TemplateError);
123impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
124impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
125impl_from_error_for_route!(mas_storage::RepositoryError);
126impl_from_error_for_route!(mas_policy::EvaluationError);
127impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
128
129impl IntoResponse for RouteError {
130 fn into_response(self) -> axum::response::Response {
131 let sentry_event_id = record_error!(
132 self,
133 Self::Internal(_)
134 | Self::RequiredAttributeEmpty { .. }
135 | Self::RequiredAttributeRender { .. }
136 | Self::SessionNotFound(_)
137 | Self::ProviderNotFound(_)
138 | Self::UserNotFound(_)
139 | Self::HomeserverConnection(_)
140 );
141
142 let status_code = match self {
143 Self::LinkNotFound => StatusCode::NOT_FOUND,
144 _ => StatusCode::INTERNAL_SERVER_ERROR,
145 };
146
147 let response = GenericError::new(status_code, self);
148 (sentry_event_id, response).into_response()
149 }
150}
151
152fn render_attribute_template(
165 environment: &Environment,
166 template: &str,
167 context: &minijinja::Value,
168 required: bool,
169) -> Result<Option<String>, RouteError> {
170 match environment.render_str(template, context) {
171 Ok(value) if value.is_empty() => {
172 if required {
173 return Err(RouteError::RequiredAttributeEmpty {
174 template: template.to_owned(),
175 });
176 }
177
178 Ok(None)
179 }
180
181 Ok(value) => Ok(Some(value)),
182
183 Err(source) => {
184 if required {
185 return Err(RouteError::RequiredAttributeRender {
186 template: template.to_owned(),
187 source,
188 });
189 }
190
191 tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
192 Ok(None)
193 }
194 }
195}
196
197#[derive(Deserialize, Serialize)]
198#[serde(rename_all = "lowercase", tag = "action")]
199pub(crate) enum FormData {
200 Register {
201 #[serde(default)]
202 username: Option<String>,
203 #[serde(default)]
204 import_email: Option<String>,
205 #[serde(default)]
206 import_display_name: Option<String>,
207 #[serde(default)]
208 accept_terms: Option<String>,
209 },
210 Link,
211}
212
213impl ToFormState for FormData {
214 type Field = mas_templates::UpstreamRegisterFormField;
215}
216
217#[tracing::instrument(
218 name = "handlers.upstream_oauth2.link.get",
219 fields(upstream_oauth_link.id = %link_id),
220 skip_all,
221)]
222pub(crate) async fn get(
223 mut rng: BoxRng,
224 clock: BoxClock,
225 mut repo: BoxRepository,
226 mut policy: Policy,
227 PreferredLanguage(locale): PreferredLanguage,
228 State(templates): State<Templates>,
229 State(url_builder): State<UrlBuilder>,
230 State(homeserver): State<Arc<dyn HomeserverConnection>>,
231 cookie_jar: CookieJar,
232 activity_tracker: BoundActivityTracker,
233 user_agent: Option<TypedHeader<headers::UserAgent>>,
234 Path(link_id): Path<Ulid>,
235) -> Result<impl IntoResponse, RouteError> {
236 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
237 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
238 let (session_id, post_auth_action) = sessions_cookie
239 .lookup_link(link_id)
240 .map_err(|_| RouteError::MissingCookie)?;
241
242 let post_auth_action = OptionalPostAuthAction {
243 post_auth_action: post_auth_action.cloned(),
244 };
245
246 let link = repo
247 .upstream_oauth_link()
248 .lookup(link_id)
249 .await?
250 .ok_or(RouteError::LinkNotFound)?;
251
252 let upstream_session = repo
253 .upstream_oauth_session()
254 .lookup(session_id)
255 .await?
256 .ok_or(RouteError::SessionNotFound(session_id))?;
257
258 if upstream_session.link_id() != Some(link.id) {
261 return Err(RouteError::SessionNotFound(session_id));
262 }
263
264 if upstream_session.is_consumed() {
265 return Err(RouteError::SessionConsumed(session_id));
266 }
267
268 let (user_session_info, cookie_jar) = cookie_jar.session_info();
269 let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
270 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
271
272 let response = match (maybe_user_session, link.user_id) {
273 (Some(session), Some(user_id)) if session.user.id == user_id => {
274 let upstream_session = repo
277 .upstream_oauth_session()
278 .consume(&clock, upstream_session)
279 .await?;
280
281 repo.browser_session()
282 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
283 .await?;
284
285 cookie_jar = cookie_jar.set_session(&session);
286
287 repo.save().await?;
288
289 post_auth_action.go_next(&url_builder).into_response()
290 }
291
292 (Some(user_session), Some(user_id)) => {
293 let user = repo
297 .user()
298 .lookup(user_id)
299 .await?
300 .ok_or(RouteError::UserNotFound(user_id))?;
301
302 let ctx = UpstreamExistingLinkContext::new(user)
303 .with_session(user_session)
304 .with_csrf(csrf_token.form_value())
305 .with_language(locale);
306
307 Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
308 }
309
310 (Some(user_session), None) => {
311 let ctx = UpstreamSuggestLink::new(&link)
313 .with_session(user_session)
314 .with_csrf(csrf_token.form_value())
315 .with_language(locale);
316
317 Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
318 }
319
320 (None, Some(user_id)) => {
321 let user = repo
323 .user()
324 .lookup(user_id)
325 .await?
326 .ok_or(RouteError::UserNotFound(user_id))?;
327
328 if user.deactivated_at.is_some() {
330 let ctx = AccountInactiveContext::new(user)
332 .with_csrf(csrf_token.form_value())
333 .with_language(locale);
334 let fallback = templates.render_account_deactivated(&ctx)?;
335 return Ok((cookie_jar, Html(fallback).into_response()));
336 }
337
338 if user.locked_at.is_some() {
339 let ctx = AccountInactiveContext::new(user)
341 .with_csrf(csrf_token.form_value())
342 .with_language(locale);
343 let fallback = templates.render_account_locked(&ctx)?;
344 return Ok((cookie_jar, Html(fallback).into_response()));
345 }
346
347 let session = repo
348 .browser_session()
349 .add(&mut rng, &clock, &user, user_agent)
350 .await?;
351
352 let upstream_session = repo
353 .upstream_oauth_session()
354 .consume(&clock, upstream_session)
355 .await?;
356
357 repo.browser_session()
358 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
359 .await?;
360
361 cookie_jar = sessions_cookie
362 .consume_link(link_id)?
363 .save(cookie_jar, &clock);
364 cookie_jar = cookie_jar.set_session(&session);
365
366 repo.save().await?;
367
368 LOGIN_COUNTER.add(
369 1,
370 &[KeyValue::new(
371 PROVIDER,
372 upstream_session.provider_id.to_string(),
373 )],
374 );
375
376 post_auth_action.go_next(&url_builder).into_response()
377 }
378
379 (None, None) => {
380 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
383
384 let provider = repo
385 .upstream_oauth_provider()
386 .lookup(link.provider_id)
387 .await?
388 .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
389
390 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
391
392 let env = environment();
393
394 let mut context = AttributeMappingContext::new();
395 if let Some(id_token) = id_token {
396 let (_, payload) = id_token.into_parts();
397 context = context.with_id_token_claims(payload);
398 }
399 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
400 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
401 }
402 if let Some(userinfo) = upstream_session.userinfo() {
403 context = context.with_userinfo_claims(userinfo.clone());
404 }
405 let context = context.build();
406
407 let ctx = if provider.claims_imports.displayname.ignore() {
408 ctx
409 } else {
410 let template = provider
411 .claims_imports
412 .displayname
413 .template
414 .as_deref()
415 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
416
417 match render_attribute_template(
418 &env,
419 template,
420 &context,
421 provider.claims_imports.displayname.is_required(),
422 )? {
423 Some(value) => ctx
424 .with_display_name(value, provider.claims_imports.displayname.is_forced()),
425 None => ctx,
426 }
427 };
428
429 let ctx = if provider.claims_imports.email.ignore() {
430 ctx
431 } else {
432 let template = provider
433 .claims_imports
434 .email
435 .template
436 .as_deref()
437 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
438
439 match render_attribute_template(
440 &env,
441 template,
442 &context,
443 provider.claims_imports.email.is_required(),
444 )? {
445 Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
446 None => ctx,
447 }
448 };
449
450 let ctx = if provider.claims_imports.localpart.ignore() {
451 ctx
452 } else {
453 let template = provider
454 .claims_imports
455 .localpart
456 .template
457 .as_deref()
458 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
459
460 match render_attribute_template(
461 &env,
462 template,
463 &context,
464 provider.claims_imports.localpart.is_required(),
465 )? {
466 Some(localpart) => {
467 let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
471 let is_available = homeserver
472 .is_localpart_available(&localpart)
473 .await
474 .map_err(RouteError::HomeserverConnection)?;
475
476 if maybe_existing_user.is_some() || !is_available {
477 if let Some(existing_user) = maybe_existing_user {
478 warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
481 }
482
483 let ctx = ErrorContext::new()
485 .with_code("User exists")
486 .with_description(format!(
487 r"Upstream account provider returned {localpart:?} as username,
488 which is not linked to that upstream account"
489 ))
490 .with_language(&locale);
491
492 return Ok((
493 cookie_jar,
494 Html(templates.render_error(&ctx)?).into_response(),
495 ));
496 }
497
498 let res = policy
499 .evaluate_register(mas_policy::RegisterInput {
500 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
501 username: &localpart,
502 email: None,
503 requester: mas_policy::Requester {
504 ip_address: activity_tracker.ip(),
505 user_agent: user_agent.clone(),
506 },
507 })
508 .await?;
509
510 if res.valid() {
511 ctx.with_localpart(
513 localpart,
514 provider.claims_imports.localpart.is_forced(),
515 )
516 } else if provider.claims_imports.localpart.is_forced() {
517 let ctx = ErrorContext::new()
521 .with_code("Policy error")
522 .with_description(format!(
523 r"Upstream account provider returned {localpart:?} as username,
524 which does not pass the policy check: {res}"
525 ))
526 .with_language(&locale);
527
528 return Ok((
529 cookie_jar,
530 Html(templates.render_error(&ctx)?).into_response(),
531 ));
532 } else {
533 ctx
535 }
536 }
537 None => ctx,
538 }
539 };
540
541 let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
542
543 Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
544 }
545 };
546
547 Ok((cookie_jar, response))
548}
549
550#[tracing::instrument(
551 name = "handlers.upstream_oauth2.link.post",
552 fields(upstream_oauth_link.id = %link_id),
553 skip_all,
554)]
555pub(crate) async fn post(
556 mut rng: BoxRng,
557 clock: BoxClock,
558 mut repo: BoxRepository,
559 cookie_jar: CookieJar,
560 user_agent: Option<TypedHeader<headers::UserAgent>>,
561 mut policy: Policy,
562 PreferredLanguage(locale): PreferredLanguage,
563 activity_tracker: BoundActivityTracker,
564 State(templates): State<Templates>,
565 State(homeserver): State<Arc<dyn HomeserverConnection>>,
566 State(url_builder): State<UrlBuilder>,
567 State(site_config): State<SiteConfig>,
568 Path(link_id): Path<Ulid>,
569 Form(form): Form<ProtectedForm<FormData>>,
570) -> Result<Response, RouteError> {
571 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
572 let form = cookie_jar.verify_form(&clock, form)?;
573
574 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
575 let (session_id, post_auth_action) = sessions_cookie
576 .lookup_link(link_id)
577 .map_err(|_| RouteError::MissingCookie)?;
578
579 let post_auth_action = OptionalPostAuthAction {
580 post_auth_action: post_auth_action.cloned(),
581 };
582
583 let link = repo
584 .upstream_oauth_link()
585 .lookup(link_id)
586 .await?
587 .ok_or(RouteError::LinkNotFound)?;
588
589 let upstream_session = repo
590 .upstream_oauth_session()
591 .lookup(session_id)
592 .await?
593 .ok_or(RouteError::SessionNotFound(session_id))?;
594
595 if upstream_session.link_id() != Some(link.id) {
598 return Err(RouteError::SessionNotFound(session_id));
599 }
600
601 if upstream_session.is_consumed() {
602 return Err(RouteError::SessionConsumed(session_id));
603 }
604
605 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
606 let (user_session_info, cookie_jar) = cookie_jar.session_info();
607 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
608 let form_state = form.to_form_state();
609
610 let session = match (maybe_user_session, link.user_id, form) {
611 (Some(session), None, FormData::Link) => {
612 repo.upstream_oauth_link()
615 .associate_to_user(&link, &session.user)
616 .await?;
617
618 session
619 }
620
621 (
622 None,
623 None,
624 FormData::Register {
625 username,
626 import_email,
627 import_display_name,
628 accept_terms,
629 },
630 ) => {
631 let import_email = import_email.is_some();
638 let import_display_name = import_display_name.is_some();
639 let accept_terms = accept_terms.is_some();
640
641 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
642
643 let provider = repo
644 .upstream_oauth_provider()
645 .lookup(link.provider_id)
646 .await?
647 .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
648
649 let env = environment();
651
652 let mut context = AttributeMappingContext::new();
653 if let Some(id_token) = id_token {
654 let (_, payload) = id_token.into_parts();
655 context = context.with_id_token_claims(payload);
656 }
657 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
658 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
659 }
660 if let Some(userinfo) = upstream_session.userinfo() {
661 context = context.with_userinfo_claims(userinfo.clone());
662 }
663 let context = context.build();
664
665 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
667
668 let display_name = if provider
669 .claims_imports
670 .displayname
671 .should_import(import_display_name)
672 {
673 let template = provider
674 .claims_imports
675 .displayname
676 .template
677 .as_deref()
678 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
679
680 render_attribute_template(
681 &env,
682 template,
683 &context,
684 provider.claims_imports.displayname.is_required(),
685 )?
686 } else {
687 None
688 };
689
690 let ctx = if let Some(ref display_name) = display_name {
691 ctx.with_display_name(
692 display_name.clone(),
693 provider.claims_imports.email.is_forced(),
694 )
695 } else {
696 ctx
697 };
698
699 let email = if provider.claims_imports.email.should_import(import_email) {
700 let template = provider
701 .claims_imports
702 .email
703 .template
704 .as_deref()
705 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
706
707 render_attribute_template(
708 &env,
709 template,
710 &context,
711 provider.claims_imports.email.is_required(),
712 )?
713 } else {
714 None
715 };
716
717 let ctx = if let Some(ref email) = email {
718 ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
719 } else {
720 ctx
721 };
722
723 let username = if provider.claims_imports.localpart.is_forced() {
724 let template = provider
725 .claims_imports
726 .localpart
727 .template
728 .as_deref()
729 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
730
731 render_attribute_template(&env, template, &context, true)?
732 } else {
733 username
735 }
736 .unwrap_or_default();
737
738 let ctx = ctx.with_localpart(
739 username.clone(),
740 provider.claims_imports.localpart.is_forced(),
741 );
742
743 let form_state = {
745 let mut form_state = form_state;
746 let mut homeserver_denied_username = false;
747 if username.is_empty() {
748 form_state.add_error_on_field(
749 mas_templates::UpstreamRegisterFormField::Username,
750 FieldError::Required,
751 );
752 } else if repo.user().exists(&username).await? {
753 form_state.add_error_on_field(
754 mas_templates::UpstreamRegisterFormField::Username,
755 FieldError::Exists,
756 );
757 } else if !homeserver
758 .is_localpart_available(&username)
759 .await
760 .map_err(RouteError::HomeserverConnection)?
761 {
762 tracing::warn!(
764 %username,
765 "Homeserver denied username provided by user"
766 );
767
768 homeserver_denied_username = true;
771 }
772
773 if site_config.tos_uri.is_some() && !accept_terms {
775 form_state.add_error_on_field(
776 mas_templates::UpstreamRegisterFormField::AcceptTerms,
777 FieldError::Required,
778 );
779 }
780
781 let res = policy
783 .evaluate_register(mas_policy::RegisterInput {
784 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
785 username: &username,
786 email: email.as_deref(),
787 requester: mas_policy::Requester {
788 ip_address: activity_tracker.ip(),
789 user_agent: user_agent.clone(),
790 },
791 })
792 .await?;
793
794 for violation in res.violations {
795 match violation.field.as_deref() {
796 Some("username") => {
797 homeserver_denied_username = false;
801 form_state.add_error_on_field(
802 mas_templates::UpstreamRegisterFormField::Username,
803 FieldError::Policy {
804 code: violation.code.map(|c| c.as_str()),
805 message: violation.msg,
806 },
807 );
808 }
809 _ => form_state.add_error_on_form(FormError::Policy {
810 code: violation.code.map(|c| c.as_str()),
811 message: violation.msg,
812 }),
813 }
814 }
815
816 if homeserver_denied_username {
817 form_state.add_error_on_field(
819 mas_templates::UpstreamRegisterFormField::Username,
820 FieldError::Exists,
821 );
822 }
823
824 form_state
825 };
826
827 if !form_state.is_valid() {
828 let ctx = ctx
829 .with_form_state(form_state)
830 .with_csrf(csrf_token.form_value())
831 .with_language(locale);
832
833 return Ok((
834 cookie_jar,
835 Html(templates.render_upstream_oauth2_do_register(&ctx)?),
836 )
837 .into_response());
838 }
839
840 REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
841
842 let user = repo.user().add(&mut rng, &clock, username).await?;
844
845 if let Some(terms_url) = &site_config.tos_uri {
846 repo.user_terms()
847 .accept_terms(&mut rng, &clock, &user, terms_url.clone())
848 .await?;
849 }
850
851 let mut job = ProvisionUserJob::new(&user);
853
854 if let Some(name) = display_name {
856 job = job.set_display_name(name);
857 }
858
859 repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
860
861 if let Some(email) = email {
863 repo.user_email()
864 .add(&mut rng, &clock, &user, email)
865 .await?;
866 }
867
868 repo.upstream_oauth_link()
869 .associate_to_user(&link, &user)
870 .await?;
871
872 repo.browser_session()
873 .add(&mut rng, &clock, &user, user_agent)
874 .await?
875 }
876
877 _ => return Err(RouteError::InvalidFormAction),
878 };
879
880 let upstream_session = repo
881 .upstream_oauth_session()
882 .consume(&clock, upstream_session)
883 .await?;
884
885 repo.browser_session()
886 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
887 .await?;
888
889 let cookie_jar = sessions_cookie
890 .consume_link(link_id)?
891 .save(cookie_jar, &clock);
892 let cookie_jar = cookie_jar.set_session(&session);
893
894 repo.save().await?;
895
896 Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
897}
898
899#[cfg(test)]
900mod tests {
901 use hyper::{Request, StatusCode, header::CONTENT_TYPE};
902 use mas_data_model::{
903 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
904 UpstreamOAuthProviderTokenAuthMethod,
905 };
906 use mas_iana::jose::JsonWebSignatureAlg;
907 use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
908 use mas_router::Route;
909 use mas_storage::{
910 Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
911 };
912 use oauth2_types::scope::{OPENID, Scope};
913 use sqlx::PgPool;
914
915 use super::UpstreamSessionsCookie;
916 use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
917
918 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
919 async fn test_register(pool: PgPool) {
920 setup();
921 let state = TestState::from_pool(pool).await.unwrap();
922 let mut rng = state.rng();
923 let cookies = CookieHelper::new();
924
925 let claims_imports = UpstreamOAuthProviderClaimsImports {
926 localpart: UpstreamOAuthProviderImportPreference {
927 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
928 template: None,
929 },
930 email: UpstreamOAuthProviderImportPreference {
931 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
932 template: None,
933 },
934 ..UpstreamOAuthProviderClaimsImports::default()
935 };
936
937 let id_token = serde_json::json!({
938 "preferred_username": "john",
939 "email": "john@example.com",
940 "email_verified": true,
941 });
942
943 let key = state
947 .key_store
948 .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
949 .unwrap();
950
951 let signer = key
952 .params()
953 .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
954 .unwrap();
955 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
956 let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
957
958 let mut repo = state.repository().await.unwrap();
960 let provider = repo
961 .upstream_oauth_provider()
962 .add(
963 &mut rng,
964 &state.clock,
965 UpstreamOAuthProviderParams {
966 issuer: Some("https://example.com/".to_owned()),
967 human_name: Some("Example Ltd.".to_owned()),
968 brand_name: None,
969 scope: Scope::from_iter([OPENID]),
970 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
971 token_endpoint_signing_alg: None,
972 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
973 client_id: "client".to_owned(),
974 encrypted_client_secret: None,
975 claims_imports,
976 authorization_endpoint_override: None,
977 token_endpoint_override: None,
978 userinfo_endpoint_override: None,
979 fetch_userinfo: false,
980 userinfo_signed_response_alg: None,
981 jwks_uri_override: None,
982 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
983 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
984 response_mode: None,
985 additional_authorization_parameters: Vec::new(),
986 ui_order: 0,
987 },
988 )
989 .await
990 .unwrap();
991
992 let session = repo
993 .upstream_oauth_session()
994 .add(
995 &mut rng,
996 &state.clock,
997 &provider,
998 "state".to_owned(),
999 None,
1000 "nonce".to_owned(),
1001 )
1002 .await
1003 .unwrap();
1004
1005 let link = repo
1006 .upstream_oauth_link()
1007 .add(
1008 &mut rng,
1009 &state.clock,
1010 &provider,
1011 "subject".to_owned(),
1012 None,
1013 )
1014 .await
1015 .unwrap();
1016
1017 let session = repo
1018 .upstream_oauth_session()
1019 .complete_with_link(
1020 &state.clock,
1021 session,
1022 &link,
1023 Some(id_token.into_string()),
1024 None,
1025 None,
1026 )
1027 .await
1028 .unwrap();
1029
1030 repo.save().await.unwrap();
1031
1032 let cookie_jar = state.cookie_jar();
1033 let upstream_sessions = UpstreamSessionsCookie::default()
1034 .add(session.id, provider.id, "state".to_owned(), None)
1035 .add_link_to_session(session.id, link.id)
1036 .unwrap();
1037 let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1038 cookies.import(cookie_jar);
1039
1040 let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1041 let request = cookies.with_cookies(request);
1042 let response = state.request(request).await;
1043 cookies.save_cookies(&response);
1044 response.assert_status(StatusCode::OK);
1045 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1046
1047 let csrf_token = response
1049 .body()
1050 .split("name=\"csrf\" value=\"")
1051 .nth(1)
1052 .unwrap()
1053 .split('\"')
1054 .next()
1055 .unwrap();
1056
1057 let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1058 serde_json::json!({
1059 "csrf": csrf_token,
1060 "action": "register",
1061 "import_email": "on",
1062 "accept_terms": "on",
1063 }),
1064 );
1065 let request = cookies.with_cookies(request);
1066 let response = state.request(request).await;
1067 cookies.save_cookies(&response);
1068 response.assert_status(StatusCode::SEE_OTHER);
1069
1070 let mut repo = state.repository().await.unwrap();
1072 let user = repo
1073 .user()
1074 .find_by_username("john")
1075 .await
1076 .unwrap()
1077 .expect("user exists");
1078
1079 let link = repo
1080 .upstream_oauth_link()
1081 .find_by_subject(&provider, "subject")
1082 .await
1083 .unwrap()
1084 .expect("link exists");
1085
1086 assert_eq!(link.user_id, Some(user.id));
1087
1088 let page = repo
1089 .user_email()
1090 .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1091 .await
1092 .unwrap();
1093 let email = page.edges.first().expect("email exists");
1094
1095 assert_eq!(email.email, "john@example.com");
1096 }
1097}