syn2mas/mas_writer/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! # MAS Writer
7//!
8//! This module is responsible for writing new records to MAS' database.
9
10use std::{
11    fmt::Display,
12    net::IpAddr,
13    sync::{
14        Arc,
15        atomic::{AtomicU32, Ordering},
16    },
17};
18
19use chrono::{DateTime, Utc};
20use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
21use sqlx::{Executor, PgConnection, query, query_as};
22use thiserror::Error;
23use thiserror_ext::{Construct, ContextInto};
24use tokio::sync::mpsc::{self, Receiver, Sender};
25use tracing::{Instrument, error, info, warn};
26use uuid::{NonNilUuid, Uuid};
27
28use self::{
29    constraint_pausing::{ConstraintDescription, IndexDescription},
30    locking::LockedMasDatabase,
31};
32use crate::Progress;
33
34pub mod checks;
35pub mod locking;
36
37mod constraint_pausing;
38
39#[derive(Debug, Error, Construct, ContextInto)]
40pub enum Error {
41    #[error("database error whilst {context}")]
42    Database {
43        #[source]
44        source: sqlx::Error,
45        context: String,
46    },
47
48    #[error("writer connection pool shut down due to error")]
49    #[expect(clippy::enum_variant_names)]
50    WriterConnectionPoolError,
51
52    #[error("inconsistent database: {0}")]
53    Inconsistent(String),
54
55    #[error("bug in syn2mas: write buffers not finished")]
56    WriteBuffersNotFinished,
57
58    #[error("{0}")]
59    Multiple(MultipleErrors),
60}
61
62#[derive(Debug)]
63pub struct MultipleErrors {
64    errors: Vec<Error>,
65}
66
67impl Display for MultipleErrors {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "multiple errors")?;
70        for error in &self.errors {
71            write!(f, "\n- {error}")?;
72        }
73        Ok(())
74    }
75}
76
77impl From<Vec<Error>> for MultipleErrors {
78    fn from(value: Vec<Error>) -> Self {
79        MultipleErrors { errors: value }
80    }
81}
82
83struct WriterConnectionPool {
84    /// How many connections are in circulation
85    num_connections: usize,
86
87    /// A receiver handle to get a writer connection
88    /// The writer connection will be mid-transaction!
89    connection_rx: Receiver<Result<PgConnection, Error>>,
90
91    /// A sender handle to return a writer connection to the pool
92    /// The connection should still be mid-transaction!
93    connection_tx: Sender<Result<PgConnection, Error>>,
94}
95
96impl WriterConnectionPool {
97    pub fn new(connections: Vec<PgConnection>) -> Self {
98        let num_connections = connections.len();
99        let (connection_tx, connection_rx) = mpsc::channel(num_connections);
100        for connection in connections {
101            connection_tx
102                .try_send(Ok(connection))
103                .expect("there should be room for this connection");
104        }
105
106        WriterConnectionPool {
107            num_connections,
108            connection_rx,
109            connection_tx,
110        }
111    }
112
113    pub async fn spawn_with_connection<F>(&mut self, task: F) -> Result<(), Error>
114    where
115        F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>>
116            + Send
117            + 'static,
118    {
119        match self.connection_rx.recv().await {
120            Some(Ok(mut connection)) => {
121                let connection_tx = self.connection_tx.clone();
122                tokio::task::spawn(
123                    async move {
124                        let to_return = match task(&mut connection).await {
125                            Ok(()) => Ok(connection),
126                            Err(error) => {
127                                error!("error in writer: {error}");
128                                Err(error)
129                            }
130                        };
131                        // This should always succeed in sending unless we're already shutting
132                        // down for some other reason.
133                        let _: Result<_, _> = connection_tx.send(to_return).await;
134                    }
135                    .instrument(tracing::debug_span!("spawn_with_connection")),
136                );
137
138                Ok(())
139            }
140            Some(Err(error)) => {
141                // This should always succeed in sending unless we're already shutting
142                // down for some other reason.
143                let _: Result<_, _> = self.connection_tx.send(Err(error)).await;
144
145                Err(Error::WriterConnectionPoolError)
146            }
147            None => {
148                unreachable!("we still hold a reference to the sender, so this shouldn't happen")
149            }
150        }
151    }
152
153    /// Finishes writing to the database, committing all changes.
154    ///
155    /// # Errors
156    ///
157    /// - If any errors were returned to the pool.
158    /// - If committing the changes failed.
159    ///
160    /// # Panics
161    ///
162    /// - If connections were not returned to the pool. (This indicates a
163    ///   serious bug.)
164    pub async fn finish(self) -> Result<(), Vec<Error>> {
165        let mut errors = Vec::new();
166
167        let Self {
168            num_connections,
169            mut connection_rx,
170            connection_tx,
171        } = self;
172        // Drop the sender handle so we gracefully allow the receiver to close
173        drop(connection_tx);
174
175        let mut finished_connections = 0;
176
177        while let Some(connection_or_error) = connection_rx.recv().await {
178            finished_connections += 1;
179
180            match connection_or_error {
181                Ok(mut connection) => {
182                    if let Err(err) = query("COMMIT;").execute(&mut connection).await {
183                        errors.push(err.into_database("commit writer transaction"));
184                    }
185                }
186                Err(error) => {
187                    errors.push(error);
188                }
189            }
190        }
191        assert_eq!(
192            finished_connections, num_connections,
193            "syn2mas had a bug: connections went missing {finished_connections} != {num_connections}"
194        );
195
196        if errors.is_empty() {
197            Ok(())
198        } else {
199            Err(errors)
200        }
201    }
202}
203
204/// Small utility to make sure `finish()` is called on all write buffers
205/// before committing to the database.
206#[derive(Default)]
207struct FinishChecker {
208    counter: Arc<AtomicU32>,
209}
210
211struct FinishCheckerHandle {
212    counter: Arc<AtomicU32>,
213}
214
215impl FinishChecker {
216    /// Acquire a new handle, for a task that should declare when it has
217    /// finished.
218    pub fn handle(&self) -> FinishCheckerHandle {
219        self.counter.fetch_add(1, Ordering::SeqCst);
220        FinishCheckerHandle {
221            counter: Arc::clone(&self.counter),
222        }
223    }
224
225    /// Check that all handles have been declared as finished.
226    pub fn check_all_finished(self) -> Result<(), Error> {
227        if self.counter.load(Ordering::SeqCst) == 0 {
228            Ok(())
229        } else {
230            Err(Error::WriteBuffersNotFinished)
231        }
232    }
233}
234
235impl FinishCheckerHandle {
236    /// Declare that the task this handle represents has been finished.
237    pub fn declare_finished(self) {
238        self.counter.fetch_sub(1, Ordering::SeqCst);
239    }
240}
241
242pub struct MasWriter {
243    conn: LockedMasDatabase,
244    writer_pool: WriterConnectionPool,
245    dry_run: bool,
246
247    indices_to_restore: Vec<IndexDescription>,
248    constraints_to_restore: Vec<ConstraintDescription>,
249
250    write_buffer_finish_checker: FinishChecker,
251}
252
253pub trait WriteBatch: Send + Sync + Sized + 'static {
254    fn write_batch(
255        conn: &mut PgConnection,
256        batch: Vec<Self>,
257    ) -> impl Future<Output = Result<(), Error>> + Send;
258}
259
260pub struct MasNewUser {
261    pub user_id: NonNilUuid,
262    pub username: String,
263    pub created_at: DateTime<Utc>,
264    pub locked_at: Option<DateTime<Utc>>,
265    pub deactivated_at: Option<DateTime<Utc>>,
266    pub can_request_admin: bool,
267    /// Whether the user was a Synapse guest.
268    /// Although MAS doesn't support guest access, it's still useful to track
269    /// for the future.
270    pub is_guest: bool,
271}
272
273impl WriteBatch for MasNewUser {
274    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
275        // `UNNEST` is a fast way to do bulk inserts, as it lets us send multiple rows
276        // in one statement without having to change the statement
277        // SQL thus altering the query plan. See <https://github.com/launchbadge/sqlx/blob/main/FAQ.md#how-can-i-bind-an-array-to-a-values-clause-how-can-i-do-bulk-inserts>.
278        // In the future we could consider using sqlx's support for `PgCopyIn` / the
279        // `COPY FROM STDIN` statement, which is allegedly the best
280        // for insert performance, but is less simple to encode.
281        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
282        let mut usernames: Vec<String> = Vec::with_capacity(batch.len());
283        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
284        let mut locked_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
285        let mut deactivated_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
286        let mut can_request_admins: Vec<bool> = Vec::with_capacity(batch.len());
287        let mut is_guests: Vec<bool> = Vec::with_capacity(batch.len());
288        for MasNewUser {
289            user_id,
290            username,
291            created_at,
292            locked_at,
293            deactivated_at,
294            can_request_admin,
295            is_guest,
296        } in batch
297        {
298            user_ids.push(user_id.get());
299            usernames.push(username);
300            created_ats.push(created_at);
301            locked_ats.push(locked_at);
302            deactivated_ats.push(deactivated_at);
303            can_request_admins.push(can_request_admin);
304            is_guests.push(is_guest);
305        }
306
307        sqlx::query!(
308            r#"
309            INSERT INTO syn2mas__users (
310              user_id, username,
311              created_at, locked_at,
312              deactivated_at,
313              can_request_admin, is_guest)
314            SELECT * FROM UNNEST(
315              $1::UUID[], $2::TEXT[],
316              $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
317              $5::TIMESTAMP WITH TIME ZONE[],
318              $6::BOOL[], $7::BOOL[])
319            "#,
320            &user_ids[..],
321            &usernames[..],
322            &created_ats[..],
323            // We need to override the typing for arrays of optionals (sqlx limitation)
324            &locked_ats[..] as &[Option<DateTime<Utc>>],
325            &deactivated_ats[..] as &[Option<DateTime<Utc>>],
326            &can_request_admins[..],
327            &is_guests[..],
328        )
329        .execute(&mut *conn)
330        .await
331        .into_database("writing users to MAS")?;
332
333        Ok(())
334    }
335}
336
337pub struct MasNewUserPassword {
338    pub user_password_id: Uuid,
339    pub user_id: NonNilUuid,
340    pub hashed_password: String,
341    pub created_at: DateTime<Utc>,
342}
343
344impl WriteBatch for MasNewUserPassword {
345    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
346        let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
347        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
348        let mut hashed_passwords: Vec<String> = Vec::with_capacity(batch.len());
349        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
350        let mut versions: Vec<i32> = Vec::with_capacity(batch.len());
351        for MasNewUserPassword {
352            user_password_id,
353            user_id,
354            hashed_password,
355            created_at,
356        } in batch
357        {
358            user_password_ids.push(user_password_id);
359            user_ids.push(user_id.get());
360            hashed_passwords.push(hashed_password);
361            created_ats.push(created_at);
362            versions.push(MIGRATED_PASSWORD_VERSION.into());
363        }
364
365        sqlx::query!(
366            r#"
367            INSERT INTO syn2mas__user_passwords
368            (user_password_id, user_id, hashed_password, created_at, version)
369            SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $5::INTEGER[])
370            "#,
371            &user_password_ids[..],
372            &user_ids[..],
373            &hashed_passwords[..],
374            &created_ats[..],
375            &versions[..],
376        ).execute(&mut *conn).await.into_database("writing users to MAS")?;
377
378        Ok(())
379    }
380}
381
382pub struct MasNewEmailThreepid {
383    pub user_email_id: Uuid,
384    pub user_id: NonNilUuid,
385    pub email: String,
386    pub created_at: DateTime<Utc>,
387}
388
389impl WriteBatch for MasNewEmailThreepid {
390    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
391        let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
392        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
393        let mut emails: Vec<String> = Vec::with_capacity(batch.len());
394        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
395
396        for MasNewEmailThreepid {
397            user_email_id,
398            user_id,
399            email,
400            created_at,
401        } in batch
402        {
403            user_email_ids.push(user_email_id);
404            user_ids.push(user_id.get());
405            emails.push(email);
406            created_ats.push(created_at);
407        }
408
409        // `confirmed_at` is going to get removed in a future MAS release,
410        // so just populate with `created_at`
411        sqlx::query!(
412            r#"
413            INSERT INTO syn2mas__user_emails
414            (user_email_id, user_id, email, created_at, confirmed_at)
415            SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[])
416            "#,
417            &user_email_ids[..],
418            &user_ids[..],
419            &emails[..],
420            &created_ats[..],
421        ).execute(&mut *conn).await.into_database("writing emails to MAS")?;
422
423        Ok(())
424    }
425}
426
427pub struct MasNewUnsupportedThreepid {
428    pub user_id: NonNilUuid,
429    pub medium: String,
430    pub address: String,
431    pub created_at: DateTime<Utc>,
432}
433
434impl WriteBatch for MasNewUnsupportedThreepid {
435    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
436        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
437        let mut mediums: Vec<String> = Vec::with_capacity(batch.len());
438        let mut addresses: Vec<String> = Vec::with_capacity(batch.len());
439        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
440
441        for MasNewUnsupportedThreepid {
442            user_id,
443            medium,
444            address,
445            created_at,
446        } in batch
447        {
448            user_ids.push(user_id.get());
449            mediums.push(medium);
450            addresses.push(address);
451            created_ats.push(created_at);
452        }
453
454        sqlx::query!(
455            r#"
456            INSERT INTO syn2mas__user_unsupported_third_party_ids
457            (user_id, medium, address, created_at)
458            SELECT * FROM UNNEST($1::UUID[], $2::TEXT[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
459            "#,
460            &user_ids[..],
461            &mediums[..],
462            &addresses[..],
463            &created_ats[..],
464        )
465        .execute(&mut *conn)
466        .await
467        .into_database("writing unsupported threepids to MAS")?;
468
469        Ok(())
470    }
471}
472
473pub struct MasNewUpstreamOauthLink {
474    pub link_id: Uuid,
475    pub user_id: NonNilUuid,
476    pub upstream_provider_id: Uuid,
477    pub subject: String,
478    pub created_at: DateTime<Utc>,
479}
480
481impl WriteBatch for MasNewUpstreamOauthLink {
482    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
483        let mut link_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
484        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
485        let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
486        let mut subjects: Vec<String> = Vec::with_capacity(batch.len());
487        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
488
489        for MasNewUpstreamOauthLink {
490            link_id,
491            user_id,
492            upstream_provider_id,
493            subject,
494            created_at,
495        } in batch
496        {
497            link_ids.push(link_id);
498            user_ids.push(user_id.get());
499            upstream_provider_ids.push(upstream_provider_id);
500            subjects.push(subject);
501            created_ats.push(created_at);
502        }
503
504        sqlx::query!(
505            r#"
506            INSERT INTO syn2mas__upstream_oauth_links
507            (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
508            SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
509            "#,
510            &link_ids[..],
511            &user_ids[..],
512            &upstream_provider_ids[..],
513            &subjects[..],
514            &created_ats[..],
515        ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
516
517        Ok(())
518    }
519}
520
521pub struct MasNewCompatSession {
522    pub session_id: Uuid,
523    pub user_id: NonNilUuid,
524    pub device_id: Option<String>,
525    pub human_name: Option<String>,
526    pub created_at: DateTime<Utc>,
527    pub is_synapse_admin: bool,
528    pub last_active_at: Option<DateTime<Utc>>,
529    pub last_active_ip: Option<IpAddr>,
530    pub user_agent: Option<String>,
531}
532
533impl WriteBatch for MasNewCompatSession {
534    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
535        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
536        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
537        let mut device_ids: Vec<Option<String>> = Vec::with_capacity(batch.len());
538        let mut human_names: Vec<Option<String>> = Vec::with_capacity(batch.len());
539        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
540        let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(batch.len());
541        let mut last_active_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
542        let mut last_active_ips: Vec<Option<IpAddr>> = Vec::with_capacity(batch.len());
543        let mut user_agents: Vec<Option<String>> = Vec::with_capacity(batch.len());
544
545        for MasNewCompatSession {
546            session_id,
547            user_id,
548            device_id,
549            human_name,
550            created_at,
551            is_synapse_admin,
552            last_active_at,
553            last_active_ip,
554            user_agent,
555        } in batch
556        {
557            session_ids.push(session_id);
558            user_ids.push(user_id.get());
559            device_ids.push(device_id);
560            human_names.push(human_name);
561            created_ats.push(created_at);
562            is_synapse_admins.push(is_synapse_admin);
563            last_active_ats.push(last_active_at);
564            last_active_ips.push(last_active_ip);
565            user_agents.push(user_agent);
566        }
567
568        sqlx::query!(
569            r#"
570            INSERT INTO syn2mas__compat_sessions (
571              compat_session_id, user_id,
572              device_id, human_name,
573              created_at, is_synapse_admin,
574              last_active_at, last_active_ip,
575              user_agent)
576            SELECT * FROM UNNEST(
577              $1::UUID[], $2::UUID[],
578              $3::TEXT[], $4::TEXT[],
579              $5::TIMESTAMP WITH TIME ZONE[], $6::BOOLEAN[],
580              $7::TIMESTAMP WITH TIME ZONE[], $8::INET[],
581              $9::TEXT[])
582            "#,
583            &session_ids[..],
584            &user_ids[..],
585            &device_ids[..] as &[Option<String>],
586            &human_names[..] as &[Option<String>],
587            &created_ats[..],
588            &is_synapse_admins[..],
589            // We need to override the typing for arrays of optionals (sqlx limitation)
590            &last_active_ats[..] as &[Option<DateTime<Utc>>],
591            &last_active_ips[..] as &[Option<IpAddr>],
592            &user_agents[..] as &[Option<String>],
593        )
594        .execute(&mut *conn)
595        .await
596        .into_database("writing compat sessions to MAS")?;
597
598        Ok(())
599    }
600}
601
602pub struct MasNewCompatAccessToken {
603    pub token_id: Uuid,
604    pub session_id: Uuid,
605    pub access_token: String,
606    pub created_at: DateTime<Utc>,
607    pub expires_at: Option<DateTime<Utc>>,
608}
609
610impl WriteBatch for MasNewCompatAccessToken {
611    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
612        let mut token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
613        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
614        let mut access_tokens: Vec<String> = Vec::with_capacity(batch.len());
615        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
616        let mut expires_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
617
618        for MasNewCompatAccessToken {
619            token_id,
620            session_id,
621            access_token,
622            created_at,
623            expires_at,
624        } in batch
625        {
626            token_ids.push(token_id);
627            session_ids.push(session_id);
628            access_tokens.push(access_token);
629            created_ats.push(created_at);
630            expires_ats.push(expires_at);
631        }
632
633        sqlx::query!(
634            r#"
635            INSERT INTO syn2mas__compat_access_tokens (
636              compat_access_token_id,
637              compat_session_id,
638              access_token,
639              created_at,
640              expires_at)
641            SELECT * FROM UNNEST(
642              $1::UUID[],
643              $2::UUID[],
644              $3::TEXT[],
645              $4::TIMESTAMP WITH TIME ZONE[],
646              $5::TIMESTAMP WITH TIME ZONE[])
647            "#,
648            &token_ids[..],
649            &session_ids[..],
650            &access_tokens[..],
651            &created_ats[..],
652            // We need to override the typing for arrays of optionals (sqlx limitation)
653            &expires_ats[..] as &[Option<DateTime<Utc>>],
654        )
655        .execute(&mut *conn)
656        .await
657        .into_database("writing compat access tokens to MAS")?;
658
659        Ok(())
660    }
661}
662
663pub struct MasNewCompatRefreshToken {
664    pub refresh_token_id: Uuid,
665    pub session_id: Uuid,
666    pub access_token_id: Uuid,
667    pub refresh_token: String,
668    pub created_at: DateTime<Utc>,
669}
670
671impl WriteBatch for MasNewCompatRefreshToken {
672    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
673        let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
674        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
675        let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
676        let mut refresh_tokens: Vec<String> = Vec::with_capacity(batch.len());
677        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
678
679        for MasNewCompatRefreshToken {
680            refresh_token_id,
681            session_id,
682            access_token_id,
683            refresh_token,
684            created_at,
685        } in batch
686        {
687            refresh_token_ids.push(refresh_token_id);
688            session_ids.push(session_id);
689            access_token_ids.push(access_token_id);
690            refresh_tokens.push(refresh_token);
691            created_ats.push(created_at);
692        }
693
694        sqlx::query!(
695            r#"
696            INSERT INTO syn2mas__compat_refresh_tokens (
697              compat_refresh_token_id,
698              compat_session_id,
699              compat_access_token_id,
700              refresh_token,
701              created_at)
702            SELECT * FROM UNNEST(
703              $1::UUID[],
704              $2::UUID[],
705              $3::UUID[],
706              $4::TEXT[],
707              $5::TIMESTAMP WITH TIME ZONE[])
708            "#,
709            &refresh_token_ids[..],
710            &session_ids[..],
711            &access_token_ids[..],
712            &refresh_tokens[..],
713            &created_ats[..],
714        )
715        .execute(&mut *conn)
716        .await
717        .into_database("writing compat refresh tokens to MAS")?;
718
719        Ok(())
720    }
721}
722
723/// The 'version' of the password hashing scheme used for passwords when they
724/// are migrated from Synapse to MAS.
725/// This is version 1, as in the previous syn2mas script.
726// TODO hardcoding version to `1` may not be correct long-term?
727pub const MIGRATED_PASSWORD_VERSION: u16 = 1;
728
729/// List of all MAS tables that are written to by syn2mas.
730pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
731    "users",
732    "user_passwords",
733    "user_emails",
734    "user_unsupported_third_party_ids",
735    "upstream_oauth_links",
736    "compat_sessions",
737    "compat_access_tokens",
738    "compat_refresh_tokens",
739];
740
741/// Detect whether a syn2mas migration has started on the given database.
742///
743/// Concretly, this checks for the presence of syn2mas restoration tables.
744///
745/// Returns `true` if syn2mas has started, or `false` if it hasn't.
746///
747/// # Errors
748///
749/// Errors are returned under the following circumstances:
750///
751/// - If any database error occurs whilst querying the database.
752/// - If some, but not all, syn2mas restoration tables are present. (This
753///   shouldn't be possible without syn2mas having been sabotaged!)
754pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Error> {
755    // Names of tables used for syn2mas resumption
756    // Must be `String`s, not just `&str`, for the query.
757    let restore_table_names = vec![
758        "syn2mas_restore_constraints".to_owned(),
759        "syn2mas_restore_indices".to_owned(),
760    ];
761
762    let num_resumption_tables = query!(
763        r#"
764        SELECT 1 AS _dummy FROM pg_tables WHERE schemaname = current_schema
765        AND tablename = ANY($1)
766        "#,
767        &restore_table_names,
768    )
769    .fetch_all(conn.as_mut())
770    .await
771    .into_database("failed to query count of resumption tables")?
772    .len();
773
774    if num_resumption_tables == 0 {
775        Ok(false)
776    } else if num_resumption_tables == restore_table_names.len() {
777        Ok(true)
778    } else {
779        Err(Error::inconsistent(
780            "some, but not all, syn2mas resumption tables were found",
781        ))
782    }
783}
784
785impl MasWriter {
786    /// Creates a new MAS writer.
787    ///
788    /// # Errors
789    ///
790    /// Errors are returned in the following conditions:
791    ///
792    /// - If the database connection experiences an error.
793    #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
794    pub async fn new(
795        mut conn: LockedMasDatabase,
796        mut writer_connections: Vec<PgConnection>,
797        dry_run: bool,
798    ) -> Result<Self, Error> {
799        // Given that we don't have any concurrent transactions here,
800        // the READ COMMITTED isolation level is sufficient.
801        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
802            .execute(conn.as_mut())
803            .await
804            .into_database("begin MAS transaction")?;
805
806        let syn2mas_started = is_syn2mas_in_progress(conn.as_mut()).await?;
807
808        let indices_to_restore;
809        let constraints_to_restore;
810
811        if syn2mas_started {
812            // We are resuming from a partially-done syn2mas migration
813            // We should reset the database so that we're starting from scratch.
814            warn!("Partial syn2mas migration has already been done; resetting.");
815            for table in MAS_TABLES_AFFECTED_BY_MIGRATION {
816                query(&format!("TRUNCATE syn2mas__{table};"))
817                    .execute(conn.as_mut())
818                    .await
819                    .into_database_with(|| format!("failed to truncate table syn2mas__{table}"))?;
820            }
821
822            indices_to_restore = query_as!(
823                IndexDescription,
824                "SELECT table_name, name, definition FROM syn2mas_restore_indices ORDER BY order_key"
825            )
826                .fetch_all(conn.as_mut())
827                .await
828                .into_database("failed to get syn2mas restore data (index descriptions)")?;
829            constraints_to_restore = query_as!(
830                ConstraintDescription,
831                "SELECT table_name, name, definition FROM syn2mas_restore_constraints ORDER BY order_key"
832            )
833                .fetch_all(conn.as_mut())
834                .await
835                .into_database("failed to get syn2mas restore data (constraint descriptions)")?;
836        } else {
837            info!("Starting new syn2mas migration");
838
839            conn.as_mut()
840                .execute_many(include_str!("syn2mas_temporary_tables.sql"))
841                // We don't care about any query results
842                .try_collect::<Vec<_>>()
843                .await
844                .into_database("could not create temporary tables")?;
845
846            // Pause (temporarily drop) indices and constraints in order to improve
847            // performance of bulk data loading.
848            (indices_to_restore, constraints_to_restore) =
849                Self::pause_indices(conn.as_mut()).await?;
850
851            // Persist these index and constraint definitions.
852            for IndexDescription {
853                name,
854                table_name,
855                definition,
856            } in &indices_to_restore
857            {
858                query!(
859                    r#"
860                    INSERT INTO syn2mas_restore_indices (name, table_name, definition)
861                    VALUES ($1, $2, $3)
862                    "#,
863                    name,
864                    table_name,
865                    definition
866                )
867                .execute(conn.as_mut())
868                .await
869                .into_database("failed to save restore data (index)")?;
870            }
871            for ConstraintDescription {
872                name,
873                table_name,
874                definition,
875            } in &constraints_to_restore
876            {
877                query!(
878                    r#"
879                    INSERT INTO syn2mas_restore_constraints (name, table_name, definition)
880                    VALUES ($1, $2, $3)
881                    "#,
882                    name,
883                    table_name,
884                    definition
885                )
886                .execute(conn.as_mut())
887                .await
888                .into_database("failed to save restore data (index)")?;
889            }
890        }
891
892        query("COMMIT;")
893            .execute(conn.as_mut())
894            .await
895            .into_database("begin MAS transaction")?;
896
897        // Now after all the schema changes have been done, begin writer transactions
898        for writer_connection in &mut writer_connections {
899            query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
900                .execute(&mut *writer_connection)
901                .await
902                .into_database("begin MAS writer transaction")?;
903        }
904
905        Ok(Self {
906            conn,
907            dry_run,
908            writer_pool: WriterConnectionPool::new(writer_connections),
909            indices_to_restore,
910            constraints_to_restore,
911            write_buffer_finish_checker: FinishChecker::default(),
912        })
913    }
914
915    #[tracing::instrument(skip_all)]
916    async fn pause_indices(
917        conn: &mut PgConnection,
918    ) -> Result<(Vec<IndexDescription>, Vec<ConstraintDescription>), Error> {
919        let mut indices_to_restore = Vec::new();
920        let mut constraints_to_restore = Vec::new();
921
922        for &unprefixed_table in MAS_TABLES_AFFECTED_BY_MIGRATION {
923            let table = format!("syn2mas__{unprefixed_table}");
924            // First drop incoming foreign key constraints
925            for constraint in
926                constraint_pausing::describe_foreign_key_constraints_to_table(&mut *conn, &table)
927                    .await?
928            {
929                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
930                constraints_to_restore.push(constraint);
931            }
932            // After all incoming foreign key constraints have been removed,
933            // we can now drop internal constraints.
934            for constraint in
935                constraint_pausing::describe_constraints_on_table(&mut *conn, &table).await?
936            {
937                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
938                constraints_to_restore.push(constraint);
939            }
940            // After all constraints have been removed, we can drop indices.
941            for index in constraint_pausing::describe_indices_on_table(&mut *conn, &table).await? {
942                constraint_pausing::drop_index(&mut *conn, &index).await?;
943                indices_to_restore.push(index);
944            }
945        }
946
947        Ok((indices_to_restore, constraints_to_restore))
948    }
949
950    async fn restore_indices(
951        conn: &mut LockedMasDatabase,
952        indices_to_restore: &[IndexDescription],
953        constraints_to_restore: &[ConstraintDescription],
954        progress: &Progress,
955    ) -> Result<(), Error> {
956        // First restore all indices. The order is not important as far as I know.
957        // However the indices are needed before constraints.
958        for index in indices_to_restore.iter().rev() {
959            progress.rebuild_index(index.name.clone());
960            constraint_pausing::restore_index(conn.as_mut(), index).await?;
961        }
962        // Then restore all constraints.
963        // The order here is the reverse of drop order, since some constraints may rely
964        // on other constraints to work.
965        for constraint in constraints_to_restore.iter().rev() {
966            progress.rebuild_constraint(constraint.name.clone());
967            constraint_pausing::restore_constraint(conn.as_mut(), constraint).await?;
968        }
969        Ok(())
970    }
971
972    /// Finish writing to the MAS database, flushing and committing all changes.
973    /// It returns the unlocked underlying connection.
974    ///
975    /// # Errors
976    ///
977    /// Errors are returned in the following conditions:
978    ///
979    /// - If the database connection experiences an error.
980    #[tracing::instrument(skip_all)]
981    pub async fn finish(mut self, progress: &Progress) -> Result<PgConnection, Error> {
982        self.write_buffer_finish_checker.check_all_finished()?;
983
984        // Commit all writer transactions to the database.
985        self.writer_pool
986            .finish()
987            .await
988            .map_err(|errors| Error::Multiple(MultipleErrors::from(errors)))?;
989
990        // Now all the data has been migrated, finish off by restoring indices and
991        // constraints!
992        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
993            .execute(self.conn.as_mut())
994            .await
995            .into_database("begin MAS transaction")?;
996
997        Self::restore_indices(
998            &mut self.conn,
999            &self.indices_to_restore,
1000            &self.constraints_to_restore,
1001            progress,
1002        )
1003        .await?;
1004
1005        self.conn
1006            .as_mut()
1007            .execute_many(include_str!("syn2mas_revert_temporary_tables.sql"))
1008            // We don't care about any query results
1009            .try_collect::<Vec<_>>()
1010            .await
1011            .into_database("could not revert temporary tables")?;
1012
1013        // If we're in dry-run mode, truncate all the tables we've written to
1014        if self.dry_run {
1015            warn!("Migration ran in dry-run mode, deleting all imported data");
1016            let tables = MAS_TABLES_AFFECTED_BY_MIGRATION
1017                .iter()
1018                .map(|table| format!("\"{table}\""))
1019                .collect::<Vec<_>>()
1020                .join(", ");
1021
1022            // Note that we do that with CASCADE, because we do that *after*
1023            // restoring the FK constraints.
1024            //
1025            // The alternative would be to list all the tables we have FK to
1026            // those tables, which would be a hassle, or to do that after
1027            // restoring the constraints, which would mean we wouldn't validate
1028            // that we've done valid FKs in dry-run mode.
1029            query(&format!("TRUNCATE TABLE {tables} CASCADE;"))
1030                .execute(self.conn.as_mut())
1031                .await
1032                .into_database_with(|| "failed to truncate all tables")?;
1033        }
1034
1035        query("COMMIT;")
1036            .execute(self.conn.as_mut())
1037            .await
1038            .into_database("ending MAS transaction")?;
1039
1040        let conn = self
1041            .conn
1042            .unlock()
1043            .await
1044            .into_database("could not unlock MAS database")?;
1045
1046        Ok(conn)
1047    }
1048}
1049
1050// How many entries to buffer at once, before writing a batch of rows to the
1051// database.
1052const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
1053
1054/// A buffer for writing rows to the MAS database.
1055/// Generic over the type of rows.
1056pub struct MasWriteBuffer<T> {
1057    rows: Vec<T>,
1058    finish_checker_handle: FinishCheckerHandle,
1059}
1060
1061impl<T> MasWriteBuffer<T>
1062where
1063    T: WriteBatch,
1064{
1065    pub fn new(writer: &MasWriter) -> Self {
1066        MasWriteBuffer {
1067            rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
1068            finish_checker_handle: writer.write_buffer_finish_checker.handle(),
1069        }
1070    }
1071
1072    pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1073        self.flush(writer).await?;
1074        self.finish_checker_handle.declare_finished();
1075        Ok(())
1076    }
1077
1078    pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
1079        if self.rows.is_empty() {
1080            return Ok(());
1081        }
1082        let rows = std::mem::take(&mut self.rows);
1083        self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
1084        writer
1085            .writer_pool
1086            .spawn_with_connection(move |conn| T::write_batch(conn, rows).boxed())
1087            .boxed()
1088            .await?;
1089        Ok(())
1090    }
1091
1092    pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
1093        self.rows.push(row);
1094        if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
1095            self.flush(writer).await?;
1096        }
1097        Ok(())
1098    }
1099}
1100
1101#[cfg(test)]
1102mod test {
1103    use std::collections::{BTreeMap, BTreeSet};
1104
1105    use chrono::DateTime;
1106    use futures_util::TryStreamExt;
1107    use serde::Serialize;
1108    use sqlx::{Column, PgConnection, PgPool, Row};
1109    use uuid::{NonNilUuid, Uuid};
1110
1111    use crate::{
1112        LockedMasDatabase, MasWriter, Progress,
1113        mas_writer::{
1114            MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
1115            MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
1116            MasNewUserPassword, MasWriteBuffer,
1117        },
1118    };
1119
1120    /// A snapshot of a whole database
1121    #[derive(Default, Serialize)]
1122    #[serde(transparent)]
1123    struct DatabaseSnapshot {
1124        tables: BTreeMap<String, TableSnapshot>,
1125    }
1126
1127    #[derive(Serialize)]
1128    #[serde(transparent)]
1129    struct TableSnapshot {
1130        rows: BTreeSet<RowSnapshot>,
1131    }
1132
1133    #[derive(PartialEq, Eq, PartialOrd, Ord, Serialize)]
1134    #[serde(transparent)]
1135    struct RowSnapshot {
1136        columns_to_values: BTreeMap<String, Option<String>>,
1137    }
1138
1139    const SKIPPED_TABLES: &[&str] = &["_sqlx_migrations"];
1140
1141    /// Produces a serialisable snapshot of a database, usable for snapshot
1142    /// testing
1143    ///
1144    /// For brevity, empty tables, as well as [`SKIPPED_TABLES`], will not be
1145    /// included in the snapshot.
1146    async fn snapshot_database(conn: &mut PgConnection) -> DatabaseSnapshot {
1147        let mut out = DatabaseSnapshot::default();
1148        let table_names: Vec<String> = sqlx::query_scalar(
1149            "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();",
1150        )
1151        .fetch_all(&mut *conn)
1152        .await
1153        .unwrap();
1154
1155        for table_name in table_names {
1156            if SKIPPED_TABLES.contains(&table_name.as_str()) {
1157                continue;
1158            }
1159
1160            let column_names: Vec<String> = sqlx::query_scalar(
1161                "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
1162            ).bind(&table_name).fetch_all(&mut *conn).await.expect("failed to get column names for table for snapshotting");
1163
1164            let column_name_list = column_names
1165                .iter()
1166                // stringify all the values for simplicity
1167                .map(|column_name| format!("{column_name}::TEXT AS \"{column_name}\""))
1168                .collect::<Vec<_>>()
1169                .join(", ");
1170
1171            let table_rows = sqlx::query(&format!("SELECT {column_name_list} FROM {table_name};"))
1172                .fetch(&mut *conn)
1173                .map_ok(|row| {
1174                    let mut columns_to_values = BTreeMap::new();
1175                    for (idx, column) in row.columns().iter().enumerate() {
1176                        columns_to_values.insert(column.name().to_owned(), row.get(idx));
1177                    }
1178                    RowSnapshot { columns_to_values }
1179                })
1180                .try_collect::<BTreeSet<RowSnapshot>>()
1181                .await
1182                .expect("failed to fetch rows from table for snapshotting");
1183
1184            if !table_rows.is_empty() {
1185                out.tables
1186                    .insert(table_name, TableSnapshot { rows: table_rows });
1187            }
1188        }
1189
1190        out
1191    }
1192
1193    /// Make a snapshot assertion against the database.
1194    macro_rules! assert_db_snapshot {
1195        ($db: expr) => {
1196            let db_snapshot = snapshot_database($db).await;
1197            ::insta::assert_yaml_snapshot!(db_snapshot);
1198        };
1199    }
1200
1201    /// Runs some code with a `MasWriter`.
1202    ///
1203    /// The callback is responsible for `finish`ing the `MasWriter`.
1204    async fn make_mas_writer(pool: &PgPool) -> MasWriter {
1205        let main_conn = pool.acquire().await.unwrap().detach();
1206        let mut writer_conns = Vec::new();
1207        for _ in 0..2 {
1208            writer_conns.push(
1209                pool.acquire()
1210                    .await
1211                    .expect("failed to acquire MasWriter writer connection")
1212                    .detach(),
1213            );
1214        }
1215        let locked_main_conn = LockedMasDatabase::try_new(main_conn)
1216            .await
1217            .expect("failed to lock MAS database")
1218            .expect_left("MAS database is already locked");
1219        MasWriter::new(locked_main_conn, writer_conns, false)
1220            .await
1221            .expect("failed to construct MasWriter")
1222    }
1223
1224    /// Tests writing a single user, without a password.
1225    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1226    async fn test_write_user(pool: PgPool) {
1227        let mut writer = make_mas_writer(&pool).await;
1228        let mut buffer = MasWriteBuffer::new(&writer);
1229
1230        buffer
1231            .write(
1232                &mut writer,
1233                MasNewUser {
1234                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1235                    username: "alice".to_owned(),
1236                    created_at: DateTime::default(),
1237                    locked_at: None,
1238                    deactivated_at: None,
1239                    can_request_admin: false,
1240                    is_guest: false,
1241                },
1242            )
1243            .await
1244            .expect("failed to write user");
1245
1246        buffer
1247            .finish(&mut writer)
1248            .await
1249            .expect("failed to finish MasWriter");
1250
1251        let mut conn = writer
1252            .finish(&Progress::default())
1253            .await
1254            .expect("failed to finish MasWriter");
1255
1256        assert_db_snapshot!(&mut conn);
1257    }
1258
1259    /// Tests writing a single user, with a password.
1260    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1261    async fn test_write_user_with_password(pool: PgPool) {
1262        const USER_ID: NonNilUuid = NonNilUuid::new(Uuid::from_u128(1u128)).unwrap();
1263
1264        let mut writer = make_mas_writer(&pool).await;
1265
1266        let mut user_buffer = MasWriteBuffer::new(&writer);
1267        let mut password_buffer = MasWriteBuffer::new(&writer);
1268
1269        user_buffer
1270            .write(
1271                &mut writer,
1272                MasNewUser {
1273                    user_id: USER_ID,
1274                    username: "alice".to_owned(),
1275                    created_at: DateTime::default(),
1276                    locked_at: None,
1277                    deactivated_at: None,
1278                    can_request_admin: false,
1279                    is_guest: false,
1280                },
1281            )
1282            .await
1283            .expect("failed to write user");
1284
1285        password_buffer
1286            .write(
1287                &mut writer,
1288                MasNewUserPassword {
1289                    user_password_id: Uuid::from_u128(42u128),
1290                    user_id: USER_ID,
1291                    hashed_password: "$bcrypt$aaaaaaaaaaa".to_owned(),
1292                    created_at: DateTime::default(),
1293                },
1294            )
1295            .await
1296            .expect("failed to write password");
1297
1298        user_buffer
1299            .finish(&mut writer)
1300            .await
1301            .expect("failed to finish MasWriteBuffer");
1302        password_buffer
1303            .finish(&mut writer)
1304            .await
1305            .expect("failed to finish MasWriteBuffer");
1306
1307        let mut conn = writer
1308            .finish(&Progress::default())
1309            .await
1310            .expect("failed to finish MasWriter");
1311
1312        assert_db_snapshot!(&mut conn);
1313    }
1314
1315    /// Tests writing a single user, with an e-mail address associated.
1316    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1317    async fn test_write_user_with_email(pool: PgPool) {
1318        let mut writer = make_mas_writer(&pool).await;
1319
1320        let mut user_buffer = MasWriteBuffer::new(&writer);
1321        let mut email_buffer = MasWriteBuffer::new(&writer);
1322
1323        user_buffer
1324            .write(
1325                &mut writer,
1326                MasNewUser {
1327                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1328                    username: "alice".to_owned(),
1329                    created_at: DateTime::default(),
1330                    locked_at: None,
1331                    deactivated_at: None,
1332                    can_request_admin: false,
1333                    is_guest: false,
1334                },
1335            )
1336            .await
1337            .expect("failed to write user");
1338
1339        email_buffer
1340            .write(
1341                &mut writer,
1342                MasNewEmailThreepid {
1343                    user_email_id: Uuid::from_u128(2u128),
1344                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1345                    email: "alice@example.org".to_owned(),
1346                    created_at: DateTime::default(),
1347                },
1348            )
1349            .await
1350            .expect("failed to write e-mail");
1351
1352        user_buffer
1353            .finish(&mut writer)
1354            .await
1355            .expect("failed to finish user buffer");
1356        email_buffer
1357            .finish(&mut writer)
1358            .await
1359            .expect("failed to finish email buffer");
1360
1361        let mut conn = writer
1362            .finish(&Progress::default())
1363            .await
1364            .expect("failed to finish MasWriter");
1365
1366        assert_db_snapshot!(&mut conn);
1367    }
1368
1369    /// Tests writing a single user, with a unsupported third-party ID
1370    /// associated.
1371    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1372    async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1373        let mut writer = make_mas_writer(&pool).await;
1374
1375        let mut user_buffer = MasWriteBuffer::new(&writer);
1376        let mut threepid_buffer = MasWriteBuffer::new(&writer);
1377
1378        user_buffer
1379            .write(
1380                &mut writer,
1381                MasNewUser {
1382                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1383                    username: "alice".to_owned(),
1384                    created_at: DateTime::default(),
1385                    locked_at: None,
1386                    deactivated_at: None,
1387                    can_request_admin: false,
1388                    is_guest: false,
1389                },
1390            )
1391            .await
1392            .expect("failed to write user");
1393
1394        threepid_buffer
1395            .write(
1396                &mut writer,
1397                MasNewUnsupportedThreepid {
1398                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1399                    medium: "msisdn".to_owned(),
1400                    address: "441189998819991197253".to_owned(),
1401                    created_at: DateTime::default(),
1402                },
1403            )
1404            .await
1405            .expect("failed to write phone number (unsupported threepid)");
1406
1407        user_buffer
1408            .finish(&mut writer)
1409            .await
1410            .expect("failed to finish user buffer");
1411        threepid_buffer
1412            .finish(&mut writer)
1413            .await
1414            .expect("failed to finish threepid buffer");
1415
1416        let mut conn = writer
1417            .finish(&Progress::default())
1418            .await
1419            .expect("failed to finish MasWriter");
1420
1421        assert_db_snapshot!(&mut conn);
1422    }
1423
1424    /// Tests writing a single user, with a link to an upstream provider.
1425    /// There needs to be an upstream provider in the database already — in the
1426    /// real migration, this is done by running a provider sync first.
1427    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1428    async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1429        let mut writer = make_mas_writer(&pool).await;
1430
1431        let mut user_buffer = MasWriteBuffer::new(&writer);
1432        let mut link_buffer = MasWriteBuffer::new(&writer);
1433
1434        user_buffer
1435            .write(
1436                &mut writer,
1437                MasNewUser {
1438                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1439                    username: "alice".to_owned(),
1440                    created_at: DateTime::default(),
1441                    locked_at: None,
1442                    deactivated_at: None,
1443                    can_request_admin: false,
1444                    is_guest: false,
1445                },
1446            )
1447            .await
1448            .expect("failed to write user");
1449
1450        link_buffer
1451            .write(
1452                &mut writer,
1453                MasNewUpstreamOauthLink {
1454                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1455                    link_id: Uuid::from_u128(3u128),
1456                    upstream_provider_id: Uuid::from_u128(4u128),
1457                    subject: "12345.67890".to_owned(),
1458                    created_at: DateTime::default(),
1459                },
1460            )
1461            .await
1462            .expect("failed to write link");
1463
1464        user_buffer
1465            .finish(&mut writer)
1466            .await
1467            .expect("failed to finish user buffer");
1468        link_buffer
1469            .finish(&mut writer)
1470            .await
1471            .expect("failed to finish link buffer");
1472
1473        let mut conn = writer
1474            .finish(&Progress::default())
1475            .await
1476            .expect("failed to finish MasWriter");
1477
1478        assert_db_snapshot!(&mut conn);
1479    }
1480
1481    /// Tests writing a single user, with a device (compat session).
1482    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1483    async fn test_write_user_with_device(pool: PgPool) {
1484        let mut writer = make_mas_writer(&pool).await;
1485
1486        let mut user_buffer = MasWriteBuffer::new(&writer);
1487        let mut session_buffer = MasWriteBuffer::new(&writer);
1488
1489        user_buffer
1490            .write(
1491                &mut writer,
1492                MasNewUser {
1493                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1494                    username: "alice".to_owned(),
1495                    created_at: DateTime::default(),
1496                    locked_at: None,
1497                    deactivated_at: None,
1498                    can_request_admin: false,
1499                    is_guest: false,
1500                },
1501            )
1502            .await
1503            .expect("failed to write user");
1504
1505        session_buffer
1506            .write(
1507                &mut writer,
1508                MasNewCompatSession {
1509                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1510                    session_id: Uuid::from_u128(5u128),
1511                    created_at: DateTime::default(),
1512                    device_id: Some("ADEVICE".to_owned()),
1513                    human_name: Some("alice's pinephone".to_owned()),
1514                    is_synapse_admin: true,
1515                    last_active_at: Some(DateTime::default()),
1516                    last_active_ip: Some("203.0.113.1".parse().unwrap()),
1517                    user_agent: Some("Browser/5.0".to_owned()),
1518                },
1519            )
1520            .await
1521            .expect("failed to write compat session");
1522
1523        user_buffer
1524            .finish(&mut writer)
1525            .await
1526            .expect("failed to finish user buffer");
1527        session_buffer
1528            .finish(&mut writer)
1529            .await
1530            .expect("failed to finish session buffer");
1531
1532        let mut conn = writer
1533            .finish(&Progress::default())
1534            .await
1535            .expect("failed to finish MasWriter");
1536
1537        assert_db_snapshot!(&mut conn);
1538    }
1539
1540    /// Tests writing a single user, with a device and an access token.
1541    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1542    async fn test_write_user_with_access_token(pool: PgPool) {
1543        let mut writer = make_mas_writer(&pool).await;
1544
1545        let mut user_buffer = MasWriteBuffer::new(&writer);
1546        let mut session_buffer = MasWriteBuffer::new(&writer);
1547        let mut token_buffer = MasWriteBuffer::new(&writer);
1548
1549        user_buffer
1550            .write(
1551                &mut writer,
1552                MasNewUser {
1553                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1554                    username: "alice".to_owned(),
1555                    created_at: DateTime::default(),
1556                    locked_at: None,
1557                    deactivated_at: None,
1558                    can_request_admin: false,
1559                    is_guest: false,
1560                },
1561            )
1562            .await
1563            .expect("failed to write user");
1564
1565        session_buffer
1566            .write(
1567                &mut writer,
1568                MasNewCompatSession {
1569                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1570                    session_id: Uuid::from_u128(5u128),
1571                    created_at: DateTime::default(),
1572                    device_id: Some("ADEVICE".to_owned()),
1573                    human_name: None,
1574                    is_synapse_admin: false,
1575                    last_active_at: None,
1576                    last_active_ip: None,
1577                    user_agent: None,
1578                },
1579            )
1580            .await
1581            .expect("failed to write compat session");
1582
1583        token_buffer
1584            .write(
1585                &mut writer,
1586                MasNewCompatAccessToken {
1587                    token_id: Uuid::from_u128(6u128),
1588                    session_id: Uuid::from_u128(5u128),
1589                    access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1590                    created_at: DateTime::default(),
1591                    expires_at: None,
1592                },
1593            )
1594            .await
1595            .expect("failed to write access token");
1596
1597        user_buffer
1598            .finish(&mut writer)
1599            .await
1600            .expect("failed to finish user buffer");
1601        session_buffer
1602            .finish(&mut writer)
1603            .await
1604            .expect("failed to finish session buffer");
1605        token_buffer
1606            .finish(&mut writer)
1607            .await
1608            .expect("failed to finish token buffer");
1609
1610        let mut conn = writer
1611            .finish(&Progress::default())
1612            .await
1613            .expect("failed to finish MasWriter");
1614
1615        assert_db_snapshot!(&mut conn);
1616    }
1617
1618    /// Tests writing a single user, with a device, an access token and a
1619    /// refresh token.
1620    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1621    async fn test_write_user_with_refresh_token(pool: PgPool) {
1622        let mut writer = make_mas_writer(&pool).await;
1623
1624        let mut user_buffer = MasWriteBuffer::new(&writer);
1625        let mut session_buffer = MasWriteBuffer::new(&writer);
1626        let mut token_buffer = MasWriteBuffer::new(&writer);
1627        let mut refresh_token_buffer = MasWriteBuffer::new(&writer);
1628
1629        user_buffer
1630            .write(
1631                &mut writer,
1632                MasNewUser {
1633                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1634                    username: "alice".to_owned(),
1635                    created_at: DateTime::default(),
1636                    locked_at: None,
1637                    deactivated_at: None,
1638                    can_request_admin: false,
1639                    is_guest: false,
1640                },
1641            )
1642            .await
1643            .expect("failed to write user");
1644
1645        session_buffer
1646            .write(
1647                &mut writer,
1648                MasNewCompatSession {
1649                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1650                    session_id: Uuid::from_u128(5u128),
1651                    created_at: DateTime::default(),
1652                    device_id: Some("ADEVICE".to_owned()),
1653                    human_name: None,
1654                    is_synapse_admin: false,
1655                    last_active_at: None,
1656                    last_active_ip: None,
1657                    user_agent: None,
1658                },
1659            )
1660            .await
1661            .expect("failed to write compat session");
1662
1663        token_buffer
1664            .write(
1665                &mut writer,
1666                MasNewCompatAccessToken {
1667                    token_id: Uuid::from_u128(6u128),
1668                    session_id: Uuid::from_u128(5u128),
1669                    access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1670                    created_at: DateTime::default(),
1671                    expires_at: None,
1672                },
1673            )
1674            .await
1675            .expect("failed to write access token");
1676
1677        refresh_token_buffer
1678            .write(
1679                &mut writer,
1680                MasNewCompatRefreshToken {
1681                    refresh_token_id: Uuid::from_u128(7u128),
1682                    session_id: Uuid::from_u128(5u128),
1683                    access_token_id: Uuid::from_u128(6u128),
1684                    refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1685                    created_at: DateTime::default(),
1686                },
1687            )
1688            .await
1689            .expect("failed to write refresh token");
1690
1691        user_buffer
1692            .finish(&mut writer)
1693            .await
1694            .expect("failed to finish user buffer");
1695        session_buffer
1696            .finish(&mut writer)
1697            .await
1698            .expect("failed to finish session buffer");
1699        token_buffer
1700            .finish(&mut writer)
1701            .await
1702            .expect("failed to finish token buffer");
1703        refresh_token_buffer
1704            .finish(&mut writer)
1705            .await
1706            .expect("failed to finish refresh token buffer");
1707
1708        let mut conn = writer
1709            .finish(&Progress::default())
1710            .await
1711            .expect("failed to finish MasWriter");
1712
1713        assert_db_snapshot!(&mut conn);
1714    }
1715}