stackable_telemetry/instrumentation/axum/mod.rs
1//! This module contains types which can be used as [`axum`] layers to produce
2//! [OpenTelemetry][1] compatible [HTTP spans][2].
3//!
4//! These spans include a wide variety of fields / attributes defined by the
5//! semantic conventions specification. A few examples are:
6//!
7//! - `http.request.method`
8//! - `http.response.status_code`
9//! - `user_agent.original`
10//!
11//! [1]: https://opentelemetry.io/
12//! [2]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
13use std::{future::Future, net::SocketAddr, num::ParseIntError, task::Poll};
14
15use axum::{
16 extract::{ConnectInfo, MatchedPath, Request},
17 http::{
18 HeaderMap,
19 header::{HOST, USER_AGENT},
20 },
21 response::Response,
22};
23use futures_util::ready;
24use opentelemetry::{
25 Context,
26 trace::{SpanKind, TraceContextExt},
27};
28use opentelemetry_semantic_conventions as semconv;
29use opentelemetry_semantic_conventions::trace::HTTP_REQUEST_HEADER;
30use pin_project::pin_project;
31use snafu::{ResultExt, Snafu};
32use tower::{Layer, Service};
33use tracing::{Span, field::Empty};
34use tracing_opentelemetry::OpenTelemetrySpanExt;
35
36mod extractor;
37mod injector;
38
39pub use extractor::*;
40pub use injector::*;
41
42const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
43const DEFAULT_HTTPS_PORT: u16 = 443;
44const DEFAULT_HTTP_PORT: u16 = 80;
45
46// NOTE (@Techassi): These constants are defined here because they are private in
47// the tracing-opentelemetry crate.
48const OTEL_NAME: &str = "otel.name";
49const OTEL_KIND: &str = "otel.kind";
50
51const OTEL_TRACE_ID_FROM: &str = "opentelemetry.trace_id.from";
52const OTEL_TRACE_ID_TO: &str = "opentelemetry.trace_id.to";
53
54/// A Tower [`Layer`][1] which decorates [`TraceService`].
55///
56/// ### Example with Axum
57///
58/// ```
59/// use stackable_telemetry::AxumTraceLayer;
60/// use axum::{routing::get, Router};
61///
62/// let trace_layer = AxumTraceLayer::new();
63/// let router = Router::new()
64/// .route("/", get(|| async { "Hello, World!" }))
65/// .layer(trace_layer);
66///
67/// # let _: Router = router;
68/// ```
69///
70/// This layer is implemented based on [this][1] official Tower guide.
71///
72/// [1]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md
73#[derive(Clone, Debug, Default)]
74pub struct TraceLayer {
75 opt_in: bool,
76}
77
78impl<S> Layer<S> for TraceLayer {
79 type Service = TraceService<S>;
80
81 fn layer(&self, inner: S) -> Self::Service {
82 TraceService {
83 inner,
84 opt_in: self.opt_in,
85 }
86 }
87}
88
89impl TraceLayer {
90 /// Creates a new default trace layer.
91 pub fn new() -> Self {
92 Self::default()
93 }
94
95 /// Enables various fields marked as opt-in by the specification.
96 ///
97 /// This will require more computing power and will increase the latency.
98 /// See <https://opentelemetry.io/docs/specs/semconv/http/http-spans/>
99 pub fn with_opt_in(mut self) -> Self {
100 self.opt_in = true;
101 self
102 }
103}
104
105/// A Tower [`Service`] which injects Span Context into HTTP Response Headers.
106#[derive(Debug, Clone)]
107pub struct TraceService<S> {
108 inner: S,
109 opt_in: bool,
110}
111
112impl<S> Service<Request> for TraceService<S>
113where
114 S: Service<Request, Response = Response> + Send + 'static,
115 S::Error: std::error::Error + 'static,
116 S::Future: Send + 'static,
117{
118 type Error = S::Error;
119 type Future = ResponseFuture<S::Future>;
120 type Response = S::Response;
121
122 fn poll_ready(
123 &mut self,
124 cx: &mut std::task::Context<'_>,
125 ) -> std::task::Poll<Result<(), Self::Error>> {
126 self.inner.poll_ready(cx)
127 }
128
129 fn call(&mut self, req: Request) -> Self::Future {
130 let span = Span::from_request(&req, self.opt_in);
131
132 let future = {
133 let _guard = span.enter();
134 self.inner.call(req)
135 };
136
137 ResponseFuture { future, span }
138 }
139}
140
141/// This future contains the inner service future and the current [`Span`].
142#[pin_project]
143pub struct ResponseFuture<F> {
144 #[pin]
145 pub(crate) future: F,
146 pub(crate) span: Span,
147}
148
149impl<F, E> Future for ResponseFuture<F>
150where
151 F: Future<Output = Result<Response, E>>,
152 E: std::error::Error + 'static,
153{
154 type Output = Result<Response, E>;
155
156 fn poll(
157 self: std::pin::Pin<&mut Self>,
158 cx: &mut std::task::Context<'_>,
159 ) -> std::task::Poll<Self::Output> {
160 let this = self.project();
161 let _guard = this.span.enter();
162
163 let mut result = ready!(this.future.poll(cx));
164 this.span.finalize(&mut result);
165
166 Poll::Ready(result)
167 }
168}
169
170/// Errors which can be encountered when extracting the server host from a [`Request`].
171#[derive(Debug, Snafu)]
172pub enum ServerHostError {
173 /// Indicates that parsing the port of the server host from the [`Request`] as a `u16` failed.
174 #[snafu(display("failed to parse port {port:?} as u16 from string"))]
175 ParsePort {
176 #[allow(missing_docs)]
177 source: ParseIntError,
178
179 // TODO (@Techassi): Make snafu re-emit this
180 /// The original input which was attempted to be parsed.
181 port: String,
182 },
183
184 /// Indicates that the server host from the [`Request`] contains an invalid/unknown scheme.
185 #[snafu(display("encountered invalid request scheme {scheme:?}"))]
186 InvalidScheme {
187 /// The original scheme.
188 scheme: String,
189 },
190
191 // TODO (@Techassi): Make snafu re-emit this
192 /// Indicates that no method of extracting the server host from the [`Request`] succeeded.
193 #[snafu(display("failed to extract any host information from request"))]
194 ExtractHost,
195}
196
197/// This trait provides various helper functions to extract data from a HTTP [`Request`].
198pub trait RequestExt {
199 /// Returns the client socket address, if available.
200 fn client_socket_address(&self) -> Option<SocketAddr>;
201
202 /// Returns the server host, if available.
203 ///
204 /// ### Value Selection Strategy
205 ///
206 /// The following value selection strategy is taken verbatim from [this][1]
207 /// section of the HTTP span semantic conventions:
208 ///
209 /// > HTTP server instrumentations SHOULD do the best effort when populating
210 /// > server.address and server.port attributes and SHOULD determine them by
211 /// > using the first of the following that applies:
212 /// >
213 /// > - The original host which may be passed by the reverse proxy in the
214 /// > Forwarded#host, X-Forwarded-Host, or a similar header.
215 /// > - The :authority pseudo-header in case of HTTP/2 or HTTP/3
216 /// > - The Host header.
217 ///
218 /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
219 fn server_host(&self) -> Result<(String, u16), ServerHostError>;
220
221 /// Returns the matched path, like `/object/:object_id/tags`.
222 ///
223 /// The returned path has low cardinality. It will never contain any path
224 /// or query parameter. This behaviour is suggested by the conventions
225 /// specification.
226 fn matched_path(&self) -> Option<&MatchedPath>;
227
228 /// Returns the span name.
229 ///
230 /// The format is either `{method} {http.route}` or `{method}` if
231 /// `http.route` is not available. Examples are:
232 ///
233 /// - `GET /object/:object_id/tags`
234 /// - `PUT /upload/:file_id`
235 /// - `POST /convert`
236 /// - `OPTIONS`
237 fn span_name(&self) -> String;
238
239 /// Returns the user agent, if available.
240 fn user_agent(&self) -> Option<&str>;
241}
242
243impl RequestExt for Request {
244 fn server_host(&self) -> Result<(String, u16), ServerHostError> {
245 // There is currently no obvious way to use the Host extractor from Axum
246 // directly. Using that extractor either requires impossible code (async
247 // in the Service's call function, unnecessary cloning or consuming self
248 // and returning a newly created request). That's why the following
249 // section mirrors the Axum extractor implementation. The implementation
250 // currently only looks for the X-Forwarded-Host / Host header and falls
251 // back to the request URI host. The Axum implementation also extracts
252 // data from the Forwarded header.
253
254 if let Some(host) = self
255 .headers()
256 .get(X_FORWARDED_HOST_HEADER_KEY)
257 .and_then(|host| host.to_str().ok())
258 {
259 return server_host_to_tuple(host, self.uri().scheme_str());
260 }
261
262 if let Some(host) = self.headers().get(HOST).and_then(|host| host.to_str().ok()) {
263 return server_host_to_tuple(host, self.uri().scheme_str());
264 }
265
266 if let (Some(host), Some(port)) = (self.uri().host(), self.uri().port_u16()) {
267 return Ok((host.to_owned(), port));
268 }
269
270 ExtractHostSnafu.fail()
271 }
272
273 fn client_socket_address(&self) -> Option<SocketAddr> {
274 self.extensions()
275 .get::<ConnectInfo<SocketAddr>>()
276 .map(|ci| ci.0)
277 }
278
279 fn matched_path(&self) -> Option<&MatchedPath> {
280 self.extensions().get::<MatchedPath>()
281 }
282
283 fn span_name(&self) -> String {
284 let http_method = self.method().as_str();
285
286 match self.matched_path() {
287 Some(matched_path) => format!("{http_method} {}", matched_path.as_str()),
288 None => http_method.to_string(),
289 }
290 }
291
292 fn user_agent(&self) -> Option<&str> {
293 self.headers()
294 .get(USER_AGENT)
295 .map(|ua| ua.to_str().unwrap_or_default())
296 }
297}
298
299fn server_host_to_tuple(
300 host: &str,
301 scheme: Option<&str>,
302) -> Result<(String, u16), ServerHostError> {
303 if let Some((host, port)) = host.split_once(':') {
304 // First, see if the host header value contains a colon indicating that
305 // it includes a non-default port.
306 let port: u16 = port.parse().context(ParsePortSnafu { port })?;
307 Ok((host.to_owned(), port))
308 } else {
309 // If there is no port included in the header value, the port is implied.
310 // Port 443 for HTTPS and port 80 for HTTP.
311 let port = match scheme {
312 Some("https") => DEFAULT_HTTPS_PORT,
313 Some("http") => DEFAULT_HTTP_PORT,
314 Some(scheme) => return InvalidSchemeSnafu { scheme }.fail(),
315 _ => return InvalidSchemeSnafu { scheme: "" }.fail(),
316 };
317
318 Ok((host.to_owned(), port))
319 }
320}
321
322/// This trait provides various helper functions to create a [`Span`] out of
323/// an HTTP [`Request`].
324pub trait SpanExt {
325 /// Create a span according to the semantic conventions for HTTP spans from
326 /// an Axum [`Request`].
327 ///
328 /// The individual fields are defined in [this specification][1]. Some of
329 /// them are:
330 ///
331 /// - `http.request.method`
332 /// - `http.response.status_code`
333 /// - `network.protocol.version`
334 ///
335 /// Setting the `opt_in` parameter to `true` enables various fields marked
336 /// as opt-in by the specification. This will require more computing power
337 /// and will increase the latency.
338 ///
339 /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
340 fn from_request(req: &Request, opt_in: bool) -> Self;
341
342 /// Adds HTTP request headers to the span as a `http.request.header.<key>`
343 /// field.
344 ///
345 /// NOTE: This is currently not supported, because [`tracing`] doesn't
346 /// support recording dynamic fields.
347 fn add_header_fields(&self, headers: &HeaderMap);
348
349 /// Finalize the [`Span`] with an Axum [`Response`].
350 fn finalize_with_response(&self, response: &mut Response);
351
352 /// Finalize the [`Span`] with an error.
353 fn finalize_with_error<E>(&self, error: &mut E)
354 where
355 E: std::error::Error;
356
357 /// Finalize the [`Span`] with a result.
358 ///
359 /// The default implementation internally calls:
360 ///
361 /// - [`SpanExt::finalize_with_response`] when [`Ok`]
362 /// - [`SpanExt::finalize_with_error`] when [`Err`]
363 fn finalize<E>(&self, result: &mut Result<Response, E>)
364 where
365 E: std::error::Error,
366 {
367 match result {
368 Ok(response) => self.finalize_with_response(response),
369 Err(error) => self.finalize_with_error(error),
370 }
371 }
372}
373
374impl SpanExt for Span {
375 fn from_request(req: &Request, opt_in: bool) -> Self {
376 let http_method = req.method().as_str();
377 let span_name = req.span_name();
378 let url = req.uri();
379
380 tracing::trace!(
381 http_method,
382 span_name,
383 ?url,
384 "extracted http method, span name and request url"
385 );
386
387 // The span name follows the format `{method} {http.route}` defined
388 // by the semantic conventions spec from the OpenTelemetry project.
389 // Currently, the tracing crate doesn't allow non 'static span names,
390 // and thus, the special field otel.name is used to set the span name.
391 // The span name defined in the trace_span macro only serves as a
392 // placeholder.
393 //
394 // - https://docs.rs/tracing-opentelemetry/latest/tracing_opentelemetry/#special-fields
395 // - https://github.com/tokio-rs/tracing/issues/1047
396 // - https://github.com/tokio-rs/tracing/pull/732
397 //
398 // Additionally we cannot use consts for field names. There was an
399 // upstream PR to add support for it, but it was unexpectedly closed.
400 // See https://github.com/tokio-rs/tracing/pull/2254.
401 //
402 // If this is eventually supported (maybe with our efforts), we can use
403 // the opentelemetry-semantic-conventions crate, see here:
404 // https://docs.rs/opentelemetry-semantic-conventions/latest/opentelemetry_semantic_conventions/index.html
405
406 // Setting common fields first
407 // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#common-attributes
408
409 let span = tracing::trace_span!(
410 "HTTP request",
411 { OTEL_NAME } = span_name,
412 { OTEL_KIND } = ?SpanKind::Server,
413 { semconv::attribute::OTEL_STATUS_CODE } = Empty,
414 // The current tracing-opentelemetry version still uses the old semantic convention
415 // See https://github.com/tokio-rs/tracing-opentelemetry/pull/209
416 { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
417 { semconv::trace::HTTP_REQUEST_METHOD } = http_method,
418 { semconv::trace::HTTP_RESPONSE_STATUS_CODE } = Empty,
419 { semconv::trace::HTTP_ROUTE } = Empty,
420 { semconv::trace::URL_PATH } = url.path(),
421 { semconv::trace::URL_QUERY } = url.query(),
422 { semconv::trace::URL_SCHEME } = url.scheme_str().unwrap_or_default(),
423 { semconv::trace::USER_AGENT_ORIGINAL } = Empty,
424 { semconv::trace::SERVER_ADDRESS } = Empty,
425 { semconv::trace::SERVER_PORT } = Empty,
426 { semconv::trace::CLIENT_ADDRESS } = Empty,
427 { semconv::trace::CLIENT_PORT } = Empty,
428 // TODO (@Techassi): Add network.protocol.version
429 );
430
431 // Set the parent span based on the extracted context
432 //
433 // The OpenTelemetry spec does not allow trace ids to be updated after
434 // a span has been created. Since the (optional) new trace id given by
435 // a client is only knowable after handling the request, it is not
436 // available to the existing parent spans for the lower layers (tcp/tls
437 // handling).
438 //
439 // Therefore, we have to made a decision about linking the two traces.
440 // These are the options:
441 // 1. Link to the trace id supplied in the incoming request, or
442 // 2. Link to the current trace id, then set the parent context based on
443 // trace information supplied in the incoming request.
444 //
445 // Neither is ideal, as it means there are (at least) two traces to look
446 // at to get a complete picture of what happened over the request.
447 //
448 // Option 1 is not viable, as the trace id in the response headers will
449 // not be the same as what the client sent. Yet we are supposed to pass
450 // their trace id in any further requests.
451 //
452 // We will go with option 2 as it at least keeps the higher layer spans
453 // in one trace, which is likely going to be more useful to the person
454 // visualizing the traces.
455 let new_parent = HeaderExtractor::new(req.headers()).extract_context();
456 let new_span_context = new_parent.span().span_context().clone();
457 let current_span_context = Context::current().span().span_context().clone();
458
459 if new_span_context != current_span_context {
460 tracing::trace!(
461 { OTEL_TRACE_ID_FROM } = ?current_span_context.trace_id(),
462 { OTEL_TRACE_ID_TO } = ?new_span_context.trace_id(),
463 "set parent span context based on context extracted from request headers"
464 );
465
466 Span::current().add_link(new_parent.span().span_context().clone());
467 span.add_link(Context::current().span().span_context().to_owned());
468 let _ = span.set_parent(new_parent);
469 }
470
471 if let Some(user_agent) = req.user_agent() {
472 span.record(semconv::trace::USER_AGENT_ORIGINAL, user_agent);
473 }
474
475 // Setting server.address and server.port
476 // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
477
478 if let Ok((host, port)) = req.server_host() {
479 // NOTE (@Techassi): We cast to i64, because otherwise the field
480 // will NOT be recorded as a number but as a string. This is likely
481 // an issue in the tracing-opentelemetry crate.
482 span.record(semconv::trace::SERVER_ADDRESS, host)
483 .record(semconv::trace::SERVER_PORT, port as i64);
484 }
485
486 // Setting fields according to the HTTP server semantic conventions
487 // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server-semantic-conventions
488
489 if let Some(client_socket_address) = req.client_socket_address() {
490 span.record(
491 semconv::trace::CLIENT_ADDRESS,
492 client_socket_address.ip().to_string(),
493 );
494
495 if opt_in {
496 // NOTE (@Techassi): We cast to i64, because otherwise the field
497 // will NOT be recorded as a number but as a string. This is
498 // likely an issue in the tracing-opentelemetry crate.
499 span.record(
500 semconv::trace::CLIENT_PORT,
501 client_socket_address.port() as i64,
502 );
503 }
504 }
505
506 // Only include the headers if the user opted in, because this might
507 // potentially be an expensive operation when many different headers
508 // are present. The OpenTelemetry spec also marks this as opt-in.
509
510 // FIXME (@Techassi): Currently, tracing doesn't support recording
511 // fields which are not registered at span creation which thus makes it
512 // impossible to record request headers at runtime.
513 // See: https://github.com/tokio-rs/tracing/issues/1343
514
515 if let Some(http_route) = req.matched_path() {
516 span.record(semconv::trace::HTTP_ROUTE, http_route.as_str());
517 }
518
519 span
520 }
521
522 fn add_header_fields(&self, headers: &HeaderMap) {
523 for (header_name, header_value) in headers {
524 // TODO (@Techassi): Add an allow list for header names
525 // TODO (@Techassi): Handle multiple headers with the same name
526
527 // header_name.as_str() always returns lowercase strings and thus we
528 // don't need to call to_lowercase on it.
529 let header_name = header_name.as_str();
530 let field_name = format!("{HTTP_REQUEST_HEADER}.{header_name}");
531
532 self.record(
533 field_name.as_str(),
534 header_value.to_str().unwrap_or_default(),
535 );
536 }
537 }
538
539 fn finalize_with_response(&self, response: &mut Response) {
540 let status_code = response.status();
541
542 // NOTE (@Techassi): We cast to i64, because otherwise the field will
543 // NOT be recorded as a number but as a string. This is likely an issue
544 // in the tracing-opentelemetry crate.
545 self.record(
546 semconv::trace::HTTP_RESPONSE_STATUS_CODE,
547 status_code.as_u16() as i64,
548 );
549
550 // Only set the span status to "Error" when we encountered an server
551 // error. See:
552 //
553 // - https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status
554 // - https://github.com/open-telemetry/opentelemetry-specification/blob/v1.26.0/specification/trace/api.md#set-status
555 if status_code.is_server_error() {
556 self.record(semconv::attribute::OTEL_STATUS_CODE, "Error");
557 // NOTE (@Techassi): Can we add a status_description here as well?
558 }
559
560 let mut injector = HeaderInjector::new(response.headers_mut());
561 injector.inject_context(&Span::current().context());
562 }
563
564 fn finalize_with_error<E>(&self, error: &mut E)
565 where
566 E: std::error::Error,
567 {
568 // NOTE (@Techassi): This field might get renamed: https://github.com/tokio-rs/tracing-opentelemetry/issues/115
569 // NOTE (@Techassi): It got renamed, a fixed version of tracing-opentelemetry is not available yet
570 self.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
571 .record(
572 semconv::attribute::OTEL_STATUS_DESCRIPTION,
573 error.to_string(),
574 );
575 }
576}
577
578#[cfg(test)]
579mod test {
580 use axum::{Router, routing::get};
581
582 use super::*;
583
584 #[tokio::test]
585 async fn test() {
586 let trace_layer = TraceLayer::new();
587 let router = Router::new()
588 .route("/", get(|| async { "Hello, World!" }))
589 .layer(trace_layer);
590
591 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
592 axum::serve(listener, router)
593 .with_graceful_shutdown(tokio::time::sleep(std::time::Duration::from_secs(1)))
594 .await
595 .unwrap();
596 }
597}