mas_config/sections/
clients.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::ops::Deref;
8
9use figment::Figment;
10use mas_iana::oauth::OAuthClientAuthenticationMethod;
11use mas_jose::jwk::PublicJsonWebKeySet;
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize, de::Error};
14use ulid::Ulid;
15use url::Url;
16
17use super::ConfigurationSection;
18
19#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
20#[serde(rename_all = "snake_case")]
21pub enum JwksOrJwksUri {
22    Jwks(PublicJsonWebKeySet),
23    JwksUri(Url),
24}
25
26impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
27    fn from(jwks: PublicJsonWebKeySet) -> Self {
28        Self::Jwks(jwks)
29    }
30}
31
32/// Authentication method used by clients
33#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
34#[serde(rename_all = "snake_case")]
35pub enum ClientAuthMethodConfig {
36    /// `none`: No authentication
37    None,
38
39    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
40    /// authorization credentials
41    ClientSecretBasic,
42
43    /// `client_secret_post`: `client_id` and `client_secret` sent in the
44    /// request body
45    ClientSecretPost,
46
47    /// `client_secret_basic`: a `client_assertion` sent in the request body and
48    /// signed using the `client_secret`
49    ClientSecretJwt,
50
51    /// `client_secret_basic`: a `client_assertion` sent in the request body and
52    /// signed by an asymmetric key
53    PrivateKeyJwt,
54}
55
56impl std::fmt::Display for ClientAuthMethodConfig {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            ClientAuthMethodConfig::None => write!(f, "none"),
60            ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
61            ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
62            ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
63            ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
64        }
65    }
66}
67
68/// An OAuth 2.0 client configuration
69#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
70pub struct ClientConfig {
71    /// The client ID
72    #[schemars(
73        with = "String",
74        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
75        description = "A ULID as per https://github.com/ulid/spec"
76    )]
77    pub client_id: Ulid,
78
79    /// Authentication method used for this client
80    client_auth_method: ClientAuthMethodConfig,
81
82    /// Name of the `OAuth2` client
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub client_name: Option<String>,
85
86    /// The client secret, used by the `client_secret_basic`,
87    /// `client_secret_post` and `client_secret_jwt` authentication methods
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub client_secret: Option<String>,
90
91    /// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
92    /// method. Mutually exclusive with `jwks_uri`
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub jwks: Option<PublicJsonWebKeySet>,
95
96    /// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
97    /// authentication method. Mutually exclusive with `jwks`
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub jwks_uri: Option<Url>,
100
101    /// List of allowed redirect URIs
102    #[serde(default, skip_serializing_if = "Vec::is_empty")]
103    pub redirect_uris: Vec<Url>,
104}
105
106impl ClientConfig {
107    fn validate(&self) -> Result<(), figment::error::Error> {
108        let auth_method = self.client_auth_method;
109        match self.client_auth_method {
110            ClientAuthMethodConfig::PrivateKeyJwt => {
111                if self.jwks.is_none() && self.jwks_uri.is_none() {
112                    let error = figment::error::Error::custom(
113                        "jwks or jwks_uri is required for private_key_jwt",
114                    );
115                    return Err(error.with_path("client_auth_method"));
116                }
117
118                if self.jwks.is_some() && self.jwks_uri.is_some() {
119                    let error =
120                        figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
121                    return Err(error.with_path("jwks"));
122                }
123
124                if self.client_secret.is_some() {
125                    let error = figment::error::Error::custom(
126                        "client_secret is not allowed with private_key_jwt",
127                    );
128                    return Err(error.with_path("client_secret"));
129                }
130            }
131
132            ClientAuthMethodConfig::ClientSecretPost
133            | ClientAuthMethodConfig::ClientSecretBasic
134            | ClientAuthMethodConfig::ClientSecretJwt => {
135                if self.client_secret.is_none() {
136                    let error = figment::error::Error::custom(format!(
137                        "client_secret is required for {auth_method}"
138                    ));
139                    return Err(error.with_path("client_auth_method"));
140                }
141
142                if self.jwks.is_some() {
143                    let error = figment::error::Error::custom(format!(
144                        "jwks is not allowed with {auth_method}"
145                    ));
146                    return Err(error.with_path("jwks"));
147                }
148
149                if self.jwks_uri.is_some() {
150                    let error = figment::error::Error::custom(format!(
151                        "jwks_uri is not allowed with {auth_method}"
152                    ));
153                    return Err(error.with_path("jwks_uri"));
154                }
155            }
156
157            ClientAuthMethodConfig::None => {
158                if self.client_secret.is_some() {
159                    let error = figment::error::Error::custom(
160                        "client_secret is not allowed with none authentication method",
161                    );
162                    return Err(error.with_path("client_secret"));
163                }
164
165                if self.jwks.is_some() {
166                    let error = figment::error::Error::custom(
167                        "jwks is not allowed with none authentication method",
168                    );
169                    return Err(error);
170                }
171
172                if self.jwks_uri.is_some() {
173                    let error = figment::error::Error::custom(
174                        "jwks_uri is not allowed with none authentication method",
175                    );
176                    return Err(error);
177                }
178            }
179        }
180
181        Ok(())
182    }
183
184    /// Authentication method used for this client
185    #[must_use]
186    pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
187        match self.client_auth_method {
188            ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
189            ClientAuthMethodConfig::ClientSecretBasic => {
190                OAuthClientAuthenticationMethod::ClientSecretBasic
191            }
192            ClientAuthMethodConfig::ClientSecretPost => {
193                OAuthClientAuthenticationMethod::ClientSecretPost
194            }
195            ClientAuthMethodConfig::ClientSecretJwt => {
196                OAuthClientAuthenticationMethod::ClientSecretJwt
197            }
198            ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
199        }
200    }
201}
202
203/// List of OAuth 2.0/OIDC clients config
204#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
205#[serde(transparent)]
206pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
207
208impl ClientsConfig {
209    /// Returns true if all fields are at their default values
210    pub(crate) fn is_default(&self) -> bool {
211        self.0.is_empty()
212    }
213}
214
215impl Deref for ClientsConfig {
216    type Target = Vec<ClientConfig>;
217
218    fn deref(&self) -> &Self::Target {
219        &self.0
220    }
221}
222
223impl IntoIterator for ClientsConfig {
224    type Item = ClientConfig;
225    type IntoIter = std::vec::IntoIter<ClientConfig>;
226
227    fn into_iter(self) -> Self::IntoIter {
228        self.0.into_iter()
229    }
230}
231
232impl ConfigurationSection for ClientsConfig {
233    const PATH: Option<&'static str> = Some("clients");
234
235    fn validate(&self, figment: &Figment) -> Result<(), figment::error::Error> {
236        for (index, client) in self.0.iter().enumerate() {
237            client.validate().map_err(|mut err| {
238                // Save the error location information in the error
239                err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
240                err.profile = Some(figment::Profile::Default);
241                err.path.insert(0, Self::PATH.unwrap().to_owned());
242                err.path.insert(1, format!("{index}"));
243                err
244            })?;
245        }
246
247        Ok(())
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use std::str::FromStr;
254
255    use figment::{
256        Figment, Jail,
257        providers::{Format, Yaml},
258    };
259
260    use super::*;
261
262    #[test]
263    fn load_config() {
264        Jail::expect_with(|jail| {
265            jail.create_file(
266                "config.yaml",
267                r#"
268                  clients:
269                    - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
270                      client_auth_method: none
271                      redirect_uris:
272                        - https://exemple.fr/callback
273
274                    - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
275                      client_auth_method: client_secret_basic
276                      client_secret: hello
277
278                    - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
279                      client_auth_method: client_secret_post
280                      client_secret: hello
281
282                    - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
283                      client_auth_method: client_secret_jwt
284                      client_secret: hello
285
286                    - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
287                      client_auth_method: private_key_jwt
288                      jwks:
289                        keys:
290                        - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
291                          kty: "RSA"
292                          alg: "RS256"
293                          use: "sig"
294                          e: "AQAB"
295                          n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
296
297                        - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
298                          kty: "RSA"
299                          alg: "RS256"
300                          use: "sig"
301                          e: "AQAB"
302                          n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
303                "#,
304            )?;
305
306            let config = Figment::new()
307                .merge(Yaml::file("config.yaml"))
308                .extract_inner::<ClientsConfig>("clients")?;
309
310            assert_eq!(config.0.len(), 5);
311
312            assert_eq!(
313                config.0[0].client_id,
314                Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
315            );
316            assert_eq!(
317                config.0[0].redirect_uris,
318                vec!["https://exemple.fr/callback".parse().unwrap()]
319            );
320
321            assert_eq!(
322                config.0[1].client_id,
323                Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
324            );
325            assert_eq!(config.0[1].redirect_uris, Vec::new());
326
327            Ok(())
328        });
329    }
330}