stackable_telemetry/instrumentation/axum/
mod.rs

1//! This module contains types which can be used as [`axum`] layers to produce
2//! [OpenTelemetry][1] compatible [HTTP spans][2].
3//!
4//! These spans include a wide variety of fields / attributes defined by the
5//! semantic conventions specification. A few examples are:
6//!
7//! - `http.request.method`
8//! - `http.response.status_code`
9//! - `user_agent.original`
10//!
11//! [1]: https://opentelemetry.io/
12//! [2]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
13use std::{future::Future, net::SocketAddr, num::ParseIntError, task::Poll};
14
15use axum::{
16    extract::{ConnectInfo, MatchedPath, Request},
17    http::{
18        HeaderMap,
19        header::{HOST, USER_AGENT},
20    },
21    response::Response,
22};
23use futures_util::ready;
24use opentelemetry::{
25    Context,
26    trace::{SpanKind, TraceContextExt},
27};
28use opentelemetry_semantic_conventions as semconv;
29use opentelemetry_semantic_conventions::trace::HTTP_REQUEST_HEADER;
30use pin_project::pin_project;
31use snafu::{ResultExt, Snafu};
32use tower::{Layer, Service};
33use tracing::{Span, field::Empty};
34use tracing_opentelemetry::OpenTelemetrySpanExt;
35
36mod extractor;
37mod injector;
38
39pub use extractor::*;
40pub use injector::*;
41
42const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
43const DEFAULT_HTTPS_PORT: u16 = 443;
44const DEFAULT_HTTP_PORT: u16 = 80;
45
46// NOTE (@Techassi): These constants are defined here because they are private in
47// the tracing-opentelemetry crate.
48const OTEL_NAME: &str = "otel.name";
49const OTEL_KIND: &str = "otel.kind";
50
51const OTEL_TRACE_ID_FROM: &str = "opentelemetry.trace_id.from";
52const OTEL_TRACE_ID_TO: &str = "opentelemetry.trace_id.to";
53
54/// A Tower [`Layer`][1] which decorates [`TraceService`].
55///
56/// ### Example with Axum
57///
58/// ```
59/// use stackable_telemetry::AxumTraceLayer;
60/// use axum::{routing::get, Router};
61///
62/// let trace_layer = AxumTraceLayer::new();
63/// let router = Router::new()
64///     .route("/", get(|| async { "Hello, World!" }))
65///     .layer(trace_layer);
66///
67/// # let _: Router = router;
68/// ```
69///
70/// This layer is implemented based on [this][1] official Tower guide.
71///
72/// [1]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md
73#[derive(Clone, Debug, Default)]
74pub struct TraceLayer {
75    opt_in: bool,
76}
77
78impl<S> Layer<S> for TraceLayer {
79    type Service = TraceService<S>;
80
81    fn layer(&self, inner: S) -> Self::Service {
82        TraceService {
83            inner,
84            opt_in: self.opt_in,
85        }
86    }
87}
88
89impl TraceLayer {
90    /// Creates a new default trace layer.
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Enables various fields marked as opt-in by the specification.
96    ///
97    /// This will require more computing power and will increase the latency.
98    /// See <https://opentelemetry.io/docs/specs/semconv/http/http-spans/>
99    pub fn with_opt_in(mut self) -> Self {
100        self.opt_in = true;
101        self
102    }
103}
104
105/// A Tower [`Service`] which injects Span Context into HTTP Response Headers.
106#[derive(Debug, Clone)]
107pub struct TraceService<S> {
108    inner: S,
109    opt_in: bool,
110}
111
112impl<S> Service<Request> for TraceService<S>
113where
114    S: Service<Request, Response = Response> + Send + 'static,
115    S::Error: std::error::Error + 'static,
116    S::Future: Send + 'static,
117{
118    type Error = S::Error;
119    type Future = ResponseFuture<S::Future>;
120    type Response = S::Response;
121
122    fn poll_ready(
123        &mut self,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), Self::Error>> {
126        self.inner.poll_ready(cx)
127    }
128
129    fn call(&mut self, req: Request) -> Self::Future {
130        let span = Span::from_request(&req, self.opt_in);
131
132        let future = {
133            let _guard = span.enter();
134            self.inner.call(req)
135        };
136
137        ResponseFuture { future, span }
138    }
139}
140
141/// This future contains the inner service future and the current [`Span`].
142#[pin_project]
143pub struct ResponseFuture<F> {
144    #[pin]
145    pub(crate) future: F,
146    pub(crate) span: Span,
147}
148
149impl<F, E> Future for ResponseFuture<F>
150where
151    F: Future<Output = Result<Response, E>>,
152    E: std::error::Error + 'static,
153{
154    type Output = Result<Response, E>;
155
156    fn poll(
157        self: std::pin::Pin<&mut Self>,
158        cx: &mut std::task::Context<'_>,
159    ) -> std::task::Poll<Self::Output> {
160        let this = self.project();
161        let _guard = this.span.enter();
162
163        let mut result = ready!(this.future.poll(cx));
164        this.span.finalize(&mut result);
165
166        Poll::Ready(result)
167    }
168}
169
170/// Errors which can be encountered when extracting the server host from a [`Request`].
171#[derive(Debug, Snafu)]
172pub enum ServerHostError {
173    /// Indicates that parsing the port of the server host from the [`Request`] as a `u16` failed.
174    #[snafu(display("failed to parse port {port:?} as u16 from string"))]
175    ParsePort {
176        #[allow(missing_docs)]
177        source: ParseIntError,
178
179        // TODO (@Techassi): Make snafu re-emit this
180        /// The original input which was attempted to be parsed.
181        port: String,
182    },
183
184    /// Indicates that the server host from the [`Request`] contains an invalid/unknown scheme.
185    #[snafu(display("encountered invalid request scheme {scheme:?}"))]
186    InvalidScheme {
187        /// The original scheme.
188        scheme: String,
189    },
190
191    // TODO (@Techassi): Make snafu re-emit this
192    /// Indicates that no method of extracting the server host from the [`Request`] succeeded.
193    #[snafu(display("failed to extract any host information from request"))]
194    ExtractHost,
195}
196
197/// This trait provides various helper functions to extract data from a HTTP [`Request`].
198pub trait RequestExt {
199    /// Returns the client socket address, if available.
200    fn client_socket_address(&self) -> Option<SocketAddr>;
201
202    /// Returns the server host, if available.
203    ///
204    /// ### Value Selection Strategy
205    ///
206    /// The following value selection strategy is taken verbatim from [this][1]
207    /// section of the HTTP span semantic conventions:
208    ///
209    /// > HTTP server instrumentations SHOULD do the best effort when populating
210    /// > server.address and server.port attributes and SHOULD determine them by
211    /// > using the first of the following that applies:
212    /// >
213    /// > - The original host which may be passed by the reverse proxy in the
214    /// >  Forwarded#host, X-Forwarded-Host, or a similar header.
215    /// > - The :authority pseudo-header in case of HTTP/2 or HTTP/3
216    /// > - The Host header.
217    ///
218    /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
219    fn server_host(&self) -> Result<(String, u16), ServerHostError>;
220
221    /// Returns the matched path, like `/object/:object_id/tags`.
222    ///
223    /// The returned path has low cardinality. It will never contain any path
224    /// or query parameter. This behaviour is suggested by the conventions
225    /// specification.
226    fn matched_path(&self) -> Option<&MatchedPath>;
227
228    /// Returns the span name.
229    ///
230    /// The format is either `{method} {http.route}` or `{method}` if
231    /// `http.route` is not available. Examples are:
232    ///
233    /// - `GET /object/:object_id/tags`
234    /// - `PUT /upload/:file_id`
235    /// - `POST /convert`
236    /// - `OPTIONS`
237    fn span_name(&self) -> String;
238
239    /// Returns the user agent, if available.
240    fn user_agent(&self) -> Option<&str>;
241}
242
243impl RequestExt for Request {
244    fn server_host(&self) -> Result<(String, u16), ServerHostError> {
245        // There is currently no obvious way to use the Host extractor from Axum
246        // directly. Using that extractor either requires impossible code (async
247        // in the Service's call function, unnecessary cloning or consuming self
248        // and returning a newly created request). That's why the following
249        // section mirrors the Axum extractor implementation. The implementation
250        // currently only looks for the X-Forwarded-Host / Host header and falls
251        // back to the request URI host. The Axum implementation also extracts
252        // data from the Forwarded header.
253
254        if let Some(host) = self
255            .headers()
256            .get(X_FORWARDED_HOST_HEADER_KEY)
257            .and_then(|host| host.to_str().ok())
258        {
259            return server_host_to_tuple(host, self.uri().scheme_str());
260        }
261
262        if let Some(host) = self.headers().get(HOST).and_then(|host| host.to_str().ok()) {
263            return server_host_to_tuple(host, self.uri().scheme_str());
264        }
265
266        if let (Some(host), Some(port)) = (self.uri().host(), self.uri().port_u16()) {
267            return Ok((host.to_owned(), port));
268        }
269
270        ExtractHostSnafu.fail()
271    }
272
273    fn client_socket_address(&self) -> Option<SocketAddr> {
274        self.extensions()
275            .get::<ConnectInfo<SocketAddr>>()
276            .map(|ci| ci.0)
277    }
278
279    fn matched_path(&self) -> Option<&MatchedPath> {
280        self.extensions().get::<MatchedPath>()
281    }
282
283    fn span_name(&self) -> String {
284        let http_method = self.method().as_str();
285
286        match self.matched_path() {
287            Some(matched_path) => format!("{http_method} {}", matched_path.as_str()),
288            None => http_method.to_string(),
289        }
290    }
291
292    fn user_agent(&self) -> Option<&str> {
293        self.headers()
294            .get(USER_AGENT)
295            .map(|ua| ua.to_str().unwrap_or_default())
296    }
297}
298
299fn server_host_to_tuple(
300    host: &str,
301    scheme: Option<&str>,
302) -> Result<(String, u16), ServerHostError> {
303    if let Some((host, port)) = host.split_once(':') {
304        // First, see if the host header value contains a colon indicating that
305        // it includes a non-default port.
306        let port: u16 = port.parse().context(ParsePortSnafu { port })?;
307        Ok((host.to_owned(), port))
308    } else {
309        // If there is no port included in the header value, the port is implied.
310        // Port 443 for HTTPS and port 80 for HTTP.
311        let port = match scheme {
312            Some("https") => DEFAULT_HTTPS_PORT,
313            Some("http") => DEFAULT_HTTP_PORT,
314            Some(scheme) => return InvalidSchemeSnafu { scheme }.fail(),
315            _ => return InvalidSchemeSnafu { scheme: "" }.fail(),
316        };
317
318        Ok((host.to_owned(), port))
319    }
320}
321
322/// This trait provides various helper functions to create a [`Span`] out of
323/// an HTTP [`Request`].
324pub trait SpanExt {
325    /// Create a span according to the semantic conventions for HTTP spans from
326    /// an Axum [`Request`].
327    ///
328    /// The individual fields are defined in [this specification][1]. Some of
329    /// them are:
330    ///
331    /// - `http.request.method`
332    /// - `http.response.status_code`
333    /// - `network.protocol.version`
334    ///
335    /// Setting the `opt_in` parameter to `true` enables various fields marked
336    /// as opt-in by the specification. This will require more computing power
337    /// and will increase the latency.
338    ///
339    /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
340    fn from_request(req: &Request, opt_in: bool) -> Self;
341
342    /// Adds HTTP request headers to the span as a `http.request.header.<key>`
343    /// field.
344    ///
345    /// NOTE: This is currently not supported, because [`tracing`] doesn't
346    /// support recording dynamic fields.
347    fn add_header_fields(&self, headers: &HeaderMap);
348
349    /// Finalize the [`Span`] with an Axum [`Response`].
350    fn finalize_with_response(&self, response: &mut Response);
351
352    /// Finalize the [`Span`] with an error.
353    fn finalize_with_error<E>(&self, error: &mut E)
354    where
355        E: std::error::Error;
356
357    /// Finalize the [`Span`] with a result.
358    ///
359    /// The default implementation internally calls:
360    ///
361    /// - [`SpanExt::finalize_with_response`] when [`Ok`]
362    /// - [`SpanExt::finalize_with_error`] when [`Err`]
363    fn finalize<E>(&self, result: &mut Result<Response, E>)
364    where
365        E: std::error::Error,
366    {
367        match result {
368            Ok(response) => self.finalize_with_response(response),
369            Err(error) => self.finalize_with_error(error),
370        }
371    }
372}
373
374impl SpanExt for Span {
375    fn from_request(req: &Request, opt_in: bool) -> Self {
376        let http_method = req.method().as_str();
377        let span_name = req.span_name();
378        let url = req.uri();
379
380        tracing::trace!(
381            http_method,
382            span_name,
383            ?url,
384            "extracted http method, span name and request url"
385        );
386
387        // The span name follows the format `{method} {http.route}` defined
388        // by the semantic conventions spec from the OpenTelemetry project.
389        // Currently, the tracing crate doesn't allow non 'static span names,
390        // and thus, the special field otel.name is used to set the span name.
391        // The span name defined in the trace_span macro only serves as a
392        // placeholder.
393        //
394        // - https://docs.rs/tracing-opentelemetry/latest/tracing_opentelemetry/#special-fields
395        // - https://github.com/tokio-rs/tracing/issues/1047
396        // - https://github.com/tokio-rs/tracing/pull/732
397        //
398        // Additionally we cannot use consts for field names. There was an
399        // upstream PR to add support for it, but it was unexpectedly closed.
400        // See https://github.com/tokio-rs/tracing/pull/2254.
401        //
402        // If this is eventually supported (maybe with our efforts), we can use
403        // the opentelemetry-semantic-conventions crate, see here:
404        // https://docs.rs/opentelemetry-semantic-conventions/latest/opentelemetry_semantic_conventions/index.html
405
406        // Setting common fields first
407        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#common-attributes
408
409        let span = tracing::trace_span!(
410            "HTTP request",
411            { OTEL_NAME } = span_name,
412            { OTEL_KIND } = ?SpanKind::Server,
413            { semconv::attribute::OTEL_STATUS_CODE } = Empty,
414            // The current tracing-opentelemetry version still uses the old semantic convention
415            // See https://github.com/tokio-rs/tracing-opentelemetry/pull/209
416            { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
417            { semconv::trace::HTTP_REQUEST_METHOD } = http_method,
418            { semconv::trace::HTTP_RESPONSE_STATUS_CODE } = Empty,
419            { semconv::trace::HTTP_ROUTE } = Empty,
420            { semconv::trace::URL_PATH } = url.path(),
421            { semconv::trace::URL_QUERY } = url.query(),
422            { semconv::trace::URL_SCHEME } = url.scheme_str().unwrap_or_default(),
423            { semconv::trace::USER_AGENT_ORIGINAL } = Empty,
424            { semconv::trace::SERVER_ADDRESS } = Empty,
425            { semconv::trace::SERVER_PORT } = Empty,
426            { semconv::trace::CLIENT_ADDRESS } = Empty,
427            { semconv::trace::CLIENT_PORT } = Empty,
428            // TODO (@Techassi): Add network.protocol.version
429        );
430
431        // Set the parent span based on the extracted context
432        //
433        // The OpenTelemetry spec does not allow trace ids to be updated after
434        // a span has been created. Since the (optional) new trace id given by
435        // a client is only knowable after handling the request, it is not
436        // available to the existing parent spans for the lower layers (tcp/tls
437        // handling).
438        //
439        // Therefore, we have to made a decision about linking the two traces.
440        // These are the options:
441        // 1. Link to the trace id supplied in the incoming request, or
442        // 2. Link to the current trace id, then set the parent context based on
443        //    trace information supplied in the incoming request.
444        //
445        // Neither is ideal, as it means there are (at least) two traces to look
446        // at to get a complete picture of what happened over the request.
447        //
448        // Option 1 is not viable, as the trace id in the response headers will
449        // not be the same as what the client sent. Yet we are supposed to pass
450        // their trace id in any further requests.
451        //
452        // We will go with option 2 as it at least keeps the higher layer spans
453        // in one trace, which is likely going to be more useful to the person
454        // visualizing the traces.
455        let new_parent = HeaderExtractor::new(req.headers()).extract_context();
456        let new_span_context = new_parent.span().span_context().clone();
457        let current_span_context = Context::current().span().span_context().clone();
458
459        if new_span_context != current_span_context {
460            tracing::trace!(
461                { OTEL_TRACE_ID_FROM } = ?current_span_context.trace_id(),
462                { OTEL_TRACE_ID_TO } = ?new_span_context.trace_id(),
463                "set parent span context based on context extracted from request headers"
464            );
465
466            Span::current().add_link(new_parent.span().span_context().clone());
467            span.add_link(Context::current().span().span_context().to_owned());
468            let _ = span.set_parent(new_parent);
469        }
470
471        if let Some(user_agent) = req.user_agent() {
472            span.record(semconv::trace::USER_AGENT_ORIGINAL, user_agent);
473        }
474
475        // Setting server.address and server.port
476        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
477
478        if let Ok((host, port)) = req.server_host() {
479            // NOTE (@Techassi): We cast to i64, because otherwise the field
480            // will NOT be recorded as a number but as a string. This is likely
481            // an issue in the tracing-opentelemetry crate.
482            span.record(semconv::trace::SERVER_ADDRESS, host)
483                .record(semconv::trace::SERVER_PORT, port as i64);
484        }
485
486        // Setting fields according to the HTTP server semantic conventions
487        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server-semantic-conventions
488
489        if let Some(client_socket_address) = req.client_socket_address() {
490            span.record(
491                semconv::trace::CLIENT_ADDRESS,
492                client_socket_address.ip().to_string(),
493            );
494
495            if opt_in {
496                // NOTE (@Techassi): We cast to i64, because otherwise the field
497                // will NOT be recorded as a number but as a string. This is
498                // likely an issue in the tracing-opentelemetry crate.
499                span.record(
500                    semconv::trace::CLIENT_PORT,
501                    client_socket_address.port() as i64,
502                );
503            }
504        }
505
506        // Only include the headers if the user opted in, because this might
507        // potentially be an expensive operation when many different headers
508        // are present. The OpenTelemetry spec also marks this as opt-in.
509
510        // FIXME (@Techassi): Currently, tracing doesn't support recording
511        // fields which are not registered at span creation which thus makes it
512        // impossible to record request headers at runtime.
513        // See: https://github.com/tokio-rs/tracing/issues/1343
514
515        if let Some(http_route) = req.matched_path() {
516            span.record(semconv::trace::HTTP_ROUTE, http_route.as_str());
517        }
518
519        span
520    }
521
522    fn add_header_fields(&self, headers: &HeaderMap) {
523        for (header_name, header_value) in headers {
524            // TODO (@Techassi): Add an allow list for header names
525            // TODO (@Techassi): Handle multiple headers with the same name
526
527            // header_name.as_str() always returns lowercase strings and thus we
528            // don't need to call to_lowercase on it.
529            let header_name = header_name.as_str();
530            let field_name = format!("{HTTP_REQUEST_HEADER}.{header_name}");
531
532            self.record(
533                field_name.as_str(),
534                header_value.to_str().unwrap_or_default(),
535            );
536        }
537    }
538
539    fn finalize_with_response(&self, response: &mut Response) {
540        let status_code = response.status();
541
542        // NOTE (@Techassi): We cast to i64, because otherwise the field will
543        // NOT be recorded as a number but as a string. This is likely an issue
544        // in the tracing-opentelemetry crate.
545        self.record(
546            semconv::trace::HTTP_RESPONSE_STATUS_CODE,
547            status_code.as_u16() as i64,
548        );
549
550        // Only set the span status to "Error" when we encountered an server
551        // error. See:
552        //
553        // - https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status
554        // - https://github.com/open-telemetry/opentelemetry-specification/blob/v1.26.0/specification/trace/api.md#set-status
555        if status_code.is_server_error() {
556            self.record(semconv::attribute::OTEL_STATUS_CODE, "Error");
557            // NOTE (@Techassi): Can we add a status_description here as well?
558        }
559
560        let mut injector = HeaderInjector::new(response.headers_mut());
561        injector.inject_context(&Span::current().context());
562    }
563
564    fn finalize_with_error<E>(&self, error: &mut E)
565    where
566        E: std::error::Error,
567    {
568        // NOTE (@Techassi): This field might get renamed: https://github.com/tokio-rs/tracing-opentelemetry/issues/115
569        // NOTE (@Techassi): It got renamed, a fixed version of tracing-opentelemetry is not available yet
570        self.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
571            .record(
572                semconv::attribute::OTEL_STATUS_DESCRIPTION,
573                error.to_string(),
574            );
575    }
576}
577
578#[cfg(test)]
579mod test {
580    use axum::{Router, routing::get};
581
582    use super::*;
583
584    #[tokio::test]
585    async fn test() {
586        let trace_layer = TraceLayer::new();
587        let router = Router::new()
588            .route("/", get(|| async { "Hello, World!" }))
589            .layer(trace_layer);
590
591        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
592        axum::serve(listener, router)
593            .with_graceful_shutdown(tokio::time::sleep(std::time::Duration::from_secs(1)))
594            .await
595            .unwrap();
596    }
597}