mas_context/
future.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::{
7    pin::Pin,
8    sync::atomic::Ordering,
9    task::{Context, Poll},
10};
11
12use quanta::Instant;
13use tokio::task::futures::TaskLocalFuture;
14
15use crate::LogContext;
16
17pub type LogContextFuture<F> = TaskLocalFuture<crate::LogContext, PollRecordingFuture<F>>;
18
19impl LogContext {
20    /// Wrap a future with the given log context
21    pub(crate) fn wrap_future<F: Future>(&self, future: F) -> LogContextFuture<F> {
22        let future = PollRecordingFuture::new(future);
23        crate::CURRENT_LOG_CONTEXT.scope(self.clone(), future)
24    }
25}
26
27pin_project_lite::pin_project! {
28    /// A future which records the elapsed time and the number of polls in the
29    /// active log context
30    pub struct PollRecordingFuture<F> {
31        #[pin]
32        inner: F,
33    }
34}
35
36impl<F: Future> PollRecordingFuture<F> {
37    pub(crate) fn new(inner: F) -> Self {
38        Self { inner }
39    }
40}
41
42impl<F: Future> Future for PollRecordingFuture<F> {
43    type Output = F::Output;
44
45    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46        let start = Instant::now();
47        let this = self.project();
48        let result = this.inner.poll(cx);
49
50        // Record the number of polls and the time we spent polling the future
51        let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
52        let _ = crate::CURRENT_LOG_CONTEXT.try_with(|c| {
53            c.inner.polls.fetch_add(1, Ordering::Relaxed);
54            c.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
55        });
56
57        result
58    }
59}