1use std::{str::FromStr, sync::Arc};
8
9use axum::{
10 extract::{Form, Query, State},
11 response::{Html, IntoResponse, Response},
12};
13use axum_extra::typed_header::TypedHeader;
14use hyper::StatusCode;
15use lettre::Address;
16use mas_axum_utils::{
17 InternalError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, CsrfToken, ProtectedForm},
20};
21use mas_data_model::CaptchaConfig;
22use mas_i18n::DataLocale;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28 queue::{QueueJobRepositoryExt as _, SendEmailAuthenticationCodeJob},
29 user::{UserEmailRepository, UserRepository},
30};
31use mas_templates::{
32 FieldError, FormError, FormState, PasswordRegisterContext, RegisterFormField, TemplateContext,
33 Templates, ToFormState,
34};
35use serde::{Deserialize, Serialize};
36use zeroize::Zeroizing;
37
38use super::cookie::UserRegistrationSessions;
39use crate::{
40 BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint, SiteConfig,
41 captcha::Form as CaptchaForm, passwords::PasswordManager,
42 views::shared::OptionalPostAuthAction,
43};
44
45#[derive(Debug, Deserialize, Serialize)]
46pub(crate) struct RegisterForm {
47 username: String,
48 email: String,
49 password: String,
50 password_confirm: String,
51 #[serde(default)]
52 accept_terms: String,
53
54 #[serde(flatten, skip_serializing)]
55 captcha: CaptchaForm,
56}
57
58impl ToFormState for RegisterForm {
59 type Field = RegisterFormField;
60}
61
62#[derive(Deserialize)]
63pub struct QueryParams {
64 username: Option<String>,
65 #[serde(flatten)]
66 action: OptionalPostAuthAction,
67}
68
69#[tracing::instrument(name = "handlers.views.password_register.get", skip_all)]
70pub(crate) async fn get(
71 mut rng: BoxRng,
72 clock: BoxClock,
73 PreferredLanguage(locale): PreferredLanguage,
74 State(templates): State<Templates>,
75 State(url_builder): State<UrlBuilder>,
76 State(site_config): State<SiteConfig>,
77 mut repo: BoxRepository,
78 Query(query): Query<QueryParams>,
79 cookie_jar: CookieJar,
80) -> Result<Response, InternalError> {
81 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
82 let (session_info, cookie_jar) = cookie_jar.session_info();
83
84 let maybe_session = session_info.load_active_session(&mut repo).await?;
85
86 if maybe_session.is_some() {
87 let reply = query.action.go_next(&url_builder);
88 return Ok((cookie_jar, reply).into_response());
89 }
90
91 if !site_config.password_registration_enabled {
92 return Ok(url_builder
94 .redirect(&mas_router::Login::from(query.action.post_auth_action))
95 .into_response());
96 }
97
98 let mut ctx = PasswordRegisterContext::default();
99
100 if let Some(username) = query.username {
102 let mut form_state = FormState::default();
103 form_state.set_value(RegisterFormField::Username, Some(username));
104 ctx = ctx.with_form_state(form_state);
105 }
106
107 let content = render(
108 locale,
109 ctx,
110 query.action,
111 csrf_token,
112 &mut repo,
113 &templates,
114 site_config.captcha.clone(),
115 )
116 .await?;
117
118 Ok((cookie_jar, Html(content)).into_response())
119}
120
121#[tracing::instrument(name = "handlers.views.password_register.post", skip_all)]
122#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
123pub(crate) async fn post(
124 mut rng: BoxRng,
125 clock: BoxClock,
126 PreferredLanguage(locale): PreferredLanguage,
127 State(password_manager): State<PasswordManager>,
128 State(templates): State<Templates>,
129 State(url_builder): State<UrlBuilder>,
130 State(site_config): State<SiteConfig>,
131 State(homeserver): State<Arc<dyn HomeserverConnection>>,
132 State(http_client): State<reqwest::Client>,
133 (State(limiter), requester): (State<Limiter>, RequesterFingerprint),
134 mut policy: Policy,
135 mut repo: BoxRepository,
136 (user_agent, activity_tracker): (
137 Option<TypedHeader<headers::UserAgent>>,
138 BoundActivityTracker,
139 ),
140 Query(query): Query<OptionalPostAuthAction>,
141 cookie_jar: CookieJar,
142 Form(form): Form<ProtectedForm<RegisterForm>>,
143) -> Result<Response, InternalError> {
144 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
145
146 let ip_address = activity_tracker.ip();
147 if !site_config.password_registration_enabled {
148 return Ok(StatusCode::METHOD_NOT_ALLOWED.into_response());
149 }
150
151 let form = cookie_jar.verify_form(&clock, form)?;
152
153 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
154
155 let passed_captcha = form
158 .captcha
159 .verify(
160 &activity_tracker,
161 &http_client,
162 url_builder.public_hostname(),
163 site_config.captcha.as_ref(),
164 )
165 .await
166 .is_ok();
167
168 let state = {
170 let mut state = form.to_form_state();
171
172 if !passed_captcha {
173 state.add_error_on_form(FormError::Captcha);
174 }
175
176 let mut homeserver_denied_username = false;
177 if form.username.is_empty() {
178 state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
179 } else if repo.user().exists(&form.username).await? {
180 state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
182 } else if !homeserver
183 .is_localpart_available(&form.username)
184 .await
185 .map_err(InternalError::from_anyhow)?
186 {
187 tracing::warn!(
189 username = &form.username,
190 "Homeserver denied username provided by user"
191 );
192
193 homeserver_denied_username = true;
196 }
197
198 if form.email.is_empty() {
202 state.add_error_on_field(RegisterFormField::Email, FieldError::Required);
203 } else if Address::from_str(&form.email).is_err() {
204 state.add_error_on_field(RegisterFormField::Email, FieldError::Invalid);
205 }
206
207 if form.password.is_empty() {
208 state.add_error_on_field(RegisterFormField::Password, FieldError::Required);
209 }
210
211 if form.password_confirm.is_empty() {
212 state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Required);
213 }
214
215 if form.password != form.password_confirm {
216 state.add_error_on_field(RegisterFormField::Password, FieldError::Unspecified);
217 state.add_error_on_field(
218 RegisterFormField::PasswordConfirm,
219 FieldError::PasswordMismatch,
220 );
221 }
222
223 if !password_manager.is_password_complex_enough(&form.password)? {
224 state.add_error_on_field(
226 RegisterFormField::Password,
227 FieldError::Policy {
228 code: None,
229 message: "Password is too weak".to_owned(),
230 },
231 );
232 }
233
234 if site_config.tos_uri.is_some() && form.accept_terms != "on" {
236 state.add_error_on_field(RegisterFormField::AcceptTerms, FieldError::Required);
237 }
238
239 let res = policy
240 .evaluate_register(mas_policy::RegisterInput {
241 registration_method: mas_policy::RegistrationMethod::Password,
242 username: &form.username,
243 email: Some(&form.email),
244 requester: mas_policy::Requester {
245 ip_address: activity_tracker.ip(),
246 user_agent: user_agent.clone(),
247 },
248 })
249 .await?;
250
251 for violation in res.violations {
252 match violation.field.as_deref() {
253 Some("email") => state.add_error_on_field(
254 RegisterFormField::Email,
255 FieldError::Policy {
256 code: violation.code.map(|c| c.as_str()),
257 message: violation.msg,
258 },
259 ),
260 Some("username") => {
261 homeserver_denied_username = false;
264 state.add_error_on_field(
265 RegisterFormField::Username,
266 FieldError::Policy {
267 code: violation.code.map(|c| c.as_str()),
268 message: violation.msg,
269 },
270 );
271 }
272 Some("password") => state.add_error_on_field(
273 RegisterFormField::Password,
274 FieldError::Policy {
275 code: violation.code.map(|c| c.as_str()),
276 message: violation.msg,
277 },
278 ),
279 _ => state.add_error_on_form(FormError::Policy {
280 code: violation.code.map(|c| c.as_str()),
281 message: violation.msg,
282 }),
283 }
284 }
285
286 if homeserver_denied_username {
287 state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
289 }
290
291 if state.is_valid() {
292 if let Err(e) = limiter.check_registration(requester) {
294 tracing::warn!(error = &e as &dyn std::error::Error);
295 state.add_error_on_form(FormError::RateLimitExceeded);
296 }
297
298 if let Err(e) = limiter.check_email_authentication_email(requester, &form.email) {
299 tracing::warn!(error = &e as &dyn std::error::Error);
300 state.add_error_on_form(FormError::RateLimitExceeded);
301 }
302 }
303
304 state
305 };
306
307 if !state.is_valid() {
308 let content = render(
309 locale,
310 PasswordRegisterContext::default().with_form_state(state),
311 query,
312 csrf_token,
313 &mut repo,
314 &templates,
315 site_config.captcha.clone(),
316 )
317 .await?;
318
319 return Ok((cookie_jar, Html(content)).into_response());
320 }
321
322 let post_auth_action = query
323 .post_auth_action
324 .map(serde_json::to_value)
325 .transpose()?;
326 let registration = repo
327 .user_registration()
328 .add(
329 &mut rng,
330 &clock,
331 form.username,
332 ip_address,
333 user_agent,
334 post_auth_action,
335 )
336 .await?;
337
338 let registration = if let Some(tos_uri) = &site_config.tos_uri {
339 repo.user_registration()
340 .set_terms_url(registration, tos_uri.clone())
341 .await?
342 } else {
343 registration
344 };
345
346 let user_email_authentication = repo
348 .user_email()
349 .add_authentication_for_registration(&mut rng, &clock, form.email, ®istration)
350 .await?;
351
352 repo.queue_job()
354 .schedule_job(
355 &mut rng,
356 &clock,
357 SendEmailAuthenticationCodeJob::new(&user_email_authentication, locale.to_string()),
358 )
359 .await?;
360
361 let registration = repo
362 .user_registration()
363 .set_email_authentication(registration, &user_email_authentication)
364 .await?;
365
366 let password = Zeroizing::new(form.password.into_bytes());
368 let (version, hashed_password) = password_manager
369 .hash(&mut rng, password)
370 .await
371 .map_err(InternalError::from_anyhow)?;
372
373 let registration = repo
375 .user_registration()
376 .set_password(registration, hashed_password, version)
377 .await?;
378
379 repo.save().await?;
380
381 let cookie_jar = UserRegistrationSessions::load(&cookie_jar)
382 .add(®istration)
383 .save(cookie_jar, &clock);
384
385 Ok((
386 cookie_jar,
387 url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
388 )
389 .into_response())
390}
391
392async fn render(
393 locale: DataLocale,
394 ctx: PasswordRegisterContext,
395 action: OptionalPostAuthAction,
396 csrf_token: CsrfToken,
397 repo: &mut impl RepositoryAccess,
398 templates: &Templates,
399 captcha_config: Option<CaptchaConfig>,
400) -> Result<String, InternalError> {
401 let next = action
402 .load_context(repo)
403 .await
404 .map_err(InternalError::from_anyhow)?;
405 let ctx = if let Some(next) = next {
406 ctx.with_post_action(next)
407 } else {
408 ctx
409 };
410 let ctx = ctx
411 .with_captcha(captcha_config)
412 .with_csrf(csrf_token.form_value())
413 .with_language(locale);
414
415 let content = templates.render_password_register(&ctx)?;
416 Ok(content)
417}
418
419#[cfg(test)]
420mod tests {
421 use hyper::{
422 Request, StatusCode,
423 header::{CONTENT_TYPE, LOCATION},
424 };
425 use mas_router::Route;
426 use sqlx::PgPool;
427
428 use crate::{
429 SiteConfig,
430 test_utils::{
431 CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup, test_site_config,
432 },
433 };
434
435 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
436 async fn test_password_disabled(pool: PgPool) {
437 setup();
438 let state = TestState::from_pool_with_site_config(
439 pool,
440 SiteConfig {
441 password_login_enabled: false,
442 password_registration_enabled: false,
443 ..test_site_config()
444 },
445 )
446 .await
447 .unwrap();
448
449 let request =
450 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
451 let response = state.request(request).await;
452 response.assert_status(StatusCode::SEE_OTHER);
453 response.assert_header_value(LOCATION, "/login");
454
455 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
456 .form(serde_json::json!({
457 "csrf": "abc",
458 "username": "john",
459 "email": "john@example.com",
460 "password": "hunter2",
461 "password_confirm": "hunter2",
462 }));
463 let response = state.request(request).await;
464 response.assert_status(StatusCode::METHOD_NOT_ALLOWED);
465 }
466
467 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
469 async fn test_register(pool: PgPool) {
470 setup();
471 let state = TestState::from_pool(pool).await.unwrap();
472 let cookies = CookieHelper::new();
473
474 let request =
476 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
477 let request = cookies.with_cookies(request);
478 let response = state.request(request).await;
479 cookies.save_cookies(&response);
480 response.assert_status(StatusCode::OK);
481 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
482 let csrf_token = response
484 .body()
485 .split("name=\"csrf\" value=\"")
486 .nth(1)
487 .unwrap()
488 .split('\"')
489 .next()
490 .unwrap();
491
492 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
494 .form(serde_json::json!({
495 "csrf": csrf_token,
496 "username": "john",
497 "email": "john@example.com",
498 "password": "correcthorsebatterystaple",
499 "password_confirm": "correcthorsebatterystaple",
500 "accept_terms": "on",
501 }));
502 let request = cookies.with_cookies(request);
503 let response = state.request(request).await;
504 cookies.save_cookies(&response);
505 response.assert_status(StatusCode::SEE_OTHER);
506 let location = response.headers().get(LOCATION).unwrap();
507
508 let id = location
510 .to_str()
511 .unwrap()
512 .rsplit('/')
513 .nth(1)
514 .unwrap()
515 .parse()
516 .unwrap();
517
518 let mut repo = state.repository().await.unwrap();
520 let registration = repo.user_registration().lookup(id).await.unwrap().unwrap();
521 assert_eq!(registration.username, "john".to_owned());
522 assert!(registration.password.is_some());
523
524 let email_authentication = repo
525 .user_email()
526 .lookup_authentication(registration.email_authentication_id.unwrap())
527 .await
528 .unwrap()
529 .unwrap();
530 assert_eq!(email_authentication.email, "john@example.com");
531 }
532
533 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
535 async fn test_register_password_mismatch(pool: PgPool) {
536 setup();
537 let state = TestState::from_pool(pool).await.unwrap();
538 let cookies = CookieHelper::new();
539
540 let request =
542 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
543 let request = cookies.with_cookies(request);
544 let response = state.request(request).await;
545 cookies.save_cookies(&response);
546 response.assert_status(StatusCode::OK);
547 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
548 let csrf_token = response
550 .body()
551 .split("name=\"csrf\" value=\"")
552 .nth(1)
553 .unwrap()
554 .split('\"')
555 .next()
556 .unwrap();
557
558 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
560 .form(serde_json::json!({
561 "csrf": csrf_token,
562 "username": "john",
563 "email": "john@example.com",
564 "password": "hunter2",
565 "password_confirm": "mismatch",
566 "accept_terms": "on",
567 }));
568 let request = cookies.with_cookies(request);
569 let response = state.request(request).await;
570 cookies.save_cookies(&response);
571 response.assert_status(StatusCode::OK);
572 assert!(response.body().contains("Password fields don't match"));
573 }
574
575 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
576 async fn test_register_username_too_long(pool: PgPool) {
577 setup();
578 let state = TestState::from_pool(pool).await.unwrap();
579 let cookies = CookieHelper::new();
580
581 let request =
583 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
584 let request = cookies.with_cookies(request);
585 let response = state.request(request).await;
586 cookies.save_cookies(&response);
587 response.assert_status(StatusCode::OK);
588 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
589 let csrf_token = response
591 .body()
592 .split("name=\"csrf\" value=\"")
593 .nth(1)
594 .unwrap()
595 .split('\"')
596 .next()
597 .unwrap();
598
599 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
601 .form(serde_json::json!({
602 "csrf": csrf_token,
603 "username": "a".repeat(256),
604 "email": "john@example.com",
605 "password": "hunter2",
606 "password_confirm": "hunter2",
607 "accept_terms": "on",
608 }));
609 let request = cookies.with_cookies(request);
610 let response = state.request(request).await;
611 cookies.save_cookies(&response);
612 response.assert_status(StatusCode::OK);
613 assert!(
614 response.body().contains("Username is too long"),
615 "response body: {}",
616 response.body()
617 );
618 }
619
620 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
622 async fn test_register_user_exists(pool: PgPool) {
623 setup();
624 let state = TestState::from_pool(pool).await.unwrap();
625 let mut rng = state.rng();
626 let cookies = CookieHelper::new();
627
628 let mut repo = state.repository().await.unwrap();
630 repo.user()
631 .add(&mut rng, &state.clock, "john".to_owned())
632 .await
633 .unwrap();
634 repo.save().await.unwrap();
635
636 let request =
638 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
639 let request = cookies.with_cookies(request);
640 let response = state.request(request).await;
641 cookies.save_cookies(&response);
642 response.assert_status(StatusCode::OK);
643 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
644 let csrf_token = response
646 .body()
647 .split("name=\"csrf\" value=\"")
648 .nth(1)
649 .unwrap()
650 .split('\"')
651 .next()
652 .unwrap();
653
654 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
656 .form(serde_json::json!({
657 "csrf": csrf_token,
658 "username": "john",
659 "email": "john@example.com",
660 "password": "hunter2",
661 "password_confirm": "hunter2",
662 "accept_terms": "on",
663 }));
664 let request = cookies.with_cookies(request);
665 let response = state.request(request).await;
666 cookies.save_cookies(&response);
667 response.assert_status(StatusCode::OK);
668 assert!(response.body().contains("This username is already taken"));
669 }
670
671 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
674 async fn test_register_user_reserved(pool: PgPool) {
675 setup();
676 let state = TestState::from_pool(pool).await.unwrap();
677 let cookies = CookieHelper::new();
678
679 let request =
681 Request::get(&*mas_router::PasswordRegister::default().path_and_query()).empty();
682 let request = cookies.with_cookies(request);
683 let response = state.request(request).await;
684 cookies.save_cookies(&response);
685 response.assert_status(StatusCode::OK);
686 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
687 let csrf_token = response
689 .body()
690 .split("name=\"csrf\" value=\"")
691 .nth(1)
692 .unwrap()
693 .split('\"')
694 .next()
695 .unwrap();
696
697 state.homeserver_connection.reserve_localpart("john").await;
699
700 let request = Request::post(&*mas_router::PasswordRegister::default().path_and_query())
702 .form(serde_json::json!({
703 "csrf": csrf_token,
704 "username": "john",
705 "email": "john@example.com",
706 "password": "hunter2",
707 "password_confirm": "hunter2",
708 "accept_terms": "on",
709 }));
710 let request = cookies.with_cookies(request);
711 let response = state.request(request).await;
712 cookies.save_cookies(&response);
713 response.assert_status(StatusCode::OK);
714 assert!(response.body().contains("This username is already taken"));
715 }
716}