mas_storage_pg/oauth2/
client.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    string::ToString,
10};
11
12use async_trait::async_trait;
13use mas_data_model::{Client, JwksOrJwksUri};
14use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
15use mas_jose::jwk::PublicJsonWebKeySet;
16use mas_storage::{Clock, oauth2::OAuth2ClientRepository};
17use oauth2_types::{oidc::ApplicationType, requests::GrantType};
18use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
19use rand::RngCore;
20use sqlx::PgConnection;
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use url::Url;
24use uuid::Uuid;
25
26use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
27
28/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
29pub struct PgOAuth2ClientRepository<'c> {
30    conn: &'c mut PgConnection,
31}
32
33impl<'c> PgOAuth2ClientRepository<'c> {
34    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
35    /// connection
36    pub fn new(conn: &'c mut PgConnection) -> Self {
37        Self { conn }
38    }
39}
40
41#[allow(clippy::struct_excessive_bools)]
42#[derive(Debug)]
43struct OAuth2ClientLookup {
44    oauth2_client_id: Uuid,
45    metadata_digest: Option<String>,
46    encrypted_client_secret: Option<String>,
47    application_type: Option<String>,
48    redirect_uris: Vec<String>,
49    grant_type_authorization_code: bool,
50    grant_type_refresh_token: bool,
51    grant_type_client_credentials: bool,
52    grant_type_device_code: bool,
53    client_name: Option<String>,
54    logo_uri: Option<String>,
55    client_uri: Option<String>,
56    policy_uri: Option<String>,
57    tos_uri: Option<String>,
58    jwks_uri: Option<String>,
59    jwks: Option<serde_json::Value>,
60    id_token_signed_response_alg: Option<String>,
61    userinfo_signed_response_alg: Option<String>,
62    token_endpoint_auth_method: Option<String>,
63    token_endpoint_auth_signing_alg: Option<String>,
64    initiate_login_uri: Option<String>,
65}
66
67impl TryInto<Client> for OAuth2ClientLookup {
68    type Error = DatabaseInconsistencyError;
69
70    #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
71    fn try_into(self) -> Result<Client, Self::Error> {
72        let id = Ulid::from(self.oauth2_client_id);
73
74        let redirect_uris: Result<Vec<Url>, _> =
75            self.redirect_uris.iter().map(|s| s.parse()).collect();
76        let redirect_uris = redirect_uris.map_err(|e| {
77            DatabaseInconsistencyError::on("oauth2_clients")
78                .column("redirect_uris")
79                .row(id)
80                .source(e)
81        })?;
82
83        let application_type = self
84            .application_type
85            .map(|s| s.parse())
86            .transpose()
87            .map_err(|e| {
88                DatabaseInconsistencyError::on("oauth2_clients")
89                    .column("application_type")
90                    .row(id)
91                    .source(e)
92            })?;
93
94        let mut grant_types = Vec::new();
95        if self.grant_type_authorization_code {
96            grant_types.push(GrantType::AuthorizationCode);
97        }
98        if self.grant_type_refresh_token {
99            grant_types.push(GrantType::RefreshToken);
100        }
101        if self.grant_type_client_credentials {
102            grant_types.push(GrantType::ClientCredentials);
103        }
104        if self.grant_type_device_code {
105            grant_types.push(GrantType::DeviceCode);
106        }
107
108        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
109            DatabaseInconsistencyError::on("oauth2_clients")
110                .column("logo_uri")
111                .row(id)
112                .source(e)
113        })?;
114
115        let client_uri = self
116            .client_uri
117            .map(|s| s.parse())
118            .transpose()
119            .map_err(|e| {
120                DatabaseInconsistencyError::on("oauth2_clients")
121                    .column("client_uri")
122                    .row(id)
123                    .source(e)
124            })?;
125
126        let policy_uri = self
127            .policy_uri
128            .map(|s| s.parse())
129            .transpose()
130            .map_err(|e| {
131                DatabaseInconsistencyError::on("oauth2_clients")
132                    .column("policy_uri")
133                    .row(id)
134                    .source(e)
135            })?;
136
137        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
138            DatabaseInconsistencyError::on("oauth2_clients")
139                .column("tos_uri")
140                .row(id)
141                .source(e)
142        })?;
143
144        let id_token_signed_response_alg = self
145            .id_token_signed_response_alg
146            .map(|s| s.parse())
147            .transpose()
148            .map_err(|e| {
149                DatabaseInconsistencyError::on("oauth2_clients")
150                    .column("id_token_signed_response_alg")
151                    .row(id)
152                    .source(e)
153            })?;
154
155        let userinfo_signed_response_alg = self
156            .userinfo_signed_response_alg
157            .map(|s| s.parse())
158            .transpose()
159            .map_err(|e| {
160                DatabaseInconsistencyError::on("oauth2_clients")
161                    .column("userinfo_signed_response_alg")
162                    .row(id)
163                    .source(e)
164            })?;
165
166        let token_endpoint_auth_method = self
167            .token_endpoint_auth_method
168            .map(|s| s.parse())
169            .transpose()
170            .map_err(|e| {
171                DatabaseInconsistencyError::on("oauth2_clients")
172                    .column("token_endpoint_auth_method")
173                    .row(id)
174                    .source(e)
175            })?;
176
177        let token_endpoint_auth_signing_alg = self
178            .token_endpoint_auth_signing_alg
179            .map(|s| s.parse())
180            .transpose()
181            .map_err(|e| {
182                DatabaseInconsistencyError::on("oauth2_clients")
183                    .column("token_endpoint_auth_signing_alg")
184                    .row(id)
185                    .source(e)
186            })?;
187
188        let initiate_login_uri = self
189            .initiate_login_uri
190            .map(|s| s.parse())
191            .transpose()
192            .map_err(|e| {
193                DatabaseInconsistencyError::on("oauth2_clients")
194                    .column("initiate_login_uri")
195                    .row(id)
196                    .source(e)
197            })?;
198
199        let jwks = match (self.jwks, self.jwks_uri) {
200            (None, None) => None,
201            (Some(jwks), None) => {
202                let jwks = serde_json::from_value(jwks).map_err(|e| {
203                    DatabaseInconsistencyError::on("oauth2_clients")
204                        .column("jwks")
205                        .row(id)
206                        .source(e)
207                })?;
208                Some(JwksOrJwksUri::Jwks(jwks))
209            }
210            (None, Some(jwks_uri)) => {
211                let jwks_uri = jwks_uri.parse().map_err(|e| {
212                    DatabaseInconsistencyError::on("oauth2_clients")
213                        .column("jwks_uri")
214                        .row(id)
215                        .source(e)
216                })?;
217
218                Some(JwksOrJwksUri::JwksUri(jwks_uri))
219            }
220            _ => {
221                return Err(DatabaseInconsistencyError::on("oauth2_clients")
222                    .column("jwks(_uri)")
223                    .row(id));
224            }
225        };
226
227        Ok(Client {
228            id,
229            client_id: id.to_string(),
230            metadata_digest: self.metadata_digest,
231            encrypted_client_secret: self.encrypted_client_secret,
232            application_type,
233            redirect_uris,
234            grant_types,
235            client_name: self.client_name,
236            logo_uri,
237            client_uri,
238            policy_uri,
239            tos_uri,
240            jwks,
241            id_token_signed_response_alg,
242            userinfo_signed_response_alg,
243            token_endpoint_auth_method,
244            token_endpoint_auth_signing_alg,
245            initiate_login_uri,
246        })
247    }
248}
249
250#[async_trait]
251impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
252    type Error = DatabaseError;
253
254    #[tracing::instrument(
255        name = "db.oauth2_client.lookup",
256        skip_all,
257        fields(
258            db.query.text,
259            oauth2_client.id = %id,
260        ),
261        err,
262    )]
263    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
264        let res = sqlx::query_as!(
265            OAuth2ClientLookup,
266            r#"
267                SELECT oauth2_client_id
268                     , metadata_digest
269                     , encrypted_client_secret
270                     , application_type
271                     , redirect_uris
272                     , grant_type_authorization_code
273                     , grant_type_refresh_token
274                     , grant_type_client_credentials
275                     , grant_type_device_code
276                     , client_name
277                     , logo_uri
278                     , client_uri
279                     , policy_uri
280                     , tos_uri
281                     , jwks_uri
282                     , jwks
283                     , id_token_signed_response_alg
284                     , userinfo_signed_response_alg
285                     , token_endpoint_auth_method
286                     , token_endpoint_auth_signing_alg
287                     , initiate_login_uri
288                FROM oauth2_clients c
289
290                WHERE oauth2_client_id = $1
291            "#,
292            Uuid::from(id),
293        )
294        .traced()
295        .fetch_optional(&mut *self.conn)
296        .await?;
297
298        let Some(res) = res else { return Ok(None) };
299
300        Ok(Some(res.try_into()?))
301    }
302
303    #[tracing::instrument(
304        name = "db.oauth2_client.find_by_metadata_digest",
305        skip_all,
306        fields(
307            db.query.text,
308        ),
309        err,
310    )]
311    async fn find_by_metadata_digest(
312        &mut self,
313        digest: &str,
314    ) -> Result<Option<Client>, Self::Error> {
315        let res = sqlx::query_as!(
316            OAuth2ClientLookup,
317            r#"
318                SELECT oauth2_client_id
319                    , metadata_digest
320                    , encrypted_client_secret
321                    , application_type
322                    , redirect_uris
323                    , grant_type_authorization_code
324                    , grant_type_refresh_token
325                    , grant_type_client_credentials
326                    , grant_type_device_code
327                    , client_name
328                    , logo_uri
329                    , client_uri
330                    , policy_uri
331                    , tos_uri
332                    , jwks_uri
333                    , jwks
334                    , id_token_signed_response_alg
335                    , userinfo_signed_response_alg
336                    , token_endpoint_auth_method
337                    , token_endpoint_auth_signing_alg
338                    , initiate_login_uri
339                FROM oauth2_clients
340                WHERE metadata_digest = $1
341            "#,
342            digest,
343        )
344        .traced()
345        .fetch_optional(&mut *self.conn)
346        .await?;
347
348        let Some(res) = res else { return Ok(None) };
349
350        Ok(Some(res.try_into()?))
351    }
352
353    #[tracing::instrument(
354        name = "db.oauth2_client.load_batch",
355        skip_all,
356        fields(
357            db.query.text,
358        ),
359        err,
360    )]
361    async fn load_batch(
362        &mut self,
363        ids: BTreeSet<Ulid>,
364    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
365        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
366        let res = sqlx::query_as!(
367            OAuth2ClientLookup,
368            r#"
369                SELECT oauth2_client_id
370                     , metadata_digest
371                     , encrypted_client_secret
372                     , application_type
373                     , redirect_uris
374                     , grant_type_authorization_code
375                     , grant_type_refresh_token
376                     , grant_type_client_credentials
377                     , grant_type_device_code
378                     , client_name
379                     , logo_uri
380                     , client_uri
381                     , policy_uri
382                     , tos_uri
383                     , jwks_uri
384                     , jwks
385                     , id_token_signed_response_alg
386                     , userinfo_signed_response_alg
387                     , token_endpoint_auth_method
388                     , token_endpoint_auth_signing_alg
389                     , initiate_login_uri
390                FROM oauth2_clients c
391
392                WHERE oauth2_client_id = ANY($1::uuid[])
393            "#,
394            &ids,
395        )
396        .traced()
397        .fetch_all(&mut *self.conn)
398        .await?;
399
400        res.into_iter()
401            .map(|r| {
402                r.try_into()
403                    .map(|c: Client| (c.id, c))
404                    .map_err(DatabaseError::from)
405            })
406            .collect()
407    }
408
409    #[tracing::instrument(
410        name = "db.oauth2_client.add",
411        skip_all,
412        fields(
413            db.query.text,
414            client.id,
415            client.name = client_name
416        ),
417        err,
418    )]
419    #[allow(clippy::too_many_lines)]
420    async fn add(
421        &mut self,
422        rng: &mut (dyn RngCore + Send),
423        clock: &dyn Clock,
424        redirect_uris: Vec<Url>,
425        metadata_digest: Option<String>,
426        encrypted_client_secret: Option<String>,
427        application_type: Option<ApplicationType>,
428        grant_types: Vec<GrantType>,
429        client_name: Option<String>,
430        logo_uri: Option<Url>,
431        client_uri: Option<Url>,
432        policy_uri: Option<Url>,
433        tos_uri: Option<Url>,
434        jwks_uri: Option<Url>,
435        jwks: Option<PublicJsonWebKeySet>,
436        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
437        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
438        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
439        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
440        initiate_login_uri: Option<Url>,
441    ) -> Result<Client, Self::Error> {
442        let now = clock.now();
443        let id = Ulid::from_datetime_with_source(now.into(), rng);
444        tracing::Span::current().record("client.id", tracing::field::display(id));
445
446        let jwks_json = jwks
447            .as_ref()
448            .map(serde_json::to_value)
449            .transpose()
450            .map_err(DatabaseError::to_invalid_operation)?;
451
452        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
453
454        sqlx::query!(
455            r#"
456                INSERT INTO oauth2_clients
457                    ( oauth2_client_id
458                    , metadata_digest
459                    , encrypted_client_secret
460                    , application_type
461                    , redirect_uris
462                    , grant_type_authorization_code
463                    , grant_type_refresh_token
464                    , grant_type_client_credentials
465                    , grant_type_device_code
466                    , client_name
467                    , logo_uri
468                    , client_uri
469                    , policy_uri
470                    , tos_uri
471                    , jwks_uri
472                    , jwks
473                    , id_token_signed_response_alg
474                    , userinfo_signed_response_alg
475                    , token_endpoint_auth_method
476                    , token_endpoint_auth_signing_alg
477                    , initiate_login_uri
478                    , is_static
479                    )
480                VALUES
481                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
482                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
483            "#,
484            Uuid::from(id),
485            metadata_digest,
486            encrypted_client_secret,
487            application_type.as_ref().map(ToString::to_string),
488            &redirect_uris_array,
489            grant_types.contains(&GrantType::AuthorizationCode),
490            grant_types.contains(&GrantType::RefreshToken),
491            grant_types.contains(&GrantType::ClientCredentials),
492            grant_types.contains(&GrantType::DeviceCode),
493            client_name,
494            logo_uri.as_ref().map(Url::as_str),
495            client_uri.as_ref().map(Url::as_str),
496            policy_uri.as_ref().map(Url::as_str),
497            tos_uri.as_ref().map(Url::as_str),
498            jwks_uri.as_ref().map(Url::as_str),
499            jwks_json,
500            id_token_signed_response_alg
501                .as_ref()
502                .map(ToString::to_string),
503            userinfo_signed_response_alg
504                .as_ref()
505                .map(ToString::to_string),
506            token_endpoint_auth_method.as_ref().map(ToString::to_string),
507            token_endpoint_auth_signing_alg
508                .as_ref()
509                .map(ToString::to_string),
510            initiate_login_uri.as_ref().map(Url::as_str),
511        )
512        .traced()
513        .execute(&mut *self.conn)
514        .await?;
515
516        let jwks = match (jwks, jwks_uri) {
517            (None, None) => None,
518            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
519            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
520            _ => return Err(DatabaseError::invalid_operation()),
521        };
522
523        Ok(Client {
524            id,
525            client_id: id.to_string(),
526            metadata_digest: None,
527            encrypted_client_secret,
528            application_type,
529            redirect_uris,
530            grant_types,
531            client_name,
532            logo_uri,
533            client_uri,
534            policy_uri,
535            tos_uri,
536            jwks,
537            id_token_signed_response_alg,
538            userinfo_signed_response_alg,
539            token_endpoint_auth_method,
540            token_endpoint_auth_signing_alg,
541            initiate_login_uri,
542        })
543    }
544
545    #[tracing::instrument(
546        name = "db.oauth2_client.upsert_static",
547        skip_all,
548        fields(
549            db.query.text,
550            client.id = %client_id,
551        ),
552        err,
553    )]
554    async fn upsert_static(
555        &mut self,
556        client_id: Ulid,
557        client_name: Option<String>,
558        client_auth_method: OAuthClientAuthenticationMethod,
559        encrypted_client_secret: Option<String>,
560        jwks: Option<PublicJsonWebKeySet>,
561        jwks_uri: Option<Url>,
562        redirect_uris: Vec<Url>,
563    ) -> Result<Client, Self::Error> {
564        let jwks_json = jwks
565            .as_ref()
566            .map(serde_json::to_value)
567            .transpose()
568            .map_err(DatabaseError::to_invalid_operation)?;
569
570        let client_auth_method = client_auth_method.to_string();
571        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
572
573        sqlx::query!(
574            r#"
575                INSERT INTO oauth2_clients
576                    ( oauth2_client_id
577                    , encrypted_client_secret
578                    , redirect_uris
579                    , grant_type_authorization_code
580                    , grant_type_refresh_token
581                    , grant_type_client_credentials
582                    , grant_type_device_code
583                    , token_endpoint_auth_method
584                    , jwks
585                    , client_name
586                    , jwks_uri
587                    , is_static
588                    )
589                VALUES
590                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
591                ON CONFLICT (oauth2_client_id)
592                DO
593                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
594                             , redirect_uris = EXCLUDED.redirect_uris
595                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
596                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
597                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
598                             , grant_type_device_code = EXCLUDED.grant_type_device_code
599                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
600                             , jwks = EXCLUDED.jwks
601                             , client_name = EXCLUDED.client_name
602                             , jwks_uri = EXCLUDED.jwks_uri
603                             , is_static = TRUE
604            "#,
605            Uuid::from(client_id),
606            encrypted_client_secret,
607            &redirect_uris_array,
608            true,
609            true,
610            true,
611            true,
612            client_auth_method,
613            jwks_json,
614            client_name,
615            jwks_uri.as_ref().map(Url::as_str),
616        )
617        .traced()
618        .execute(&mut *self.conn)
619        .await?;
620
621        let jwks = match (jwks, jwks_uri) {
622            (None, None) => None,
623            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
624            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
625            _ => return Err(DatabaseError::invalid_operation()),
626        };
627
628        Ok(Client {
629            id: client_id,
630            client_id: client_id.to_string(),
631            metadata_digest: None,
632            encrypted_client_secret,
633            application_type: None,
634            redirect_uris,
635            grant_types: vec![
636                GrantType::AuthorizationCode,
637                GrantType::RefreshToken,
638                GrantType::ClientCredentials,
639            ],
640            client_name,
641            logo_uri: None,
642            client_uri: None,
643            policy_uri: None,
644            tos_uri: None,
645            jwks,
646            id_token_signed_response_alg: None,
647            userinfo_signed_response_alg: None,
648            token_endpoint_auth_method: None,
649            token_endpoint_auth_signing_alg: None,
650            initiate_login_uri: None,
651        })
652    }
653
654    #[tracing::instrument(
655        name = "db.oauth2_client.all_static",
656        skip_all,
657        fields(
658            db.query.text,
659        ),
660        err,
661    )]
662    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
663        let res = sqlx::query_as!(
664            OAuth2ClientLookup,
665            r#"
666                SELECT oauth2_client_id
667                     , metadata_digest
668                     , encrypted_client_secret
669                     , application_type
670                     , redirect_uris
671                     , grant_type_authorization_code
672                     , grant_type_refresh_token
673                     , grant_type_client_credentials
674                     , grant_type_device_code
675                     , client_name
676                     , logo_uri
677                     , client_uri
678                     , policy_uri
679                     , tos_uri
680                     , jwks_uri
681                     , jwks
682                     , id_token_signed_response_alg
683                     , userinfo_signed_response_alg
684                     , token_endpoint_auth_method
685                     , token_endpoint_auth_signing_alg
686                     , initiate_login_uri
687                FROM oauth2_clients c
688                WHERE is_static = TRUE
689            "#,
690        )
691        .traced()
692        .fetch_all(&mut *self.conn)
693        .await?;
694
695        res.into_iter()
696            .map(|r| r.try_into().map_err(DatabaseError::from))
697            .collect()
698    }
699
700    #[tracing::instrument(
701        name = "db.oauth2_client.delete_by_id",
702        skip_all,
703        fields(
704            db.query.text,
705            client.id = %id,
706        ),
707        err,
708    )]
709    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
710        // Delete the authorization grants
711        {
712            let span = info_span!(
713                "db.oauth2_client.delete_by_id.authorization_grants",
714                { DB_QUERY_TEXT } = tracing::field::Empty,
715            );
716
717            sqlx::query!(
718                r#"
719                    DELETE FROM oauth2_authorization_grants
720                    WHERE oauth2_client_id = $1
721                "#,
722                Uuid::from(id),
723            )
724            .record(&span)
725            .execute(&mut *self.conn)
726            .instrument(span)
727            .await?;
728        }
729
730        // Delete the user consents
731        {
732            let span = info_span!(
733                "db.oauth2_client.delete_by_id.consents",
734                { DB_QUERY_TEXT } = tracing::field::Empty,
735            );
736
737            sqlx::query!(
738                r#"
739                    DELETE FROM oauth2_consents
740                    WHERE oauth2_client_id = $1
741                "#,
742                Uuid::from(id),
743            )
744            .record(&span)
745            .execute(&mut *self.conn)
746            .instrument(span)
747            .await?;
748        }
749
750        // Delete the OAuth 2 sessions related data
751        {
752            let span = info_span!(
753                "db.oauth2_client.delete_by_id.access_tokens",
754                { DB_QUERY_TEXT } = tracing::field::Empty,
755            );
756
757            sqlx::query!(
758                r#"
759                    DELETE FROM oauth2_access_tokens
760                    WHERE oauth2_session_id IN (
761                        SELECT oauth2_session_id
762                        FROM oauth2_sessions
763                        WHERE oauth2_client_id = $1
764                    )
765                "#,
766                Uuid::from(id),
767            )
768            .record(&span)
769            .execute(&mut *self.conn)
770            .instrument(span)
771            .await?;
772        }
773
774        {
775            let span = info_span!(
776                "db.oauth2_client.delete_by_id.refresh_tokens",
777                { DB_QUERY_TEXT } = tracing::field::Empty,
778            );
779
780            sqlx::query!(
781                r#"
782                    DELETE FROM oauth2_refresh_tokens
783                    WHERE oauth2_session_id IN (
784                        SELECT oauth2_session_id
785                        FROM oauth2_sessions
786                        WHERE oauth2_client_id = $1
787                    )
788                "#,
789                Uuid::from(id),
790            )
791            .record(&span)
792            .execute(&mut *self.conn)
793            .instrument(span)
794            .await?;
795        }
796
797        {
798            let span = info_span!(
799                "db.oauth2_client.delete_by_id.sessions",
800                { DB_QUERY_TEXT } = tracing::field::Empty,
801            );
802
803            sqlx::query!(
804                r#"
805                    DELETE FROM oauth2_sessions
806                    WHERE oauth2_client_id = $1
807                "#,
808                Uuid::from(id),
809            )
810            .record(&span)
811            .execute(&mut *self.conn)
812            .instrument(span)
813            .await?;
814        }
815
816        // Now delete the client itself
817        let res = sqlx::query!(
818            r#"
819                DELETE FROM oauth2_clients
820                WHERE oauth2_client_id = $1
821            "#,
822            Uuid::from(id),
823        )
824        .traced()
825        .execute(&mut *self.conn)
826        .await?;
827
828        DatabaseError::ensure_affected_rows(&res, 1)
829    }
830}