mas_handlers/activity_tracker/
worker.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository};
11use opentelemetry::{
12    Key, KeyValue,
13    metrics::{Counter, Gauge, Histogram},
14};
15use sqlx::PgPool;
16use tokio_util::sync::CancellationToken;
17use ulid::Ulid;
18
19use crate::{
20    METER,
21    activity_tracker::{Message, SessionKind},
22};
23
24/// The maximum number of pending activity records before we flush them to the
25/// database automatically.
26///
27/// The [`ActivityRecord`] structure plus the key in the [`HashMap`] takes less
28/// than 100 bytes, so this should allocate around 100kB of memory.
29static MAX_PENDING_RECORDS: usize = 1000;
30
31const TYPE: Key = Key::from_static_str("type");
32const SESSION_KIND: Key = Key::from_static_str("session_kind");
33const RESULT: Key = Key::from_static_str("result");
34
35#[derive(Clone, Copy, Debug)]
36struct ActivityRecord {
37    // XXX: We don't actually use the start time for now
38    #[allow(dead_code)]
39    start_time: DateTime<Utc>,
40    end_time: DateTime<Utc>,
41    ip: Option<IpAddr>,
42}
43
44/// Handles writing activity records to the database.
45pub struct Worker {
46    pool: PgPool,
47    pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
48    pending_records_gauge: Gauge<u64>,
49    message_counter: Counter<u64>,
50    flush_time_histogram: Histogram<u64>,
51}
52
53impl Worker {
54    pub(crate) fn new(pool: PgPool) -> Self {
55        let message_counter = METER
56            .u64_counter("mas.activity_tracker.messages")
57            .with_description("The number of messages received by the activity tracker")
58            .with_unit("{messages}")
59            .build();
60
61        // Record stuff on the counter so that the metrics are initialized
62        for kind in &[
63            SessionKind::OAuth2,
64            SessionKind::Compat,
65            SessionKind::Browser,
66        ] {
67            message_counter.add(
68                0,
69                &[
70                    KeyValue::new(TYPE, "record"),
71                    KeyValue::new(SESSION_KIND, kind.as_str()),
72                ],
73            );
74        }
75        message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
76        message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
77
78        let flush_time_histogram = METER
79            .u64_histogram("mas.activity_tracker.flush_time")
80            .with_description("The time it took to flush the activity tracker")
81            .with_unit("ms")
82            .build();
83
84        let pending_records_gauge = METER
85            .u64_gauge("mas.activity_tracker.pending_records")
86            .with_description("The number of pending activity records")
87            .with_unit("{records}")
88            .build();
89        pending_records_gauge.record(0, &[]);
90
91        Self {
92            pool,
93            pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
94            pending_records_gauge,
95            message_counter,
96            flush_time_histogram,
97        }
98    }
99
100    pub(super) async fn run(
101        mut self,
102        mut receiver: tokio::sync::mpsc::Receiver<Message>,
103        cancellation_token: CancellationToken,
104    ) {
105        // This guard on the shutdown token is to ensure that if this task crashes for
106        // any reason, the server will shut down
107        let _guard = cancellation_token.clone().drop_guard();
108
109        loop {
110            let message = tokio::select! {
111                // Because we want the cancellation token to trigger only once,
112                // we looked whether we closed the channel or not
113                () = cancellation_token.cancelled(), if !receiver.is_closed() => {
114                    // We only close the channel, which will make it flush all
115                    // the pending messages
116                    receiver.close();
117                    tracing::debug!("Shutting down activity tracker");
118                    continue;
119                },
120
121                message = receiver.recv()  => {
122                    // We consumed all the messages, break out of the loop
123                    let Some(message) = message else { break };
124                    message
125                }
126            };
127
128            match message {
129                Message::Record {
130                    kind,
131                    id,
132                    date_time,
133                    ip,
134                } => {
135                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
136                        tracing::warn!("Too many pending activity records, flushing");
137                        self.flush().await;
138                    }
139
140                    if self.pending_records.len() >= MAX_PENDING_RECORDS {
141                        tracing::error!(
142                            kind = kind.as_str(),
143                            %id,
144                            %date_time,
145                            "Still too many pending activity records, dropping"
146                        );
147                        continue;
148                    }
149
150                    self.message_counter.add(
151                        1,
152                        &[
153                            KeyValue::new(TYPE, "record"),
154                            KeyValue::new(SESSION_KIND, kind.as_str()),
155                        ],
156                    );
157
158                    let record =
159                        self.pending_records
160                            .entry((kind, id))
161                            .or_insert_with(|| ActivityRecord {
162                                start_time: date_time,
163                                end_time: date_time,
164                                ip,
165                            });
166
167                    record.end_time = date_time.max(record.end_time);
168                }
169
170                Message::Flush(tx) => {
171                    self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
172
173                    self.flush().await;
174                    let _ = tx.send(());
175                }
176            }
177
178            // Update the gauge
179            self.pending_records_gauge
180                .record(self.pending_records.len() as u64, &[]);
181        }
182
183        // Flush one last time
184        self.flush().await;
185    }
186
187    /// Flush the activity tracker.
188    async fn flush(&mut self) {
189        // Short path: if there are no pending records, we don't need to flush
190        if self.pending_records.is_empty() {
191            return;
192        }
193
194        let start = std::time::Instant::now();
195        let res = self.try_flush().await;
196
197        // Measure the time it took to flush the activity tracker
198        let duration = start.elapsed();
199        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
200
201        match res {
202            Ok(()) => {
203                self.flush_time_histogram
204                    .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
205            }
206            Err(e) => {
207                self.flush_time_histogram
208                    .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
209                tracing::error!(
210                    error = &e as &dyn std::error::Error,
211                    "Failed to flush activity tracker"
212                );
213            }
214        }
215    }
216
217    /// Fallible part of [`Self::flush`].
218    #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
219    async fn try_flush(&mut self) -> Result<(), RepositoryError> {
220        let pending_records = &self.pending_records;
221
222        let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
223            .await
224            .map_err(RepositoryError::from_error)?
225            .boxed();
226
227        let mut browser_sessions = Vec::new();
228        let mut oauth2_sessions = Vec::new();
229        let mut compat_sessions = Vec::new();
230
231        for ((kind, id), record) in pending_records {
232            match kind {
233                SessionKind::Browser => {
234                    browser_sessions.push((*id, record.end_time, record.ip));
235                }
236                SessionKind::OAuth2 => {
237                    oauth2_sessions.push((*id, record.end_time, record.ip));
238                }
239                SessionKind::Compat => {
240                    compat_sessions.push((*id, record.end_time, record.ip));
241                }
242            }
243        }
244
245        tracing::info!(
246            "Flushing {} activity records to the database",
247            pending_records.len()
248        );
249
250        repo.browser_session()
251            .record_batch_activity(browser_sessions)
252            .await?;
253        repo.oauth2_session()
254            .record_batch_activity(oauth2_sessions)
255            .await?;
256        repo.compat_session()
257            .record_batch_activity(compat_sessions)
258            .await?;
259
260        repo.save().await?;
261        self.pending_records.clear();
262
263        Ok(())
264    }
265}