mas_handlers/activity_tracker/
worker.rs1use std::{collections::HashMap, net::IpAddr};
8
9use chrono::{DateTime, Utc};
10use mas_storage::{RepositoryAccess, RepositoryError, user::BrowserSessionRepository};
11use opentelemetry::{
12 Key, KeyValue,
13 metrics::{Counter, Gauge, Histogram},
14};
15use sqlx::PgPool;
16use tokio_util::sync::CancellationToken;
17use ulid::Ulid;
18
19use crate::{
20 METER,
21 activity_tracker::{Message, SessionKind},
22};
23
24static MAX_PENDING_RECORDS: usize = 1000;
30
31const TYPE: Key = Key::from_static_str("type");
32const SESSION_KIND: Key = Key::from_static_str("session_kind");
33const RESULT: Key = Key::from_static_str("result");
34
35#[derive(Clone, Copy, Debug)]
36struct ActivityRecord {
37 #[allow(dead_code)]
39 start_time: DateTime<Utc>,
40 end_time: DateTime<Utc>,
41 ip: Option<IpAddr>,
42}
43
44pub struct Worker {
46 pool: PgPool,
47 pending_records: HashMap<(SessionKind, Ulid), ActivityRecord>,
48 pending_records_gauge: Gauge<u64>,
49 message_counter: Counter<u64>,
50 flush_time_histogram: Histogram<u64>,
51}
52
53impl Worker {
54 pub(crate) fn new(pool: PgPool) -> Self {
55 let message_counter = METER
56 .u64_counter("mas.activity_tracker.messages")
57 .with_description("The number of messages received by the activity tracker")
58 .with_unit("{messages}")
59 .build();
60
61 for kind in &[
63 SessionKind::OAuth2,
64 SessionKind::Compat,
65 SessionKind::Browser,
66 ] {
67 message_counter.add(
68 0,
69 &[
70 KeyValue::new(TYPE, "record"),
71 KeyValue::new(SESSION_KIND, kind.as_str()),
72 ],
73 );
74 }
75 message_counter.add(0, &[KeyValue::new(TYPE, "flush")]);
76 message_counter.add(0, &[KeyValue::new(TYPE, "shutdown")]);
77
78 let flush_time_histogram = METER
79 .u64_histogram("mas.activity_tracker.flush_time")
80 .with_description("The time it took to flush the activity tracker")
81 .with_unit("ms")
82 .build();
83
84 let pending_records_gauge = METER
85 .u64_gauge("mas.activity_tracker.pending_records")
86 .with_description("The number of pending activity records")
87 .with_unit("{records}")
88 .build();
89 pending_records_gauge.record(0, &[]);
90
91 Self {
92 pool,
93 pending_records: HashMap::with_capacity(MAX_PENDING_RECORDS),
94 pending_records_gauge,
95 message_counter,
96 flush_time_histogram,
97 }
98 }
99
100 pub(super) async fn run(
101 mut self,
102 mut receiver: tokio::sync::mpsc::Receiver<Message>,
103 cancellation_token: CancellationToken,
104 ) {
105 let _guard = cancellation_token.clone().drop_guard();
108
109 loop {
110 let message = tokio::select! {
111 () = cancellation_token.cancelled(), if !receiver.is_closed() => {
114 receiver.close();
117 tracing::debug!("Shutting down activity tracker");
118 continue;
119 },
120
121 message = receiver.recv() => {
122 let Some(message) = message else { break };
124 message
125 }
126 };
127
128 match message {
129 Message::Record {
130 kind,
131 id,
132 date_time,
133 ip,
134 } => {
135 if self.pending_records.len() >= MAX_PENDING_RECORDS {
136 tracing::warn!("Too many pending activity records, flushing");
137 self.flush().await;
138 }
139
140 if self.pending_records.len() >= MAX_PENDING_RECORDS {
141 tracing::error!(
142 kind = kind.as_str(),
143 %id,
144 %date_time,
145 "Still too many pending activity records, dropping"
146 );
147 continue;
148 }
149
150 self.message_counter.add(
151 1,
152 &[
153 KeyValue::new(TYPE, "record"),
154 KeyValue::new(SESSION_KIND, kind.as_str()),
155 ],
156 );
157
158 let record =
159 self.pending_records
160 .entry((kind, id))
161 .or_insert_with(|| ActivityRecord {
162 start_time: date_time,
163 end_time: date_time,
164 ip,
165 });
166
167 record.end_time = date_time.max(record.end_time);
168 }
169
170 Message::Flush(tx) => {
171 self.message_counter.add(1, &[KeyValue::new(TYPE, "flush")]);
172
173 self.flush().await;
174 let _ = tx.send(());
175 }
176 }
177
178 self.pending_records_gauge
180 .record(self.pending_records.len() as u64, &[]);
181 }
182
183 self.flush().await;
185 }
186
187 async fn flush(&mut self) {
189 if self.pending_records.is_empty() {
191 return;
192 }
193
194 let start = std::time::Instant::now();
195 let res = self.try_flush().await;
196
197 let duration = start.elapsed();
199 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
200
201 match res {
202 Ok(()) => {
203 self.flush_time_histogram
204 .record(duration_ms, &[KeyValue::new(RESULT, "success")]);
205 }
206 Err(e) => {
207 self.flush_time_histogram
208 .record(duration_ms, &[KeyValue::new(RESULT, "failure")]);
209 tracing::error!(
210 error = &e as &dyn std::error::Error,
211 "Failed to flush activity tracker"
212 );
213 }
214 }
215 }
216
217 #[tracing::instrument(name = "activity_tracker.flush", skip(self))]
219 async fn try_flush(&mut self) -> Result<(), RepositoryError> {
220 let pending_records = &self.pending_records;
221
222 let mut repo = mas_storage_pg::PgRepository::from_pool(&self.pool)
223 .await
224 .map_err(RepositoryError::from_error)?
225 .boxed();
226
227 let mut browser_sessions = Vec::new();
228 let mut oauth2_sessions = Vec::new();
229 let mut compat_sessions = Vec::new();
230
231 for ((kind, id), record) in pending_records {
232 match kind {
233 SessionKind::Browser => {
234 browser_sessions.push((*id, record.end_time, record.ip));
235 }
236 SessionKind::OAuth2 => {
237 oauth2_sessions.push((*id, record.end_time, record.ip));
238 }
239 SessionKind::Compat => {
240 compat_sessions.push((*id, record.end_time, record.ip));
241 }
242 }
243 }
244
245 tracing::info!(
246 "Flushing {} activity records to the database",
247 pending_records.len()
248 );
249
250 repo.browser_session()
251 .record_batch_activity(browser_sessions)
252 .await?;
253 repo.oauth2_session()
254 .record_batch_activity(oauth2_sessions)
255 .await?;
256 repo.compat_session()
257 .record_batch_activity(compat_sessions)
258 .await?;
259
260 repo.save().await?;
261 self.pending_records.clear();
262
263 Ok(())
264 }
265}