1use std::{
11 fmt::Display,
12 net::IpAddr,
13 sync::{
14 Arc,
15 atomic::{AtomicU32, Ordering},
16 },
17};
18
19use chrono::{DateTime, Utc};
20use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
21use sqlx::{Executor, PgConnection, query, query_as};
22use thiserror::Error;
23use thiserror_ext::{Construct, ContextInto};
24use tokio::sync::mpsc::{self, Receiver, Sender};
25use tracing::{Instrument, error, info, warn};
26use uuid::{NonNilUuid, Uuid};
27
28use self::{
29 constraint_pausing::{ConstraintDescription, IndexDescription},
30 locking::LockedMasDatabase,
31};
32use crate::Progress;
33
34pub mod checks;
35pub mod locking;
36
37mod constraint_pausing;
38
39#[derive(Debug, Error, Construct, ContextInto)]
40pub enum Error {
41 #[error("database error whilst {context}")]
42 Database {
43 #[source]
44 source: sqlx::Error,
45 context: String,
46 },
47
48 #[error("writer connection pool shut down due to error")]
49 #[expect(clippy::enum_variant_names)]
50 WriterConnectionPoolError,
51
52 #[error("inconsistent database: {0}")]
53 Inconsistent(String),
54
55 #[error("bug in syn2mas: write buffers not finished")]
56 WriteBuffersNotFinished,
57
58 #[error("{0}")]
59 Multiple(MultipleErrors),
60}
61
62#[derive(Debug)]
63pub struct MultipleErrors {
64 errors: Vec<Error>,
65}
66
67impl Display for MultipleErrors {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "multiple errors")?;
70 for error in &self.errors {
71 write!(f, "\n- {error}")?;
72 }
73 Ok(())
74 }
75}
76
77impl From<Vec<Error>> for MultipleErrors {
78 fn from(value: Vec<Error>) -> Self {
79 MultipleErrors { errors: value }
80 }
81}
82
83struct WriterConnectionPool {
84 num_connections: usize,
86
87 connection_rx: Receiver<Result<PgConnection, Error>>,
90
91 connection_tx: Sender<Result<PgConnection, Error>>,
94}
95
96impl WriterConnectionPool {
97 pub fn new(connections: Vec<PgConnection>) -> Self {
98 let num_connections = connections.len();
99 let (connection_tx, connection_rx) = mpsc::channel(num_connections);
100 for connection in connections {
101 connection_tx
102 .try_send(Ok(connection))
103 .expect("there should be room for this connection");
104 }
105
106 WriterConnectionPool {
107 num_connections,
108 connection_rx,
109 connection_tx,
110 }
111 }
112
113 pub async fn spawn_with_connection<F>(&mut self, task: F) -> Result<(), Error>
114 where
115 F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>>
116 + Send
117 + 'static,
118 {
119 match self.connection_rx.recv().await {
120 Some(Ok(mut connection)) => {
121 let connection_tx = self.connection_tx.clone();
122 tokio::task::spawn(
123 async move {
124 let to_return = match task(&mut connection).await {
125 Ok(()) => Ok(connection),
126 Err(error) => {
127 error!("error in writer: {error}");
128 Err(error)
129 }
130 };
131 let _: Result<_, _> = connection_tx.send(to_return).await;
134 }
135 .instrument(tracing::debug_span!("spawn_with_connection")),
136 );
137
138 Ok(())
139 }
140 Some(Err(error)) => {
141 let _: Result<_, _> = self.connection_tx.send(Err(error)).await;
144
145 Err(Error::WriterConnectionPoolError)
146 }
147 None => {
148 unreachable!("we still hold a reference to the sender, so this shouldn't happen")
149 }
150 }
151 }
152
153 pub async fn finish(self) -> Result<(), Vec<Error>> {
165 let mut errors = Vec::new();
166
167 let Self {
168 num_connections,
169 mut connection_rx,
170 connection_tx,
171 } = self;
172 drop(connection_tx);
174
175 let mut finished_connections = 0;
176
177 while let Some(connection_or_error) = connection_rx.recv().await {
178 finished_connections += 1;
179
180 match connection_or_error {
181 Ok(mut connection) => {
182 if let Err(err) = query("COMMIT;").execute(&mut connection).await {
183 errors.push(err.into_database("commit writer transaction"));
184 }
185 }
186 Err(error) => {
187 errors.push(error);
188 }
189 }
190 }
191 assert_eq!(
192 finished_connections, num_connections,
193 "syn2mas had a bug: connections went missing {finished_connections} != {num_connections}"
194 );
195
196 if errors.is_empty() {
197 Ok(())
198 } else {
199 Err(errors)
200 }
201 }
202}
203
204#[derive(Default)]
207struct FinishChecker {
208 counter: Arc<AtomicU32>,
209}
210
211struct FinishCheckerHandle {
212 counter: Arc<AtomicU32>,
213}
214
215impl FinishChecker {
216 pub fn handle(&self) -> FinishCheckerHandle {
219 self.counter.fetch_add(1, Ordering::SeqCst);
220 FinishCheckerHandle {
221 counter: Arc::clone(&self.counter),
222 }
223 }
224
225 pub fn check_all_finished(self) -> Result<(), Error> {
227 if self.counter.load(Ordering::SeqCst) == 0 {
228 Ok(())
229 } else {
230 Err(Error::WriteBuffersNotFinished)
231 }
232 }
233}
234
235impl FinishCheckerHandle {
236 pub fn declare_finished(self) {
238 self.counter.fetch_sub(1, Ordering::SeqCst);
239 }
240}
241
242pub struct MasWriter {
243 conn: LockedMasDatabase,
244 writer_pool: WriterConnectionPool,
245 dry_run: bool,
246
247 indices_to_restore: Vec<IndexDescription>,
248 constraints_to_restore: Vec<ConstraintDescription>,
249
250 write_buffer_finish_checker: FinishChecker,
251}
252
253pub trait WriteBatch: Send + Sync + Sized + 'static {
254 fn write_batch(
255 conn: &mut PgConnection,
256 batch: Vec<Self>,
257 ) -> impl Future<Output = Result<(), Error>> + Send;
258}
259
260pub struct MasNewUser {
261 pub user_id: NonNilUuid,
262 pub username: String,
263 pub created_at: DateTime<Utc>,
264 pub locked_at: Option<DateTime<Utc>>,
265 pub deactivated_at: Option<DateTime<Utc>>,
266 pub can_request_admin: bool,
267 pub is_guest: bool,
271}
272
273impl WriteBatch for MasNewUser {
274 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
275 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
282 let mut usernames: Vec<String> = Vec::with_capacity(batch.len());
283 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
284 let mut locked_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
285 let mut deactivated_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
286 let mut can_request_admins: Vec<bool> = Vec::with_capacity(batch.len());
287 let mut is_guests: Vec<bool> = Vec::with_capacity(batch.len());
288 for MasNewUser {
289 user_id,
290 username,
291 created_at,
292 locked_at,
293 deactivated_at,
294 can_request_admin,
295 is_guest,
296 } in batch
297 {
298 user_ids.push(user_id.get());
299 usernames.push(username);
300 created_ats.push(created_at);
301 locked_ats.push(locked_at);
302 deactivated_ats.push(deactivated_at);
303 can_request_admins.push(can_request_admin);
304 is_guests.push(is_guest);
305 }
306
307 sqlx::query!(
308 r#"
309 INSERT INTO syn2mas__users (
310 user_id, username,
311 created_at, locked_at,
312 deactivated_at,
313 can_request_admin, is_guest)
314 SELECT * FROM UNNEST(
315 $1::UUID[], $2::TEXT[],
316 $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
317 $5::TIMESTAMP WITH TIME ZONE[],
318 $6::BOOL[], $7::BOOL[])
319 "#,
320 &user_ids[..],
321 &usernames[..],
322 &created_ats[..],
323 &locked_ats[..] as &[Option<DateTime<Utc>>],
325 &deactivated_ats[..] as &[Option<DateTime<Utc>>],
326 &can_request_admins[..],
327 &is_guests[..],
328 )
329 .execute(&mut *conn)
330 .await
331 .into_database("writing users to MAS")?;
332
333 Ok(())
334 }
335}
336
337pub struct MasNewUserPassword {
338 pub user_password_id: Uuid,
339 pub user_id: NonNilUuid,
340 pub hashed_password: String,
341 pub created_at: DateTime<Utc>,
342}
343
344impl WriteBatch for MasNewUserPassword {
345 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
346 let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
347 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
348 let mut hashed_passwords: Vec<String> = Vec::with_capacity(batch.len());
349 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
350 let mut versions: Vec<i32> = Vec::with_capacity(batch.len());
351 for MasNewUserPassword {
352 user_password_id,
353 user_id,
354 hashed_password,
355 created_at,
356 } in batch
357 {
358 user_password_ids.push(user_password_id);
359 user_ids.push(user_id.get());
360 hashed_passwords.push(hashed_password);
361 created_ats.push(created_at);
362 versions.push(MIGRATED_PASSWORD_VERSION.into());
363 }
364
365 sqlx::query!(
366 r#"
367 INSERT INTO syn2mas__user_passwords
368 (user_password_id, user_id, hashed_password, created_at, version)
369 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $5::INTEGER[])
370 "#,
371 &user_password_ids[..],
372 &user_ids[..],
373 &hashed_passwords[..],
374 &created_ats[..],
375 &versions[..],
376 ).execute(&mut *conn).await.into_database("writing users to MAS")?;
377
378 Ok(())
379 }
380}
381
382pub struct MasNewEmailThreepid {
383 pub user_email_id: Uuid,
384 pub user_id: NonNilUuid,
385 pub email: String,
386 pub created_at: DateTime<Utc>,
387}
388
389impl WriteBatch for MasNewEmailThreepid {
390 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
391 let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
392 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
393 let mut emails: Vec<String> = Vec::with_capacity(batch.len());
394 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
395
396 for MasNewEmailThreepid {
397 user_email_id,
398 user_id,
399 email,
400 created_at,
401 } in batch
402 {
403 user_email_ids.push(user_email_id);
404 user_ids.push(user_id.get());
405 emails.push(email);
406 created_ats.push(created_at);
407 }
408
409 sqlx::query!(
412 r#"
413 INSERT INTO syn2mas__user_emails
414 (user_email_id, user_id, email, created_at, confirmed_at)
415 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[])
416 "#,
417 &user_email_ids[..],
418 &user_ids[..],
419 &emails[..],
420 &created_ats[..],
421 ).execute(&mut *conn).await.into_database("writing emails to MAS")?;
422
423 Ok(())
424 }
425}
426
427pub struct MasNewUnsupportedThreepid {
428 pub user_id: NonNilUuid,
429 pub medium: String,
430 pub address: String,
431 pub created_at: DateTime<Utc>,
432}
433
434impl WriteBatch for MasNewUnsupportedThreepid {
435 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
436 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
437 let mut mediums: Vec<String> = Vec::with_capacity(batch.len());
438 let mut addresses: Vec<String> = Vec::with_capacity(batch.len());
439 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
440
441 for MasNewUnsupportedThreepid {
442 user_id,
443 medium,
444 address,
445 created_at,
446 } in batch
447 {
448 user_ids.push(user_id.get());
449 mediums.push(medium);
450 addresses.push(address);
451 created_ats.push(created_at);
452 }
453
454 sqlx::query!(
455 r#"
456 INSERT INTO syn2mas__user_unsupported_third_party_ids
457 (user_id, medium, address, created_at)
458 SELECT * FROM UNNEST($1::UUID[], $2::TEXT[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
459 "#,
460 &user_ids[..],
461 &mediums[..],
462 &addresses[..],
463 &created_ats[..],
464 )
465 .execute(&mut *conn)
466 .await
467 .into_database("writing unsupported threepids to MAS")?;
468
469 Ok(())
470 }
471}
472
473pub struct MasNewUpstreamOauthLink {
474 pub link_id: Uuid,
475 pub user_id: NonNilUuid,
476 pub upstream_provider_id: Uuid,
477 pub subject: String,
478 pub created_at: DateTime<Utc>,
479}
480
481impl WriteBatch for MasNewUpstreamOauthLink {
482 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
483 let mut link_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
484 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
485 let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
486 let mut subjects: Vec<String> = Vec::with_capacity(batch.len());
487 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
488
489 for MasNewUpstreamOauthLink {
490 link_id,
491 user_id,
492 upstream_provider_id,
493 subject,
494 created_at,
495 } in batch
496 {
497 link_ids.push(link_id);
498 user_ids.push(user_id.get());
499 upstream_provider_ids.push(upstream_provider_id);
500 subjects.push(subject);
501 created_ats.push(created_at);
502 }
503
504 sqlx::query!(
505 r#"
506 INSERT INTO syn2mas__upstream_oauth_links
507 (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
508 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
509 "#,
510 &link_ids[..],
511 &user_ids[..],
512 &upstream_provider_ids[..],
513 &subjects[..],
514 &created_ats[..],
515 ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
516
517 Ok(())
518 }
519}
520
521pub struct MasNewCompatSession {
522 pub session_id: Uuid,
523 pub user_id: NonNilUuid,
524 pub device_id: Option<String>,
525 pub human_name: Option<String>,
526 pub created_at: DateTime<Utc>,
527 pub is_synapse_admin: bool,
528 pub last_active_at: Option<DateTime<Utc>>,
529 pub last_active_ip: Option<IpAddr>,
530 pub user_agent: Option<String>,
531}
532
533impl WriteBatch for MasNewCompatSession {
534 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
535 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
536 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
537 let mut device_ids: Vec<Option<String>> = Vec::with_capacity(batch.len());
538 let mut human_names: Vec<Option<String>> = Vec::with_capacity(batch.len());
539 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
540 let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(batch.len());
541 let mut last_active_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
542 let mut last_active_ips: Vec<Option<IpAddr>> = Vec::with_capacity(batch.len());
543 let mut user_agents: Vec<Option<String>> = Vec::with_capacity(batch.len());
544
545 for MasNewCompatSession {
546 session_id,
547 user_id,
548 device_id,
549 human_name,
550 created_at,
551 is_synapse_admin,
552 last_active_at,
553 last_active_ip,
554 user_agent,
555 } in batch
556 {
557 session_ids.push(session_id);
558 user_ids.push(user_id.get());
559 device_ids.push(device_id);
560 human_names.push(human_name);
561 created_ats.push(created_at);
562 is_synapse_admins.push(is_synapse_admin);
563 last_active_ats.push(last_active_at);
564 last_active_ips.push(last_active_ip);
565 user_agents.push(user_agent);
566 }
567
568 sqlx::query!(
569 r#"
570 INSERT INTO syn2mas__compat_sessions (
571 compat_session_id, user_id,
572 device_id, human_name,
573 created_at, is_synapse_admin,
574 last_active_at, last_active_ip,
575 user_agent)
576 SELECT * FROM UNNEST(
577 $1::UUID[], $2::UUID[],
578 $3::TEXT[], $4::TEXT[],
579 $5::TIMESTAMP WITH TIME ZONE[], $6::BOOLEAN[],
580 $7::TIMESTAMP WITH TIME ZONE[], $8::INET[],
581 $9::TEXT[])
582 "#,
583 &session_ids[..],
584 &user_ids[..],
585 &device_ids[..] as &[Option<String>],
586 &human_names[..] as &[Option<String>],
587 &created_ats[..],
588 &is_synapse_admins[..],
589 &last_active_ats[..] as &[Option<DateTime<Utc>>],
591 &last_active_ips[..] as &[Option<IpAddr>],
592 &user_agents[..] as &[Option<String>],
593 )
594 .execute(&mut *conn)
595 .await
596 .into_database("writing compat sessions to MAS")?;
597
598 Ok(())
599 }
600}
601
602pub struct MasNewCompatAccessToken {
603 pub token_id: Uuid,
604 pub session_id: Uuid,
605 pub access_token: String,
606 pub created_at: DateTime<Utc>,
607 pub expires_at: Option<DateTime<Utc>>,
608}
609
610impl WriteBatch for MasNewCompatAccessToken {
611 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
612 let mut token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
613 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
614 let mut access_tokens: Vec<String> = Vec::with_capacity(batch.len());
615 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
616 let mut expires_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
617
618 for MasNewCompatAccessToken {
619 token_id,
620 session_id,
621 access_token,
622 created_at,
623 expires_at,
624 } in batch
625 {
626 token_ids.push(token_id);
627 session_ids.push(session_id);
628 access_tokens.push(access_token);
629 created_ats.push(created_at);
630 expires_ats.push(expires_at);
631 }
632
633 sqlx::query!(
634 r#"
635 INSERT INTO syn2mas__compat_access_tokens (
636 compat_access_token_id,
637 compat_session_id,
638 access_token,
639 created_at,
640 expires_at)
641 SELECT * FROM UNNEST(
642 $1::UUID[],
643 $2::UUID[],
644 $3::TEXT[],
645 $4::TIMESTAMP WITH TIME ZONE[],
646 $5::TIMESTAMP WITH TIME ZONE[])
647 "#,
648 &token_ids[..],
649 &session_ids[..],
650 &access_tokens[..],
651 &created_ats[..],
652 &expires_ats[..] as &[Option<DateTime<Utc>>],
654 )
655 .execute(&mut *conn)
656 .await
657 .into_database("writing compat access tokens to MAS")?;
658
659 Ok(())
660 }
661}
662
663pub struct MasNewCompatRefreshToken {
664 pub refresh_token_id: Uuid,
665 pub session_id: Uuid,
666 pub access_token_id: Uuid,
667 pub refresh_token: String,
668 pub created_at: DateTime<Utc>,
669}
670
671impl WriteBatch for MasNewCompatRefreshToken {
672 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
673 let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
674 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
675 let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
676 let mut refresh_tokens: Vec<String> = Vec::with_capacity(batch.len());
677 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
678
679 for MasNewCompatRefreshToken {
680 refresh_token_id,
681 session_id,
682 access_token_id,
683 refresh_token,
684 created_at,
685 } in batch
686 {
687 refresh_token_ids.push(refresh_token_id);
688 session_ids.push(session_id);
689 access_token_ids.push(access_token_id);
690 refresh_tokens.push(refresh_token);
691 created_ats.push(created_at);
692 }
693
694 sqlx::query!(
695 r#"
696 INSERT INTO syn2mas__compat_refresh_tokens (
697 compat_refresh_token_id,
698 compat_session_id,
699 compat_access_token_id,
700 refresh_token,
701 created_at)
702 SELECT * FROM UNNEST(
703 $1::UUID[],
704 $2::UUID[],
705 $3::UUID[],
706 $4::TEXT[],
707 $5::TIMESTAMP WITH TIME ZONE[])
708 "#,
709 &refresh_token_ids[..],
710 &session_ids[..],
711 &access_token_ids[..],
712 &refresh_tokens[..],
713 &created_ats[..],
714 )
715 .execute(&mut *conn)
716 .await
717 .into_database("writing compat refresh tokens to MAS")?;
718
719 Ok(())
720 }
721}
722
723pub const MIGRATED_PASSWORD_VERSION: u16 = 1;
728
729pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
731 "users",
732 "user_passwords",
733 "user_emails",
734 "user_unsupported_third_party_ids",
735 "upstream_oauth_links",
736 "compat_sessions",
737 "compat_access_tokens",
738 "compat_refresh_tokens",
739];
740
741pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Error> {
755 let restore_table_names = vec![
758 "syn2mas_restore_constraints".to_owned(),
759 "syn2mas_restore_indices".to_owned(),
760 ];
761
762 let num_resumption_tables = query!(
763 r#"
764 SELECT 1 AS _dummy FROM pg_tables WHERE schemaname = current_schema
765 AND tablename = ANY($1)
766 "#,
767 &restore_table_names,
768 )
769 .fetch_all(conn.as_mut())
770 .await
771 .into_database("failed to query count of resumption tables")?
772 .len();
773
774 if num_resumption_tables == 0 {
775 Ok(false)
776 } else if num_resumption_tables == restore_table_names.len() {
777 Ok(true)
778 } else {
779 Err(Error::inconsistent(
780 "some, but not all, syn2mas resumption tables were found",
781 ))
782 }
783}
784
785impl MasWriter {
786 #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
794 pub async fn new(
795 mut conn: LockedMasDatabase,
796 mut writer_connections: Vec<PgConnection>,
797 dry_run: bool,
798 ) -> Result<Self, Error> {
799 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
802 .execute(conn.as_mut())
803 .await
804 .into_database("begin MAS transaction")?;
805
806 let syn2mas_started = is_syn2mas_in_progress(conn.as_mut()).await?;
807
808 let indices_to_restore;
809 let constraints_to_restore;
810
811 if syn2mas_started {
812 warn!("Partial syn2mas migration has already been done; resetting.");
815 for table in MAS_TABLES_AFFECTED_BY_MIGRATION {
816 query(&format!("TRUNCATE syn2mas__{table};"))
817 .execute(conn.as_mut())
818 .await
819 .into_database_with(|| format!("failed to truncate table syn2mas__{table}"))?;
820 }
821
822 indices_to_restore = query_as!(
823 IndexDescription,
824 "SELECT table_name, name, definition FROM syn2mas_restore_indices ORDER BY order_key"
825 )
826 .fetch_all(conn.as_mut())
827 .await
828 .into_database("failed to get syn2mas restore data (index descriptions)")?;
829 constraints_to_restore = query_as!(
830 ConstraintDescription,
831 "SELECT table_name, name, definition FROM syn2mas_restore_constraints ORDER BY order_key"
832 )
833 .fetch_all(conn.as_mut())
834 .await
835 .into_database("failed to get syn2mas restore data (constraint descriptions)")?;
836 } else {
837 info!("Starting new syn2mas migration");
838
839 conn.as_mut()
840 .execute_many(include_str!("syn2mas_temporary_tables.sql"))
841 .try_collect::<Vec<_>>()
843 .await
844 .into_database("could not create temporary tables")?;
845
846 (indices_to_restore, constraints_to_restore) =
849 Self::pause_indices(conn.as_mut()).await?;
850
851 for IndexDescription {
853 name,
854 table_name,
855 definition,
856 } in &indices_to_restore
857 {
858 query!(
859 r#"
860 INSERT INTO syn2mas_restore_indices (name, table_name, definition)
861 VALUES ($1, $2, $3)
862 "#,
863 name,
864 table_name,
865 definition
866 )
867 .execute(conn.as_mut())
868 .await
869 .into_database("failed to save restore data (index)")?;
870 }
871 for ConstraintDescription {
872 name,
873 table_name,
874 definition,
875 } in &constraints_to_restore
876 {
877 query!(
878 r#"
879 INSERT INTO syn2mas_restore_constraints (name, table_name, definition)
880 VALUES ($1, $2, $3)
881 "#,
882 name,
883 table_name,
884 definition
885 )
886 .execute(conn.as_mut())
887 .await
888 .into_database("failed to save restore data (index)")?;
889 }
890 }
891
892 query("COMMIT;")
893 .execute(conn.as_mut())
894 .await
895 .into_database("begin MAS transaction")?;
896
897 for writer_connection in &mut writer_connections {
899 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
900 .execute(&mut *writer_connection)
901 .await
902 .into_database("begin MAS writer transaction")?;
903 }
904
905 Ok(Self {
906 conn,
907 dry_run,
908 writer_pool: WriterConnectionPool::new(writer_connections),
909 indices_to_restore,
910 constraints_to_restore,
911 write_buffer_finish_checker: FinishChecker::default(),
912 })
913 }
914
915 #[tracing::instrument(skip_all)]
916 async fn pause_indices(
917 conn: &mut PgConnection,
918 ) -> Result<(Vec<IndexDescription>, Vec<ConstraintDescription>), Error> {
919 let mut indices_to_restore = Vec::new();
920 let mut constraints_to_restore = Vec::new();
921
922 for &unprefixed_table in MAS_TABLES_AFFECTED_BY_MIGRATION {
923 let table = format!("syn2mas__{unprefixed_table}");
924 for constraint in
926 constraint_pausing::describe_foreign_key_constraints_to_table(&mut *conn, &table)
927 .await?
928 {
929 constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
930 constraints_to_restore.push(constraint);
931 }
932 for constraint in
935 constraint_pausing::describe_constraints_on_table(&mut *conn, &table).await?
936 {
937 constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
938 constraints_to_restore.push(constraint);
939 }
940 for index in constraint_pausing::describe_indices_on_table(&mut *conn, &table).await? {
942 constraint_pausing::drop_index(&mut *conn, &index).await?;
943 indices_to_restore.push(index);
944 }
945 }
946
947 Ok((indices_to_restore, constraints_to_restore))
948 }
949
950 async fn restore_indices(
951 conn: &mut LockedMasDatabase,
952 indices_to_restore: &[IndexDescription],
953 constraints_to_restore: &[ConstraintDescription],
954 progress: &Progress,
955 ) -> Result<(), Error> {
956 for index in indices_to_restore.iter().rev() {
959 progress.rebuild_index(index.name.clone());
960 constraint_pausing::restore_index(conn.as_mut(), index).await?;
961 }
962 for constraint in constraints_to_restore.iter().rev() {
966 progress.rebuild_constraint(constraint.name.clone());
967 constraint_pausing::restore_constraint(conn.as_mut(), constraint).await?;
968 }
969 Ok(())
970 }
971
972 #[tracing::instrument(skip_all)]
981 pub async fn finish(mut self, progress: &Progress) -> Result<PgConnection, Error> {
982 self.write_buffer_finish_checker.check_all_finished()?;
983
984 self.writer_pool
986 .finish()
987 .await
988 .map_err(|errors| Error::Multiple(MultipleErrors::from(errors)))?;
989
990 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
993 .execute(self.conn.as_mut())
994 .await
995 .into_database("begin MAS transaction")?;
996
997 Self::restore_indices(
998 &mut self.conn,
999 &self.indices_to_restore,
1000 &self.constraints_to_restore,
1001 progress,
1002 )
1003 .await?;
1004
1005 self.conn
1006 .as_mut()
1007 .execute_many(include_str!("syn2mas_revert_temporary_tables.sql"))
1008 .try_collect::<Vec<_>>()
1010 .await
1011 .into_database("could not revert temporary tables")?;
1012
1013 if self.dry_run {
1015 warn!("Migration ran in dry-run mode, deleting all imported data");
1016 let tables = MAS_TABLES_AFFECTED_BY_MIGRATION
1017 .iter()
1018 .map(|table| format!("\"{table}\""))
1019 .collect::<Vec<_>>()
1020 .join(", ");
1021
1022 query(&format!("TRUNCATE TABLE {tables} CASCADE;"))
1030 .execute(self.conn.as_mut())
1031 .await
1032 .into_database_with(|| "failed to truncate all tables")?;
1033 }
1034
1035 query("COMMIT;")
1036 .execute(self.conn.as_mut())
1037 .await
1038 .into_database("ending MAS transaction")?;
1039
1040 let conn = self
1041 .conn
1042 .unlock()
1043 .await
1044 .into_database("could not unlock MAS database")?;
1045
1046 Ok(conn)
1047 }
1048}
1049
1050const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
1053
1054pub struct MasWriteBuffer<T> {
1057 rows: Vec<T>,
1058 finish_checker_handle: FinishCheckerHandle,
1059}
1060
1061impl<T> MasWriteBuffer<T>
1062where
1063 T: WriteBatch,
1064{
1065 pub fn new(writer: &MasWriter) -> Self {
1066 MasWriteBuffer {
1067 rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
1068 finish_checker_handle: writer.write_buffer_finish_checker.handle(),
1069 }
1070 }
1071
1072 pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1073 self.flush(writer).await?;
1074 self.finish_checker_handle.declare_finished();
1075 Ok(())
1076 }
1077
1078 pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
1079 if self.rows.is_empty() {
1080 return Ok(());
1081 }
1082 let rows = std::mem::take(&mut self.rows);
1083 self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
1084 writer
1085 .writer_pool
1086 .spawn_with_connection(move |conn| T::write_batch(conn, rows).boxed())
1087 .boxed()
1088 .await?;
1089 Ok(())
1090 }
1091
1092 pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
1093 self.rows.push(row);
1094 if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
1095 self.flush(writer).await?;
1096 }
1097 Ok(())
1098 }
1099}
1100
1101#[cfg(test)]
1102mod test {
1103 use std::collections::{BTreeMap, BTreeSet};
1104
1105 use chrono::DateTime;
1106 use futures_util::TryStreamExt;
1107 use serde::Serialize;
1108 use sqlx::{Column, PgConnection, PgPool, Row};
1109 use uuid::{NonNilUuid, Uuid};
1110
1111 use crate::{
1112 LockedMasDatabase, MasWriter, Progress,
1113 mas_writer::{
1114 MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
1115 MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
1116 MasNewUserPassword, MasWriteBuffer,
1117 },
1118 };
1119
1120 #[derive(Default, Serialize)]
1122 #[serde(transparent)]
1123 struct DatabaseSnapshot {
1124 tables: BTreeMap<String, TableSnapshot>,
1125 }
1126
1127 #[derive(Serialize)]
1128 #[serde(transparent)]
1129 struct TableSnapshot {
1130 rows: BTreeSet<RowSnapshot>,
1131 }
1132
1133 #[derive(PartialEq, Eq, PartialOrd, Ord, Serialize)]
1134 #[serde(transparent)]
1135 struct RowSnapshot {
1136 columns_to_values: BTreeMap<String, Option<String>>,
1137 }
1138
1139 const SKIPPED_TABLES: &[&str] = &["_sqlx_migrations"];
1140
1141 async fn snapshot_database(conn: &mut PgConnection) -> DatabaseSnapshot {
1147 let mut out = DatabaseSnapshot::default();
1148 let table_names: Vec<String> = sqlx::query_scalar(
1149 "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();",
1150 )
1151 .fetch_all(&mut *conn)
1152 .await
1153 .unwrap();
1154
1155 for table_name in table_names {
1156 if SKIPPED_TABLES.contains(&table_name.as_str()) {
1157 continue;
1158 }
1159
1160 let column_names: Vec<String> = sqlx::query_scalar(
1161 "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
1162 ).bind(&table_name).fetch_all(&mut *conn).await.expect("failed to get column names for table for snapshotting");
1163
1164 let column_name_list = column_names
1165 .iter()
1166 .map(|column_name| format!("{column_name}::TEXT AS \"{column_name}\""))
1168 .collect::<Vec<_>>()
1169 .join(", ");
1170
1171 let table_rows = sqlx::query(&format!("SELECT {column_name_list} FROM {table_name};"))
1172 .fetch(&mut *conn)
1173 .map_ok(|row| {
1174 let mut columns_to_values = BTreeMap::new();
1175 for (idx, column) in row.columns().iter().enumerate() {
1176 columns_to_values.insert(column.name().to_owned(), row.get(idx));
1177 }
1178 RowSnapshot { columns_to_values }
1179 })
1180 .try_collect::<BTreeSet<RowSnapshot>>()
1181 .await
1182 .expect("failed to fetch rows from table for snapshotting");
1183
1184 if !table_rows.is_empty() {
1185 out.tables
1186 .insert(table_name, TableSnapshot { rows: table_rows });
1187 }
1188 }
1189
1190 out
1191 }
1192
1193 macro_rules! assert_db_snapshot {
1195 ($db: expr) => {
1196 let db_snapshot = snapshot_database($db).await;
1197 ::insta::assert_yaml_snapshot!(db_snapshot);
1198 };
1199 }
1200
1201 async fn make_mas_writer(pool: &PgPool) -> MasWriter {
1205 let main_conn = pool.acquire().await.unwrap().detach();
1206 let mut writer_conns = Vec::new();
1207 for _ in 0..2 {
1208 writer_conns.push(
1209 pool.acquire()
1210 .await
1211 .expect("failed to acquire MasWriter writer connection")
1212 .detach(),
1213 );
1214 }
1215 let locked_main_conn = LockedMasDatabase::try_new(main_conn)
1216 .await
1217 .expect("failed to lock MAS database")
1218 .expect_left("MAS database is already locked");
1219 MasWriter::new(locked_main_conn, writer_conns, false)
1220 .await
1221 .expect("failed to construct MasWriter")
1222 }
1223
1224 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1226 async fn test_write_user(pool: PgPool) {
1227 let mut writer = make_mas_writer(&pool).await;
1228 let mut buffer = MasWriteBuffer::new(&writer);
1229
1230 buffer
1231 .write(
1232 &mut writer,
1233 MasNewUser {
1234 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1235 username: "alice".to_owned(),
1236 created_at: DateTime::default(),
1237 locked_at: None,
1238 deactivated_at: None,
1239 can_request_admin: false,
1240 is_guest: false,
1241 },
1242 )
1243 .await
1244 .expect("failed to write user");
1245
1246 buffer
1247 .finish(&mut writer)
1248 .await
1249 .expect("failed to finish MasWriter");
1250
1251 let mut conn = writer
1252 .finish(&Progress::default())
1253 .await
1254 .expect("failed to finish MasWriter");
1255
1256 assert_db_snapshot!(&mut conn);
1257 }
1258
1259 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1261 async fn test_write_user_with_password(pool: PgPool) {
1262 const USER_ID: NonNilUuid = NonNilUuid::new(Uuid::from_u128(1u128)).unwrap();
1263
1264 let mut writer = make_mas_writer(&pool).await;
1265
1266 let mut user_buffer = MasWriteBuffer::new(&writer);
1267 let mut password_buffer = MasWriteBuffer::new(&writer);
1268
1269 user_buffer
1270 .write(
1271 &mut writer,
1272 MasNewUser {
1273 user_id: USER_ID,
1274 username: "alice".to_owned(),
1275 created_at: DateTime::default(),
1276 locked_at: None,
1277 deactivated_at: None,
1278 can_request_admin: false,
1279 is_guest: false,
1280 },
1281 )
1282 .await
1283 .expect("failed to write user");
1284
1285 password_buffer
1286 .write(
1287 &mut writer,
1288 MasNewUserPassword {
1289 user_password_id: Uuid::from_u128(42u128),
1290 user_id: USER_ID,
1291 hashed_password: "$bcrypt$aaaaaaaaaaa".to_owned(),
1292 created_at: DateTime::default(),
1293 },
1294 )
1295 .await
1296 .expect("failed to write password");
1297
1298 user_buffer
1299 .finish(&mut writer)
1300 .await
1301 .expect("failed to finish MasWriteBuffer");
1302 password_buffer
1303 .finish(&mut writer)
1304 .await
1305 .expect("failed to finish MasWriteBuffer");
1306
1307 let mut conn = writer
1308 .finish(&Progress::default())
1309 .await
1310 .expect("failed to finish MasWriter");
1311
1312 assert_db_snapshot!(&mut conn);
1313 }
1314
1315 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1317 async fn test_write_user_with_email(pool: PgPool) {
1318 let mut writer = make_mas_writer(&pool).await;
1319
1320 let mut user_buffer = MasWriteBuffer::new(&writer);
1321 let mut email_buffer = MasWriteBuffer::new(&writer);
1322
1323 user_buffer
1324 .write(
1325 &mut writer,
1326 MasNewUser {
1327 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1328 username: "alice".to_owned(),
1329 created_at: DateTime::default(),
1330 locked_at: None,
1331 deactivated_at: None,
1332 can_request_admin: false,
1333 is_guest: false,
1334 },
1335 )
1336 .await
1337 .expect("failed to write user");
1338
1339 email_buffer
1340 .write(
1341 &mut writer,
1342 MasNewEmailThreepid {
1343 user_email_id: Uuid::from_u128(2u128),
1344 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1345 email: "alice@example.org".to_owned(),
1346 created_at: DateTime::default(),
1347 },
1348 )
1349 .await
1350 .expect("failed to write e-mail");
1351
1352 user_buffer
1353 .finish(&mut writer)
1354 .await
1355 .expect("failed to finish user buffer");
1356 email_buffer
1357 .finish(&mut writer)
1358 .await
1359 .expect("failed to finish email buffer");
1360
1361 let mut conn = writer
1362 .finish(&Progress::default())
1363 .await
1364 .expect("failed to finish MasWriter");
1365
1366 assert_db_snapshot!(&mut conn);
1367 }
1368
1369 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1372 async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1373 let mut writer = make_mas_writer(&pool).await;
1374
1375 let mut user_buffer = MasWriteBuffer::new(&writer);
1376 let mut threepid_buffer = MasWriteBuffer::new(&writer);
1377
1378 user_buffer
1379 .write(
1380 &mut writer,
1381 MasNewUser {
1382 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1383 username: "alice".to_owned(),
1384 created_at: DateTime::default(),
1385 locked_at: None,
1386 deactivated_at: None,
1387 can_request_admin: false,
1388 is_guest: false,
1389 },
1390 )
1391 .await
1392 .expect("failed to write user");
1393
1394 threepid_buffer
1395 .write(
1396 &mut writer,
1397 MasNewUnsupportedThreepid {
1398 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1399 medium: "msisdn".to_owned(),
1400 address: "441189998819991197253".to_owned(),
1401 created_at: DateTime::default(),
1402 },
1403 )
1404 .await
1405 .expect("failed to write phone number (unsupported threepid)");
1406
1407 user_buffer
1408 .finish(&mut writer)
1409 .await
1410 .expect("failed to finish user buffer");
1411 threepid_buffer
1412 .finish(&mut writer)
1413 .await
1414 .expect("failed to finish threepid buffer");
1415
1416 let mut conn = writer
1417 .finish(&Progress::default())
1418 .await
1419 .expect("failed to finish MasWriter");
1420
1421 assert_db_snapshot!(&mut conn);
1422 }
1423
1424 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1428 async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1429 let mut writer = make_mas_writer(&pool).await;
1430
1431 let mut user_buffer = MasWriteBuffer::new(&writer);
1432 let mut link_buffer = MasWriteBuffer::new(&writer);
1433
1434 user_buffer
1435 .write(
1436 &mut writer,
1437 MasNewUser {
1438 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1439 username: "alice".to_owned(),
1440 created_at: DateTime::default(),
1441 locked_at: None,
1442 deactivated_at: None,
1443 can_request_admin: false,
1444 is_guest: false,
1445 },
1446 )
1447 .await
1448 .expect("failed to write user");
1449
1450 link_buffer
1451 .write(
1452 &mut writer,
1453 MasNewUpstreamOauthLink {
1454 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1455 link_id: Uuid::from_u128(3u128),
1456 upstream_provider_id: Uuid::from_u128(4u128),
1457 subject: "12345.67890".to_owned(),
1458 created_at: DateTime::default(),
1459 },
1460 )
1461 .await
1462 .expect("failed to write link");
1463
1464 user_buffer
1465 .finish(&mut writer)
1466 .await
1467 .expect("failed to finish user buffer");
1468 link_buffer
1469 .finish(&mut writer)
1470 .await
1471 .expect("failed to finish link buffer");
1472
1473 let mut conn = writer
1474 .finish(&Progress::default())
1475 .await
1476 .expect("failed to finish MasWriter");
1477
1478 assert_db_snapshot!(&mut conn);
1479 }
1480
1481 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1483 async fn test_write_user_with_device(pool: PgPool) {
1484 let mut writer = make_mas_writer(&pool).await;
1485
1486 let mut user_buffer = MasWriteBuffer::new(&writer);
1487 let mut session_buffer = MasWriteBuffer::new(&writer);
1488
1489 user_buffer
1490 .write(
1491 &mut writer,
1492 MasNewUser {
1493 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1494 username: "alice".to_owned(),
1495 created_at: DateTime::default(),
1496 locked_at: None,
1497 deactivated_at: None,
1498 can_request_admin: false,
1499 is_guest: false,
1500 },
1501 )
1502 .await
1503 .expect("failed to write user");
1504
1505 session_buffer
1506 .write(
1507 &mut writer,
1508 MasNewCompatSession {
1509 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1510 session_id: Uuid::from_u128(5u128),
1511 created_at: DateTime::default(),
1512 device_id: Some("ADEVICE".to_owned()),
1513 human_name: Some("alice's pinephone".to_owned()),
1514 is_synapse_admin: true,
1515 last_active_at: Some(DateTime::default()),
1516 last_active_ip: Some("203.0.113.1".parse().unwrap()),
1517 user_agent: Some("Browser/5.0".to_owned()),
1518 },
1519 )
1520 .await
1521 .expect("failed to write compat session");
1522
1523 user_buffer
1524 .finish(&mut writer)
1525 .await
1526 .expect("failed to finish user buffer");
1527 session_buffer
1528 .finish(&mut writer)
1529 .await
1530 .expect("failed to finish session buffer");
1531
1532 let mut conn = writer
1533 .finish(&Progress::default())
1534 .await
1535 .expect("failed to finish MasWriter");
1536
1537 assert_db_snapshot!(&mut conn);
1538 }
1539
1540 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1542 async fn test_write_user_with_access_token(pool: PgPool) {
1543 let mut writer = make_mas_writer(&pool).await;
1544
1545 let mut user_buffer = MasWriteBuffer::new(&writer);
1546 let mut session_buffer = MasWriteBuffer::new(&writer);
1547 let mut token_buffer = MasWriteBuffer::new(&writer);
1548
1549 user_buffer
1550 .write(
1551 &mut writer,
1552 MasNewUser {
1553 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1554 username: "alice".to_owned(),
1555 created_at: DateTime::default(),
1556 locked_at: None,
1557 deactivated_at: None,
1558 can_request_admin: false,
1559 is_guest: false,
1560 },
1561 )
1562 .await
1563 .expect("failed to write user");
1564
1565 session_buffer
1566 .write(
1567 &mut writer,
1568 MasNewCompatSession {
1569 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1570 session_id: Uuid::from_u128(5u128),
1571 created_at: DateTime::default(),
1572 device_id: Some("ADEVICE".to_owned()),
1573 human_name: None,
1574 is_synapse_admin: false,
1575 last_active_at: None,
1576 last_active_ip: None,
1577 user_agent: None,
1578 },
1579 )
1580 .await
1581 .expect("failed to write compat session");
1582
1583 token_buffer
1584 .write(
1585 &mut writer,
1586 MasNewCompatAccessToken {
1587 token_id: Uuid::from_u128(6u128),
1588 session_id: Uuid::from_u128(5u128),
1589 access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1590 created_at: DateTime::default(),
1591 expires_at: None,
1592 },
1593 )
1594 .await
1595 .expect("failed to write access token");
1596
1597 user_buffer
1598 .finish(&mut writer)
1599 .await
1600 .expect("failed to finish user buffer");
1601 session_buffer
1602 .finish(&mut writer)
1603 .await
1604 .expect("failed to finish session buffer");
1605 token_buffer
1606 .finish(&mut writer)
1607 .await
1608 .expect("failed to finish token buffer");
1609
1610 let mut conn = writer
1611 .finish(&Progress::default())
1612 .await
1613 .expect("failed to finish MasWriter");
1614
1615 assert_db_snapshot!(&mut conn);
1616 }
1617
1618 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1621 async fn test_write_user_with_refresh_token(pool: PgPool) {
1622 let mut writer = make_mas_writer(&pool).await;
1623
1624 let mut user_buffer = MasWriteBuffer::new(&writer);
1625 let mut session_buffer = MasWriteBuffer::new(&writer);
1626 let mut token_buffer = MasWriteBuffer::new(&writer);
1627 let mut refresh_token_buffer = MasWriteBuffer::new(&writer);
1628
1629 user_buffer
1630 .write(
1631 &mut writer,
1632 MasNewUser {
1633 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1634 username: "alice".to_owned(),
1635 created_at: DateTime::default(),
1636 locked_at: None,
1637 deactivated_at: None,
1638 can_request_admin: false,
1639 is_guest: false,
1640 },
1641 )
1642 .await
1643 .expect("failed to write user");
1644
1645 session_buffer
1646 .write(
1647 &mut writer,
1648 MasNewCompatSession {
1649 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1650 session_id: Uuid::from_u128(5u128),
1651 created_at: DateTime::default(),
1652 device_id: Some("ADEVICE".to_owned()),
1653 human_name: None,
1654 is_synapse_admin: false,
1655 last_active_at: None,
1656 last_active_ip: None,
1657 user_agent: None,
1658 },
1659 )
1660 .await
1661 .expect("failed to write compat session");
1662
1663 token_buffer
1664 .write(
1665 &mut writer,
1666 MasNewCompatAccessToken {
1667 token_id: Uuid::from_u128(6u128),
1668 session_id: Uuid::from_u128(5u128),
1669 access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1670 created_at: DateTime::default(),
1671 expires_at: None,
1672 },
1673 )
1674 .await
1675 .expect("failed to write access token");
1676
1677 refresh_token_buffer
1678 .write(
1679 &mut writer,
1680 MasNewCompatRefreshToken {
1681 refresh_token_id: Uuid::from_u128(7u128),
1682 session_id: Uuid::from_u128(5u128),
1683 access_token_id: Uuid::from_u128(6u128),
1684 refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1685 created_at: DateTime::default(),
1686 },
1687 )
1688 .await
1689 .expect("failed to write refresh token");
1690
1691 user_buffer
1692 .finish(&mut writer)
1693 .await
1694 .expect("failed to finish user buffer");
1695 session_buffer
1696 .finish(&mut writer)
1697 .await
1698 .expect("failed to finish session buffer");
1699 token_buffer
1700 .finish(&mut writer)
1701 .await
1702 .expect("failed to finish token buffer");
1703 refresh_token_buffer
1704 .finish(&mut writer)
1705 .await
1706 .expect("failed to finish refresh token buffer");
1707
1708 let mut conn = writer
1709 .finish(&Progress::default())
1710 .await
1711 .expect("failed to finish MasWriter");
1712
1713 assert_db_snapshot!(&mut conn);
1714 }
1715}