1use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::record_error;
13use mas_iana::oauth::OAuthClientAuthenticationMethod;
14use mas_keystore::Encrypter;
15use mas_policy::{EvaluationResult, Policy};
16use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
17use oauth2_types::{
18 errors::{ClientError, ClientErrorCode},
19 registration::{
20 ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
21 VerifiedClientMetadata,
22 },
23};
24use opentelemetry::{Key, KeyValue, metrics::Counter};
25use psl::Psl;
26use rand::distributions::{Alphanumeric, DistString};
27use serde::Serialize;
28use sha2::Digest as _;
29use thiserror::Error;
30use tracing::info;
31use url::Url;
32
33use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
34
35static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
36 METER
37 .u64_counter("mas.oauth2.registration_request")
38 .with_description("Number of OAuth2 registration requests")
39 .with_unit("{request}")
40 .build()
41});
42const RESULT: Key = Key::from_static_str("result");
43
44#[derive(Debug, Error)]
45pub(crate) enum RouteError {
46 #[error(transparent)]
47 Internal(Box<dyn std::error::Error + Send + Sync>),
48
49 #[error(transparent)]
50 JsonExtract(#[from] axum::extract::rejection::JsonRejection),
51
52 #[error("invalid client metadata")]
53 InvalidClientMetadata(#[from] ClientMetadataVerificationError),
54
55 #[error("{0} is a public suffix, not a valid domain")]
56 UrlIsPublicSuffix(&'static str),
57
58 #[error("client registration denied by the policy: {0}")]
59 PolicyDenied(EvaluationResult),
60}
61
62impl_from_error_for_route!(mas_storage::RepositoryError);
63impl_from_error_for_route!(mas_policy::LoadError);
64impl_from_error_for_route!(mas_policy::EvaluationError);
65impl_from_error_for_route!(mas_keystore::aead::Error);
66impl_from_error_for_route!(serde_json::Error);
67
68impl IntoResponse for RouteError {
69 fn into_response(self) -> axum::response::Response {
70 let sentry_event_id = record_error!(self, Self::Internal(_));
71
72 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
73
74 let response = match self {
75 Self::Internal(_) => (
76 StatusCode::INTERNAL_SERVER_ERROR,
77 Json(ClientError::from(ClientErrorCode::ServerError)),
78 )
79 .into_response(),
80
81 Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
85 StatusCode::BAD_REQUEST,
86 Json(
87 ClientError::from(ClientErrorCode::InvalidClientMetadata)
88 .with_description(e.to_string()),
89 ),
90 )
91 .into_response(),
92
93 Self::JsonExtract(_) => (
96 StatusCode::BAD_REQUEST,
97 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
98 )
99 .into_response(),
100
101 Self::InvalidClientMetadata(
105 ClientMetadataVerificationError::MissingRedirectUris
106 | ClientMetadataVerificationError::RedirectUriWithFragment(_),
107 ) => (
108 StatusCode::BAD_REQUEST,
109 Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
110 )
111 .into_response(),
112
113 Self::InvalidClientMetadata(e) => (
114 StatusCode::BAD_REQUEST,
115 Json(
116 ClientError::from(ClientErrorCode::InvalidClientMetadata)
117 .with_description(e.to_string()),
118 ),
119 )
120 .into_response(),
121
122 Self::UrlIsPublicSuffix("redirect_uri") => (
126 StatusCode::BAD_REQUEST,
127 Json(
128 ClientError::from(ClientErrorCode::InvalidRedirectUri)
129 .with_description("redirect_uri is not using a valid domain".to_owned()),
130 ),
131 )
132 .into_response(),
133
134 Self::UrlIsPublicSuffix(field) => (
135 StatusCode::BAD_REQUEST,
136 Json(
137 ClientError::from(ClientErrorCode::InvalidClientMetadata)
138 .with_description(format!("{field} is not using a valid domain")),
139 ),
140 )
141 .into_response(),
142
143 Self::PolicyDenied(evaluation) => {
147 let code = if evaluation
149 .violations
150 .iter()
151 .any(|v| v.msg.contains("redirect_uri"))
152 {
153 ClientErrorCode::InvalidRedirectUri
154 } else {
155 ClientErrorCode::InvalidClientMetadata
156 };
157
158 let collected = &evaluation
159 .violations
160 .iter()
161 .map(|v| v.msg.clone())
162 .collect::<Vec<String>>();
163 let joined = collected.join("; ");
164
165 (
166 StatusCode::BAD_REQUEST,
167 Json(ClientError::from(code).with_description(joined)),
168 )
169 .into_response()
170 }
171 };
172
173 (sentry_event_id, response).into_response()
174 }
175}
176
177#[derive(Serialize)]
178struct RouteResponse {
179 #[serde(flatten)]
180 response: ClientRegistrationResponse,
181 #[serde(flatten)]
182 metadata: VerifiedClientMetadata,
183}
184
185fn host_is_public_suffix(url: &Url) -> bool {
187 let host = url.host_str().unwrap_or_default().as_bytes();
188 let Some(suffix) = psl::List.suffix(host) else {
189 return false;
192 };
193
194 if !suffix.is_known() {
195 return false;
197 }
198
199 if host.len() <= suffix.as_bytes().len() + 1 {
203 return true;
205 }
206
207 false
208}
209
210fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
212 url.iter().any(|(_lang, url)| host_is_public_suffix(url))
213}
214
215#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
216pub(crate) async fn post(
217 mut rng: BoxRng,
218 clock: BoxClock,
219 mut repo: BoxRepository,
220 mut policy: Policy,
221 activity_tracker: BoundActivityTracker,
222 user_agent: Option<TypedHeader<headers::UserAgent>>,
223 State(encrypter): State<Encrypter>,
224 body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
225) -> Result<impl IntoResponse, RouteError> {
226 let Json(body) = body?;
228
229 let body = body.sorted();
231
232 let body_json = serde_json::to_string(&body)?;
234
235 info!(body = body_json, "Client registration");
236
237 let user_agent = user_agent.map(|ua| ua.to_string());
238
239 let metadata = body.validate()?;
241
242 if let Some(client_uri) = &metadata.client_uri {
245 if localised_url_has_public_suffix(client_uri) {
246 return Err(RouteError::UrlIsPublicSuffix("client_uri"));
247 }
248 }
249
250 if let Some(logo_uri) = &metadata.logo_uri {
251 if localised_url_has_public_suffix(logo_uri) {
252 return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
253 }
254 }
255
256 if let Some(policy_uri) = &metadata.policy_uri {
257 if localised_url_has_public_suffix(policy_uri) {
258 return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
259 }
260 }
261
262 if let Some(tos_uri) = &metadata.tos_uri {
263 if localised_url_has_public_suffix(tos_uri) {
264 return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
265 }
266 }
267
268 if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
269 if host_is_public_suffix(initiate_login_uri) {
270 return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
271 }
272 }
273
274 for redirect_uri in metadata.redirect_uris() {
275 if host_is_public_suffix(redirect_uri) {
276 return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
277 }
278 }
279
280 let res = policy
281 .evaluate_client_registration(mas_policy::ClientRegistrationInput {
282 client_metadata: &metadata,
283 requester: mas_policy::Requester {
284 ip_address: activity_tracker.ip(),
285 user_agent,
286 },
287 })
288 .await?;
289 if !res.valid() {
290 return Err(RouteError::PolicyDenied(res));
291 }
292
293 let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
294 Some(
295 OAuthClientAuthenticationMethod::ClientSecretJwt
296 | OAuthClientAuthenticationMethod::ClientSecretPost
297 | OAuthClientAuthenticationMethod::ClientSecretBasic,
298 ) => {
299 let client_secret = Alphanumeric.sample_string(&mut rng, 20);
301 let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
302 (Some(client_secret), Some(encrypted_client_secret))
303 }
304 _ => (None, None),
305 };
306
307 let (digest_hash, existing_client) = if client_secret.is_none() {
310 let hash = sha2::Sha256::digest(body_json);
317 let hash = hex::encode(hash);
318 let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
319 (Some(hash), client)
320 } else {
321 (None, None)
322 };
323
324 let client = if let Some(client) = existing_client {
325 tracing::info!(%client.id, "Reusing existing client");
326 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
327 client
328 } else {
329 let client = repo
330 .oauth2_client()
331 .add(
332 &mut rng,
333 &clock,
334 metadata.redirect_uris().to_vec(),
335 digest_hash,
336 encrypted_client_secret,
337 metadata.application_type.clone(),
338 metadata.grant_types().to_vec(),
340 metadata
341 .client_name
342 .clone()
343 .map(Localized::to_non_localized),
344 metadata.logo_uri.clone().map(Localized::to_non_localized),
345 metadata.client_uri.clone().map(Localized::to_non_localized),
346 metadata.policy_uri.clone().map(Localized::to_non_localized),
347 metadata.tos_uri.clone().map(Localized::to_non_localized),
348 metadata.jwks_uri.clone(),
349 metadata.jwks.clone(),
350 metadata.id_token_signed_response_alg.clone(),
352 metadata.userinfo_signed_response_alg.clone(),
353 metadata.token_endpoint_auth_method.clone(),
354 metadata.token_endpoint_auth_signing_alg.clone(),
355 metadata.initiate_login_uri.clone(),
356 )
357 .await?;
358 tracing::info!(%client.id, "Registered new client");
359 REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
360 client
361 };
362
363 let response = ClientRegistrationResponse {
364 client_id: client.client_id.clone(),
365 client_secret,
366 client_id_issued_at: Some(client.id.datetime().into()),
368 client_secret_expires_at: None,
369 };
370
371 let metadata = client.into_metadata().validate()?;
374
375 repo.save().await?;
376
377 let response = RouteResponse { response, metadata };
378
379 Ok((StatusCode::CREATED, Json(response)))
380}
381
382#[cfg(test)]
383mod tests {
384 use hyper::{Request, StatusCode};
385 use mas_router::SimpleRoute;
386 use oauth2_types::{
387 errors::{ClientError, ClientErrorCode},
388 registration::ClientRegistrationResponse,
389 };
390 use sqlx::PgPool;
391 use url::Url;
392
393 use crate::{
394 oauth2::registration::host_is_public_suffix,
395 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
396 };
397
398 #[test]
399 fn test_public_suffix_list() {
400 fn url_is_public_suffix(url: &str) -> bool {
401 host_is_public_suffix(&Url::parse(url).unwrap())
402 }
403
404 assert!(url_is_public_suffix("https://.com"));
405 assert!(url_is_public_suffix("https://.com."));
406 assert!(url_is_public_suffix("https://co.uk"));
407 assert!(url_is_public_suffix("https://github.io"));
408 assert!(!url_is_public_suffix("https://example.com"));
409 assert!(!url_is_public_suffix("https://example.com."));
410 assert!(!url_is_public_suffix("https://x.com"));
411 assert!(!url_is_public_suffix("https://x.com."));
412 assert!(!url_is_public_suffix("https://matrix-org.github.io"));
413 assert!(!url_is_public_suffix("http://localhost"));
414 assert!(!url_is_public_suffix("org.matrix:/callback"));
415 assert!(!url_is_public_suffix("http://somerandominternaldomain"));
416 }
417
418 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
419 async fn test_registration_error(pool: PgPool) {
420 setup();
421 let state = TestState::from_pool(pool).await.unwrap();
422
423 let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
425 .body("this is not a json".to_owned())
426 .unwrap();
427
428 let response = state.request(request).await;
429 response.assert_status(StatusCode::BAD_REQUEST);
430 let response: ClientError = response.json();
431 assert_eq!(response.error, ClientErrorCode::InvalidRequest);
432
433 let request =
435 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
436 "client_uri": "this is not a uri",
437 }));
438
439 let response = state.request(request).await;
440 response.assert_status(StatusCode::BAD_REQUEST);
441 let response: ClientError = response.json();
442 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
443
444 let request =
446 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
447 "application_type": "web",
448 "client_uri": "https://example.com/",
449 "redirect_uris": ["http://this-is-insecure.com/"],
450 }));
451
452 let response = state.request(request).await;
453 response.assert_status(StatusCode::BAD_REQUEST);
454 let response: ClientError = response.json();
455 assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
456
457 let request =
459 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
460 "client_uri": "https://example.com/",
461 "redirect_uris": ["https://example.com/"],
462 "response_types": ["id_token"],
463 "grant_types": ["authorization_code"],
464 }));
465
466 let response = state.request(request).await;
467 response.assert_status(StatusCode::BAD_REQUEST);
468 let response: ClientError = response.json();
469 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
470
471 let request =
473 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
474 "client_uri": "https://github.io/",
475 "redirect_uris": ["https://github.io/"],
476 "response_types": ["code"],
477 "grant_types": ["authorization_code"],
478 "token_endpoint_auth_method": "client_secret_basic",
479 }));
480
481 let response = state.request(request).await;
482 response.assert_status(StatusCode::BAD_REQUEST);
483 let response: ClientError = response.json();
484 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
485 assert_eq!(
486 response.error_description.unwrap(),
487 "client_uri is not using a valid domain"
488 );
489
490 let request =
492 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
493 "client_uri": "https://example.com/",
494 "client_uri#fr-FR": "https://github.io/",
495 "redirect_uris": ["https://example.com/"],
496 "response_types": ["code"],
497 "grant_types": ["authorization_code"],
498 "token_endpoint_auth_method": "client_secret_basic",
499 }));
500
501 let response = state.request(request).await;
502 response.assert_status(StatusCode::BAD_REQUEST);
503 let response: ClientError = response.json();
504 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
505 assert_eq!(
506 response.error_description.unwrap(),
507 "client_uri is not using a valid domain"
508 );
509 }
510
511 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
512 async fn test_registration(pool: PgPool) {
513 setup();
514 let state = TestState::from_pool(pool).await.unwrap();
515
516 let request =
519 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
520 "client_uri": "https://example.com/",
521 "redirect_uris": ["https://example.com/"],
522 "response_types": ["code"],
523 "grant_types": ["authorization_code"],
524 "token_endpoint_auth_method": "none",
525 }));
526
527 let response = state.request(request).await;
528 response.assert_status(StatusCode::CREATED);
529 let response: ClientRegistrationResponse = response.json();
530 assert!(response.client_secret.is_none());
531
532 let request =
535 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
536 "client_uri": "https://example.com/",
537 "redirect_uris": ["https://example.com/"],
538 "response_types": ["code"],
539 "grant_types": ["authorization_code"],
540 "token_endpoint_auth_method": "client_secret_basic",
541 }));
542
543 let response = state.request(request).await;
544 response.assert_status(StatusCode::CREATED);
545 let response: ClientRegistrationResponse = response.json();
546 assert!(response.client_secret.is_some());
547 }
548 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
549 async fn test_registration_dedupe(pool: PgPool) {
550 setup();
551 let state = TestState::from_pool(pool).await.unwrap();
552
553 let request =
555 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
556 "client_uri": "https://example.com/",
557 "client_name": "Example",
558 "client_name#en": "Example",
559 "client_name#fr": "Exemple",
560 "client_name#de": "Beispiel",
561 "redirect_uris": ["https://example.com/", "https://example.com/callback"],
562 "response_types": ["code"],
563 "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
564 "token_endpoint_auth_method": "none",
565 }));
566
567 let response = state.request(request.clone()).await;
568 response.assert_status(StatusCode::CREATED);
569 let response: ClientRegistrationResponse = response.json();
570 let client_id = response.client_id;
571
572 let response = state.request(request).await;
573 response.assert_status(StatusCode::CREATED);
574 let response: ClientRegistrationResponse = response.json();
575 assert_eq!(response.client_id, client_id);
576
577 let request =
579 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
580 "client_uri": "https://example.com/",
581 "client_name": "Example",
582 "client_name#de": "Beispiel",
583 "client_name#fr": "Exemple",
584 "client_name#en": "Example",
585 "redirect_uris": ["https://example.com/callback", "https://example.com/"],
586 "response_types": ["code"],
587 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
588 "token_endpoint_auth_method": "none",
589 }));
590
591 let response = state.request(request).await;
592 response.assert_status(StatusCode::CREATED);
593 let response: ClientRegistrationResponse = response.json();
594 assert_eq!(response.client_id, client_id);
595
596 let request =
598 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
599 "client_uri": "https://example.com/",
600 "redirect_uris": ["https://example.com/"],
601 "response_types": ["code"],
602 "grant_types": ["authorization_code"],
603 "token_endpoint_auth_method": "client_secret_basic",
604 }));
605
606 let response = state.request(request.clone()).await;
607 response.assert_status(StatusCode::CREATED);
608 let response: ClientRegistrationResponse = response.json();
609 assert_ne!(response.client_id, client_id);
611 let client_id = response.client_id;
612
613 let response = state.request(request).await;
614 response.assert_status(StatusCode::CREATED);
615 let response: ClientRegistrationResponse = response.json();
616 assert_ne!(response.client_id, client_id);
617 }
618}