1use std::{
7 str::FromStr,
8 sync::{Arc, LazyLock},
9 time::Duration,
10};
11
12use futures_util::FutureExt as _;
13use headers::{ContentLength, HeaderMapExt as _, UserAgent};
14use hyper_util::client::legacy::connect::{
15 HttpInfo,
16 dns::{GaiResolver, Name},
17};
18use opentelemetry::{
19 KeyValue,
20 metrics::{Histogram, UpDownCounter},
21};
22use opentelemetry_http::HeaderInjector;
23use opentelemetry_semantic_conventions::{
24 attribute::{HTTP_REQUEST_BODY_SIZE, HTTP_RESPONSE_BODY_SIZE},
25 metric::{HTTP_CLIENT_ACTIVE_REQUESTS, HTTP_CLIENT_REQUEST_DURATION},
26 trace::{
27 ERROR_TYPE, HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, NETWORK_LOCAL_ADDRESS,
28 NETWORK_LOCAL_PORT, NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT, NETWORK_TRANSPORT,
29 NETWORK_TYPE, SERVER_ADDRESS, SERVER_PORT, URL_FULL, URL_SCHEME, USER_AGENT_ORIGINAL,
30 },
31};
32use rustls_platform_verifier::ConfigVerifierExt;
33use tokio::time::Instant;
34use tower::{BoxError, Service as _};
35use tracing::Instrument;
36use tracing_opentelemetry::OpenTelemetrySpanExt;
37
38use crate::METER;
39
40static USER_AGENT: &str = concat!("matrix-authentication-service/", env!("CARGO_PKG_VERSION"));
41
42static HTTP_REQUESTS_DURATION_HISTOGRAM: LazyLock<Histogram<u64>> = LazyLock::new(|| {
43 METER
44 .u64_histogram(HTTP_CLIENT_REQUEST_DURATION)
45 .with_unit("ms")
46 .with_description("Duration of HTTP client requests")
47 .build()
48});
49
50static HTTP_REQUESTS_IN_FLIGHT: LazyLock<UpDownCounter<i64>> = LazyLock::new(|| {
51 METER
52 .i64_up_down_counter(HTTP_CLIENT_ACTIVE_REQUESTS)
53 .with_unit("{requests}")
54 .with_description("Number of HTTP client requests in flight")
55 .build()
56});
57
58struct TracingResolver {
59 inner: GaiResolver,
60}
61
62impl TracingResolver {
63 fn new() -> Self {
64 let inner = GaiResolver::new();
65 Self { inner }
66 }
67}
68
69impl reqwest::dns::Resolve for TracingResolver {
70 fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
71 let span = tracing::info_span!("dns.resolve", name = name.as_str());
72 let inner = &mut self.inner.clone();
73 Box::pin(
74 inner
75 .call(Name::from_str(name.as_str()).unwrap())
76 .map(|result| {
77 result
78 .map(|addrs| -> reqwest::dns::Addrs { Box::new(addrs) })
79 .map_err(|err| -> BoxError { Box::new(err) })
80 })
81 .instrument(span),
82 )
83 }
84}
85
86#[must_use]
92pub fn client() -> reqwest::Client {
93 let tls_config = rustls::ClientConfig::with_platform_verifier();
95 reqwest::Client::builder()
96 .dns_resolver(Arc::new(TracingResolver::new()))
97 .use_preconfigured_tls(tls_config)
98 .user_agent(USER_AGENT)
99 .timeout(Duration::from_secs(60))
100 .connect_timeout(Duration::from_secs(30))
101 .build()
102 .expect("failed to create HTTP client")
103}
104
105async fn send_traced(
106 request: reqwest::RequestBuilder,
107) -> Result<reqwest::Response, reqwest::Error> {
108 let start = Instant::now();
109 let (client, request) = request.build_split();
110 let mut request = request?;
111
112 let headers = request.headers();
113 let server_address = request.url().host_str().map(ToOwned::to_owned);
114 let server_port = request.url().port_or_known_default();
115 let scheme = request.url().scheme().to_owned();
116 let user_agent = headers
117 .typed_get::<UserAgent>()
118 .map(tracing::field::display);
119 let content_length = headers.typed_get().map(|ContentLength(len)| len);
120 let method = request.method().to_string();
121
122 let span = tracing::info_span!(
124 "http.client.request",
125 "otel.kind" = "client",
126 "otel.status_code" = tracing::field::Empty,
127 { HTTP_REQUEST_METHOD } = method,
128 { URL_FULL } = %request.url(),
129 { HTTP_RESPONSE_STATUS_CODE } = tracing::field::Empty,
130 { SERVER_ADDRESS } = server_address,
131 { SERVER_PORT } = server_port,
132 { HTTP_REQUEST_BODY_SIZE } = content_length,
133 { HTTP_RESPONSE_BODY_SIZE } = tracing::field::Empty,
134 { NETWORK_TRANSPORT } = "tcp",
135 { NETWORK_TYPE } = tracing::field::Empty,
136 { NETWORK_LOCAL_ADDRESS } = tracing::field::Empty,
137 { NETWORK_LOCAL_PORT } = tracing::field::Empty,
138 { NETWORK_PEER_ADDRESS } = tracing::field::Empty,
139 { NETWORK_PEER_PORT } = tracing::field::Empty,
140 { USER_AGENT_ORIGINAL } = user_agent,
141 "rust.error" = tracing::field::Empty,
142 );
143
144 let context = span.context();
146 opentelemetry::global::get_text_map_propagator(|propagator| {
147 let mut injector = HeaderInjector(request.headers_mut());
148 propagator.inject_context(&context, &mut injector);
149 });
150
151 let mut metrics_labels = vec![
152 KeyValue::new(HTTP_REQUEST_METHOD, method.clone()),
153 KeyValue::new(URL_SCHEME, scheme),
154 ];
155
156 if let Some(server_address) = server_address {
157 metrics_labels.push(KeyValue::new(SERVER_ADDRESS, server_address));
158 }
159
160 if let Some(server_port) = server_port {
161 metrics_labels.push(KeyValue::new(SERVER_PORT, i64::from(server_port)));
162 }
163
164 HTTP_REQUESTS_IN_FLIGHT.add(1, &metrics_labels);
165 async move {
166 let span = tracing::Span::current();
167 let result = client.execute(request).await;
168
169 HTTP_REQUESTS_IN_FLIGHT.add(-1, &metrics_labels);
173
174 let duration = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
175 let result = match result {
176 Ok(response) => {
177 span.record("otel.status_code", "OK");
178 span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16());
179
180 if let Some(ContentLength(content_length)) = response.headers().typed_get() {
181 span.record(HTTP_RESPONSE_BODY_SIZE, content_length);
182 }
183
184 if let Some(http_info) = response.extensions().get::<HttpInfo>() {
185 let local = http_info.local_addr();
186 let peer = http_info.remote_addr();
187 let family = if local.is_ipv4() { "ipv4" } else { "ipv6" };
188 span.record(NETWORK_TYPE, family);
189 span.record(NETWORK_LOCAL_ADDRESS, local.ip().to_string());
190 span.record(NETWORK_LOCAL_PORT, local.port());
191 span.record(NETWORK_PEER_ADDRESS, peer.ip().to_string());
192 span.record(NETWORK_PEER_PORT, peer.port());
193 } else {
194 tracing::warn!("No HttpInfo injected in response extensions");
195 }
196
197 metrics_labels.push(KeyValue::new(
198 HTTP_RESPONSE_STATUS_CODE,
199 i64::from(response.status().as_u16()),
200 ));
201
202 Ok(response)
203 }
204 Err(err) => {
205 span.record("otel.status_code", "ERROR");
206 span.record("rust.error", &err as &dyn std::error::Error);
207
208 metrics_labels.push(KeyValue::new(ERROR_TYPE, "NO_RESPONSE"));
209
210 Err(err)
211 }
212 };
213
214 HTTP_REQUESTS_DURATION_HISTOGRAM.record(duration, &metrics_labels);
215
216 result
217 }
218 .instrument(span)
219 .await
220}
221
222pub trait RequestBuilderExt {
225 fn send_traced(self) -> impl Future<Output = Result<reqwest::Response, reqwest::Error>> + Send;
227}
228
229impl RequestBuilderExt for reqwest::RequestBuilder {
230 fn send_traced(self) -> impl Future<Output = Result<reqwest::Response, reqwest::Error>> + Send {
231 send_traced(self)
232 }
233}