mas_listener/
server.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11    time::Duration,
12};
13
14use futures_util::{StreamExt, stream::SelectAll};
15use hyper::{Request, Response};
16use hyper_util::{
17    rt::{TokioExecutor, TokioIo},
18    server::conn::auto::Connection,
19    service::TowerToHyperService,
20};
21use mas_context::LogContext;
22use pin_project_lite::pin_project;
23use thiserror::Error;
24use tokio_rustls::rustls::ServerConfig;
25use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
26use tower::Service;
27use tower_http::add_extension::AddExtension;
28use tracing::Instrument;
29
30use crate::{
31    ConnectionInfo,
32    maybe_tls::{MaybeTlsAcceptor, MaybeTlsStream, TlsStreamInfo},
33    proxy_protocol::{MaybeProxyAcceptor, ProxyAcceptError},
34    rewind::Rewind,
35    unix_or_tcp::{SocketAddr, UnixOrTcpConnection, UnixOrTcpListener},
36};
37
38/// The timeout for the handshake to complete
39const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
40
41pub struct Server<S> {
42    tls: Option<Arc<ServerConfig>>,
43    proxy: bool,
44    listener: UnixOrTcpListener,
45    service: S,
46}
47
48impl<S> Server<S> {
49    /// # Errors
50    ///
51    /// Returns an error if the listener couldn't be converted via [`TryInto`]
52    pub fn try_new<L>(listener: L, service: S) -> Result<Self, L::Error>
53    where
54        L: TryInto<UnixOrTcpListener>,
55    {
56        Ok(Self {
57            tls: None,
58            proxy: false,
59            listener: listener.try_into()?,
60            service,
61        })
62    }
63
64    #[must_use]
65    pub fn new(listener: impl Into<UnixOrTcpListener>, service: S) -> Self {
66        Self {
67            tls: None,
68            proxy: false,
69            listener: listener.into(),
70            service,
71        }
72    }
73
74    #[must_use]
75    pub const fn with_proxy(mut self) -> Self {
76        self.proxy = true;
77        self
78    }
79
80    #[must_use]
81    pub fn with_tls(mut self, config: Arc<ServerConfig>) -> Self {
82        self.tls = Some(config);
83        self
84    }
85
86    /// Run a single server
87    pub async fn run<B>(
88        self,
89        soft_shutdown_token: CancellationToken,
90        hard_shutdown_token: CancellationToken,
91    ) where
92        S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
93        S::Future: Send + 'static,
94        S::Error: std::error::Error + Send + Sync + 'static,
95        B: http_body::Body + Send + 'static,
96        B::Data: Send,
97        B::Error: std::error::Error + Send + Sync + 'static,
98    {
99        run_servers(
100            std::iter::once(self),
101            soft_shutdown_token,
102            hard_shutdown_token,
103        )
104        .await;
105    }
106}
107
108#[derive(Debug, Error)]
109#[non_exhaustive]
110enum AcceptError {
111    #[error("failed to complete the TLS handshake")]
112    TlsHandshake {
113        #[source]
114        source: std::io::Error,
115    },
116
117    #[error("failed to complete the proxy protocol handshake")]
118    ProxyHandshake {
119        #[source]
120        source: ProxyAcceptError,
121    },
122
123    #[error("connection handshake timed out")]
124    HandshakeTimeout {
125        #[source]
126        source: tokio::time::error::Elapsed,
127    },
128}
129
130impl AcceptError {
131    fn tls_handshake(source: std::io::Error) -> Self {
132        Self::TlsHandshake { source }
133    }
134
135    fn proxy_handshake(source: ProxyAcceptError) -> Self {
136        Self::ProxyHandshake { source }
137    }
138
139    fn handshake_timeout(source: tokio::time::error::Elapsed) -> Self {
140        Self::HandshakeTimeout { source }
141    }
142}
143
144/// Accept a connection and do the proxy protocol and TLS handshake
145///
146/// Returns an error if the proxy protocol or TLS handshake failed.
147/// Returns the connection, which should be used to spawn a task to serve the
148/// connection.
149#[allow(clippy::type_complexity)]
150#[tracing::instrument(
151    name = "accept",
152    skip_all,
153    fields(
154        network.protocol.name = "http",
155        network.peer.address,
156        network.peer.port,
157    ),
158)]
159async fn accept<S, B>(
160    maybe_proxy_acceptor: &MaybeProxyAcceptor,
161    maybe_tls_acceptor: &MaybeTlsAcceptor,
162    peer_addr: SocketAddr,
163    stream: UnixOrTcpConnection,
164    service: S,
165) -> Result<
166    Connection<
167        'static,
168        TokioIo<MaybeTlsStream<Rewind<UnixOrTcpConnection>>>,
169        TowerToHyperService<AddExtension<S, ConnectionInfo>>,
170        TokioExecutor,
171    >,
172    AcceptError,
173>
174where
175    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
176    S::Error: std::error::Error + Send + Sync + 'static,
177    S::Future: Send + 'static,
178    B: http_body::Body + Send + 'static,
179    B::Data: Send,
180    B::Error: std::error::Error + Send + Sync + 'static,
181{
182    let span = tracing::Span::current();
183
184    match peer_addr {
185        SocketAddr::Net(addr) => {
186            span.record("network.peer.address", tracing::field::display(addr.ip()));
187            span.record("network.peer.port", addr.port());
188        }
189        SocketAddr::Unix(ref addr) => {
190            span.record("network.peer.address", tracing::field::debug(addr));
191        }
192    }
193
194    // Wrap the connection acceptation logic in a timeout
195    tokio::time::timeout(HANDSHAKE_TIMEOUT, async move {
196        let (proxy, stream) = maybe_proxy_acceptor
197            .accept(stream)
198            .await
199            .map_err(AcceptError::proxy_handshake)?;
200
201        let stream = maybe_tls_acceptor
202            .accept(stream)
203            .await
204            .map_err(AcceptError::tls_handshake)?;
205
206        let tls = stream.tls_info();
207
208        // Figure out if it's HTTP/2 based on the negociated ALPN info
209        let is_h2 = tls.as_ref().is_some_and(TlsStreamInfo::is_alpn_h2);
210
211        let info = ConnectionInfo {
212            tls,
213            proxy,
214            net_peer_addr: peer_addr.into_net(),
215        };
216
217        let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
218        if is_h2 {
219            builder = builder.http2_only();
220        }
221        builder.http1().keep_alive(true);
222
223        let service = TowerToHyperService::new(AddExtension::new(service, info));
224
225        let conn = builder
226            .serve_connection(TokioIo::new(stream), service)
227            .into_owned();
228
229        Ok(conn)
230    })
231    .instrument(span)
232    .await
233    .map_err(AcceptError::handshake_timeout)?
234}
235
236pin_project! {
237    /// A wrapper around a connection that can be aborted when a shutdown signal is received.
238    ///
239    /// This works by sharing an atomic boolean between all connections, and when a shutdown
240    /// signal is received, the boolean is set to true. The connection will then check the
241    /// boolean before polling the underlying connection, and if it's true, it will start a
242    /// graceful shutdown.
243    ///
244    /// We also use an event listener to wake up the connection when the shutdown signal is
245    /// received, because the connection needs to be polled again to start the graceful shutdown.
246    struct AbortableConnection<C> {
247        #[pin]
248        connection: C,
249        #[pin]
250        cancellation_future: WaitForCancellationFutureOwned,
251        did_start_shutdown: bool,
252    }
253}
254
255impl<C> AbortableConnection<C> {
256    fn new(connection: C, cancellation_token: CancellationToken) -> Self {
257        Self {
258            connection,
259            cancellation_future: cancellation_token.cancelled_owned(),
260            did_start_shutdown: false,
261        }
262    }
263}
264
265impl<T, S, B> Future
266    for AbortableConnection<Connection<'static, T, TowerToHyperService<S>, TokioExecutor>>
267where
268    Connection<'static, T, TowerToHyperService<S>, TokioExecutor>: Future,
269    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Send + Clone + 'static,
270    S::Future: Send + 'static,
271    S::Error: std::error::Error + Send + Sync,
272    T: hyper::rt::Read + hyper::rt::Write + Unpin,
273    B: http_body::Body + Send + 'static,
274    B::Data: Send,
275    B::Error: std::error::Error + Send + Sync + 'static,
276{
277    type Output = <Connection<'static, T, TowerToHyperService<S>, TokioExecutor> as Future>::Output;
278
279    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
280        let mut this = self.project();
281
282        if let Poll::Ready(()) = this.cancellation_future.poll(cx) {
283            if !*this.did_start_shutdown {
284                *this.did_start_shutdown = true;
285                this.connection.as_mut().graceful_shutdown();
286            }
287        }
288
289        this.connection.poll(cx)
290    }
291}
292
293#[allow(clippy::too_many_lines)]
294pub async fn run_servers<S, B>(
295    listeners: impl IntoIterator<Item = Server<S>>,
296    soft_shutdown_token: CancellationToken,
297    hard_shutdown_token: CancellationToken,
298) where
299    S: Service<Request<hyper::body::Incoming>, Response = Response<B>> + Clone + Send + 'static,
300    S::Future: Send + 'static,
301    S::Error: std::error::Error + Send + Sync + 'static,
302    B: http_body::Body + Send + 'static,
303    B::Data: Send,
304    B::Error: std::error::Error + Send + Sync + 'static,
305{
306    // This guard on the shutdown token is to ensure that if this task crashes for
307    // any reason, the server will shut down
308    let _guard = soft_shutdown_token.clone().drop_guard();
309
310    // Create a stream of accepted connections out of the listeners
311    let mut accept_stream: SelectAll<_> = listeners
312        .into_iter()
313        .map(|server| {
314            let maybe_proxy_acceptor = MaybeProxyAcceptor::new(server.proxy);
315            let maybe_tls_acceptor = MaybeTlsAcceptor::new(server.tls);
316            futures_util::stream::poll_fn(move |cx| {
317                let res =
318                    std::task::ready!(server.listener.poll_accept(cx)).map(|(addr, stream)| {
319                        (
320                            maybe_proxy_acceptor,
321                            maybe_tls_acceptor.clone(),
322                            server.service.clone(),
323                            addr,
324                            stream,
325                        )
326                    });
327                Poll::Ready(Some(res))
328            })
329        })
330        .collect();
331
332    // A JoinSet which collects connections that are being accepted
333    let mut accept_tasks = tokio::task::JoinSet::new();
334    // A JoinSet which collects connections that are being served
335    let mut connection_tasks = tokio::task::JoinSet::new();
336
337    loop {
338        tokio::select! {
339            biased;
340
341            // First look for the shutdown signal
342            () = soft_shutdown_token.cancelled() => {
343                tracing::debug!("Shutting down listeners");
344                break;
345            },
346
347            // Poll on the JoinSet to collect connections to serve
348            res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
349                match res {
350                    Some(Ok(Some(connection))) => {
351                        let token = soft_shutdown_token.child_token();
352                        connection_tasks.spawn(LogContext::new("http-serve").run(async move || {
353                            tracing::debug!("Accepted connection");
354                            if let Err(e) = AbortableConnection::new(connection, token).await {
355                                tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
356                            }
357                        }));
358                    },
359                    Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
360                    Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
361                    None => tracing::error!("Join set was polled even though it was empty"),
362                }
363            },
364
365            // Poll on the JoinSet to collect finished connections
366            res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
367                match res {
368                    Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
369                    Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
370                    None => tracing::error!("Join set was polled even though it was empty"),
371                }
372            },
373
374            // Look for connections to accept
375            res = accept_stream.next() => {
376                let Some(res) = res else { continue };
377
378                // Spawn the connection in the set, so we don't have to wait for the handshake to
379                // accept the next connection. This allows us to keep track of active connections
380                // and waiting on them for a graceful shutdown
381                accept_tasks.spawn(LogContext::new("http-accept").run(async move || {
382                    let (maybe_proxy_acceptor, maybe_tls_acceptor, service, peer_addr, stream) = match res {
383                        Ok(res) => res,
384                        Err(e) => {
385                            tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection from the underlying socket");
386                            return None;
387                        }
388                    };
389
390                    match accept(&maybe_proxy_acceptor, &maybe_tls_acceptor, peer_addr, stream, service).await {
391                        Ok(connection) => Some(connection),
392                        Err(e) => {
393                            tracing::warn!(error = &e as &dyn std::error::Error, "Failed to accept connection");
394                            None
395                        }
396                    }
397                }));
398            },
399        };
400    }
401
402    // Wait for connections to cleanup
403    if !accept_tasks.is_empty() || !connection_tasks.is_empty() {
404        tracing::info!(
405            "There are {active} active connections ({pending} pending), performing a graceful shutdown. Send the shutdown signal again to force.",
406            active = connection_tasks.len(),
407            pending = accept_tasks.len(),
408        );
409
410        while !accept_tasks.is_empty() || !connection_tasks.is_empty() {
411            tokio::select! {
412                biased;
413
414                // Poll on the JoinSet to collect connections to serve
415                res = accept_tasks.join_next(), if !accept_tasks.is_empty() => {
416                    match res {
417                        Some(Ok(Some(connection))) => {
418                            let token = soft_shutdown_token.child_token();
419                            connection_tasks.spawn(LogContext::new("http-serve").run(async || {
420                                tracing::debug!("Accepted connection");
421                                if let Err(e) = AbortableConnection::new(connection, token).await {
422                                    tracing::warn!(error = &*e as &dyn std::error::Error, "Failed to serve connection");
423                                }
424                            }));
425                        }
426                        Some(Ok(None)) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
427                        Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
428                        None => tracing::error!("Join set was polled even though it was empty"),
429                    }
430                },
431
432                // Poll on the JoinSet to collect finished connections
433                res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
434                    match res {
435                        Some(Ok(())) => { /* Connection finished, any errors should be logged in in the spawned task */ },
436                        Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
437                        None => tracing::error!("Join set was polled even though it was empty"),
438                    }
439                },
440
441                // Handle when we are asked to hard shutdown
442                () = hard_shutdown_token.cancelled() => {
443                    tracing::warn!(
444                        "Forcing shutdown ({active} active connections, {pending} pending connections)",
445                        active = connection_tasks.len(),
446                        pending = accept_tasks.len(),
447                    );
448                    break;
449                },
450            }
451        }
452    }
453
454    accept_tasks.shutdown().await;
455    connection_tasks.shutdown().await;
456}