stackable_telemetry/instrumentation/axum/
mod.rs

1//! This module contains types which can be used as [`axum`] layers to produce
2//! [OpenTelemetry][1] compatible [HTTP spans][2].
3//!
4//! These spans include a wide variety of fields / attributes defined by the
5//! semantic conventions specification. A few examples are:
6//!
7//! - `http.request.method`
8//! - `http.response.status_code`
9//! - `user_agent.original`
10//!
11//! [1]: https://opentelemetry.io/
12//! [2]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
13use std::{future::Future, net::SocketAddr, num::ParseIntError, task::Poll};
14
15use axum::{
16    extract::{ConnectInfo, MatchedPath, Request},
17    http::{
18        HeaderMap,
19        header::{HOST, USER_AGENT},
20    },
21    response::Response,
22};
23use futures_util::ready;
24use opentelemetry::{
25    Context,
26    trace::{SpanKind, TraceContextExt},
27};
28use opentelemetry_semantic_conventions::{
29    attribute::{OTEL_STATUS_CODE, OTEL_STATUS_DESCRIPTION},
30    trace::{
31        CLIENT_ADDRESS, CLIENT_PORT, HTTP_REQUEST_HEADER, HTTP_REQUEST_METHOD,
32        HTTP_RESPONSE_STATUS_CODE, HTTP_ROUTE, SERVER_ADDRESS, SERVER_PORT, URL_PATH, URL_QUERY,
33        URL_SCHEME, USER_AGENT_ORIGINAL,
34    },
35};
36use pin_project::pin_project;
37use snafu::{ResultExt, Snafu};
38use tower::{Layer, Service};
39use tracing::{Span, field::Empty};
40use tracing_opentelemetry::OpenTelemetrySpanExt;
41
42mod extractor;
43mod injector;
44
45pub use extractor::*;
46pub use injector::*;
47
48const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
49const DEFAULT_HTTPS_PORT: u16 = 443;
50const DEFAULT_HTTP_PORT: u16 = 80;
51
52// NOTE (@Techassi): These constants are defined here because they are private in
53// the tracing-opentelemetry crate.
54const OTEL_NAME: &str = "otel.name";
55const OTEL_KIND: &str = "otel.kind";
56
57const OTEL_TRACE_ID_FROM: &str = "opentelemetry.trace_id.from";
58const OTEL_TRACE_ID_TO: &str = "opentelemetry.trace_id.to";
59
60/// A Tower [`Layer`][1] which decorates [`TraceService`].
61///
62/// ### Example with Axum
63///
64/// ```
65/// use stackable_telemetry::AxumTraceLayer;
66/// use axum::{routing::get, Router};
67///
68/// let trace_layer = AxumTraceLayer::new();
69/// let router = Router::new()
70///     .route("/", get(|| async { "Hello, World!" }))
71///     .layer(trace_layer);
72///
73/// # let _: Router = router;
74/// ```
75///
76/// This layer is implemented based on [this][1] official Tower guide.
77///
78/// [1]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md
79#[derive(Clone, Debug, Default)]
80pub struct TraceLayer {
81    opt_in: bool,
82}
83
84impl<S> Layer<S> for TraceLayer {
85    type Service = TraceService<S>;
86
87    fn layer(&self, inner: S) -> Self::Service {
88        TraceService {
89            inner,
90            opt_in: self.opt_in,
91        }
92    }
93}
94
95impl TraceLayer {
96    /// Creates a new default trace layer.
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Enables various fields marked as opt-in by the specification.
102    ///
103    /// This will require more computing power and will increase the latency.
104    /// See <https://opentelemetry.io/docs/specs/semconv/http/http-spans/>
105    pub fn with_opt_in(mut self) -> Self {
106        self.opt_in = true;
107        self
108    }
109}
110
111/// A Tower [`Service`] which injects Span Context into HTTP Response Headers.
112#[derive(Debug, Clone)]
113pub struct TraceService<S> {
114    inner: S,
115    opt_in: bool,
116}
117
118impl<S> Service<Request> for TraceService<S>
119where
120    S: Service<Request, Response = Response> + Send + 'static,
121    S::Error: std::error::Error + 'static,
122    S::Future: Send + 'static,
123{
124    type Error = S::Error;
125    type Future = ResponseFuture<S::Future>;
126    type Response = S::Response;
127
128    fn poll_ready(
129        &mut self,
130        cx: &mut std::task::Context<'_>,
131    ) -> std::task::Poll<Result<(), Self::Error>> {
132        self.inner.poll_ready(cx)
133    }
134
135    fn call(&mut self, req: Request) -> Self::Future {
136        let span = Span::from_request(&req, self.opt_in);
137
138        let future = {
139            let _guard = span.enter();
140            self.inner.call(req)
141        };
142
143        ResponseFuture { future, span }
144    }
145}
146
147/// This future contains the inner service future and the current [`Span`].
148#[pin_project]
149pub struct ResponseFuture<F> {
150    #[pin]
151    pub(crate) future: F,
152    pub(crate) span: Span,
153}
154
155impl<F, E> Future for ResponseFuture<F>
156where
157    F: Future<Output = Result<Response, E>>,
158    E: std::error::Error + 'static,
159{
160    type Output = Result<Response, E>;
161
162    fn poll(
163        self: std::pin::Pin<&mut Self>,
164        cx: &mut std::task::Context<'_>,
165    ) -> std::task::Poll<Self::Output> {
166        let this = self.project();
167        let _guard = this.span.enter();
168
169        let mut result = ready!(this.future.poll(cx));
170        this.span.finalize(&mut result);
171
172        Poll::Ready(result)
173    }
174}
175
176/// Errors which can be encountered when extracting the server host from a [`Request`].
177#[derive(Debug, Snafu)]
178pub enum ServerHostError {
179    /// Indicates that parsing the port of the server host from the [`Request`] as a `u16` failed.
180    #[snafu(display("failed to parse port {port:?} as u16 from string"))]
181    ParsePort {
182        #[allow(missing_docs)]
183        source: ParseIntError,
184
185        // TODO (@Techassi): Make snafu re-emit this
186        /// The original input which was attempted to be parsed.
187        port: String,
188    },
189
190    /// Indicates that the server host from the [`Request`] contains an invalid/unknown scheme.
191    #[snafu(display("encountered invalid request scheme {scheme:?}"))]
192    InvalidScheme {
193        /// The original scheme.
194        scheme: String,
195    },
196
197    // TODO (@Techassi): Make snafu re-emit this
198    /// Indicates that no method of extracting the server host from the [`Request`] succeeded.
199    #[snafu(display("failed to extract any host information from request"))]
200    ExtractHost,
201}
202
203/// This trait provides various helper functions to extract data from a HTTP [`Request`].
204pub trait RequestExt {
205    /// Returns the client socket address, if available.
206    fn client_socket_address(&self) -> Option<SocketAddr>;
207
208    /// Returns the server host, if available.
209    ///
210    /// ### Value Selection Strategy
211    ///
212    /// The following value selection strategy is taken verbatim from [this][1]
213    /// section of the HTTP span semantic conventions:
214    ///
215    /// > HTTP server instrumentations SHOULD do the best effort when populating
216    /// > server.address and server.port attributes and SHOULD determine them by
217    /// > using the first of the following that applies:
218    /// >
219    /// > - The original host which may be passed by the reverse proxy in the
220    /// >  Forwarded#host, X-Forwarded-Host, or a similar header.
221    /// > - The :authority pseudo-header in case of HTTP/2 or HTTP/3
222    /// > - The Host header.
223    ///
224    /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
225    fn server_host(&self) -> Result<(String, u16), ServerHostError>;
226
227    /// Returns the matched path, like `/object/:object_id/tags`.
228    ///
229    /// The returned path has low cardinality. It will never contain any path
230    /// or query parameter. This behaviour is suggested by the conventions
231    /// specification.
232    fn matched_path(&self) -> Option<&MatchedPath>;
233
234    /// Returns the span name.
235    ///
236    /// The format is either `{method} {http.route}` or `{method}` if
237    /// `http.route` is not available. Examples are:
238    ///
239    /// - `GET /object/:object_id/tags`
240    /// - `PUT /upload/:file_id`
241    /// - `POST /convert`
242    /// - `OPTIONS`
243    fn span_name(&self) -> String;
244
245    /// Returns the user agent, if available.
246    fn user_agent(&self) -> Option<&str>;
247}
248
249impl RequestExt for Request {
250    fn server_host(&self) -> Result<(String, u16), ServerHostError> {
251        // There is currently no obvious way to use the Host extractor from Axum
252        // directly. Using that extractor either requires impossible code (async
253        // in the Service's call function, unnecessary cloning or consuming self
254        // and returning a newly created request). That's why the following
255        // section mirrors the Axum extractor implementation. The implementation
256        // currently only looks for the X-Forwarded-Host / Host header and falls
257        // back to the request URI host. The Axum implementation also extracts
258        // data from the Forwarded header.
259
260        if let Some(host) = self
261            .headers()
262            .get(X_FORWARDED_HOST_HEADER_KEY)
263            .and_then(|host| host.to_str().ok())
264        {
265            return server_host_to_tuple(host, self.uri().scheme_str());
266        }
267
268        if let Some(host) = self.headers().get(HOST).and_then(|host| host.to_str().ok()) {
269            return server_host_to_tuple(host, self.uri().scheme_str());
270        }
271
272        if let (Some(host), Some(port)) = (self.uri().host(), self.uri().port_u16()) {
273            return Ok((host.to_owned(), port));
274        }
275
276        ExtractHostSnafu.fail()
277    }
278
279    fn client_socket_address(&self) -> Option<SocketAddr> {
280        self.extensions()
281            .get::<ConnectInfo<SocketAddr>>()
282            .map(|ci| ci.0)
283    }
284
285    fn matched_path(&self) -> Option<&MatchedPath> {
286        self.extensions().get::<MatchedPath>()
287    }
288
289    fn span_name(&self) -> String {
290        let http_method = self.method().as_str();
291
292        match self.matched_path() {
293            Some(matched_path) => format!("{http_method} {}", matched_path.as_str()),
294            None => http_method.to_string(),
295        }
296    }
297
298    fn user_agent(&self) -> Option<&str> {
299        self.headers()
300            .get(USER_AGENT)
301            .map(|ua| ua.to_str().unwrap_or_default())
302    }
303}
304
305fn server_host_to_tuple(
306    host: &str,
307    scheme: Option<&str>,
308) -> Result<(String, u16), ServerHostError> {
309    if let Some((host, port)) = host.split_once(':') {
310        // First, see if the host header value contains a colon indicating that
311        // it includes a non-default port.
312        let port: u16 = port.parse().context(ParsePortSnafu { port })?;
313        Ok((host.to_owned(), port))
314    } else {
315        // If there is no port included in the header value, the port is implied.
316        // Port 443 for HTTPS and port 80 for HTTP.
317        let port = match scheme {
318            Some("https") => DEFAULT_HTTPS_PORT,
319            Some("http") => DEFAULT_HTTP_PORT,
320            Some(scheme) => return InvalidSchemeSnafu { scheme }.fail(),
321            _ => return InvalidSchemeSnafu { scheme: "" }.fail(),
322        };
323
324        Ok((host.to_owned(), port))
325    }
326}
327
328/// This trait provides various helper functions to create a [`Span`] out of
329/// an HTTP [`Request`].
330pub trait SpanExt {
331    /// Create a span according to the semantic conventions for HTTP spans from
332    /// an Axum [`Request`].
333    ///
334    /// The individual fields are defined in [this specification][1]. Some of
335    /// them are:
336    ///
337    /// - `http.request.method`
338    /// - `http.response.status_code`
339    /// - `network.protocol.version`
340    ///
341    /// Setting the `opt_in` parameter to `true` enables various fields marked
342    /// as opt-in by the specification. This will require more computing power
343    /// and will increase the latency.
344    ///
345    /// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
346    fn from_request(req: &Request, opt_in: bool) -> Self;
347
348    /// Adds HTTP request headers to the span as a `http.request.header.<key>`
349    /// field.
350    ///
351    /// NOTE: This is currently not supported, because [`tracing`] doesn't
352    /// support recording dynamic fields.
353    fn add_header_fields(&self, headers: &HeaderMap);
354
355    /// Finalize the [`Span`] with an Axum [`Response`].
356    fn finalize_with_response(&self, response: &mut Response);
357
358    /// Finalize the [`Span`] with an error.
359    fn finalize_with_error<E>(&self, error: &mut E)
360    where
361        E: std::error::Error;
362
363    /// Finalize the [`Span`] with a result.
364    ///
365    /// The default implementation internally calls:
366    ///
367    /// - [`SpanExt::finalize_with_response`] when [`Ok`]
368    /// - [`SpanExt::finalize_with_error`] when [`Err`]
369    fn finalize<E>(&self, result: &mut Result<Response, E>)
370    where
371        E: std::error::Error,
372    {
373        match result {
374            Ok(response) => self.finalize_with_response(response),
375            Err(error) => self.finalize_with_error(error),
376        }
377    }
378}
379
380impl SpanExt for Span {
381    fn from_request(req: &Request, opt_in: bool) -> Self {
382        let http_method = req.method().as_str();
383        let span_name = req.span_name();
384        let url = req.uri();
385
386        tracing::trace!(
387            http_method,
388            span_name,
389            ?url,
390            "extracted http method, span name and request url"
391        );
392
393        // The span name follows the format `{method} {http.route}` defined
394        // by the semantic conventions spec from the OpenTelemetry project.
395        // Currently, the tracing crate doesn't allow non 'static span names,
396        // and thus, the special field otel.name is used to set the span name.
397        // The span name defined in the trace_span macro only serves as a
398        // placeholder.
399        //
400        // - https://docs.rs/tracing-opentelemetry/latest/tracing_opentelemetry/#special-fields
401        // - https://github.com/tokio-rs/tracing/issues/1047
402        // - https://github.com/tokio-rs/tracing/pull/732
403        //
404        // Additionally we cannot use consts for field names. There was an
405        // upstream PR to add support for it, but it was unexpectedly closed.
406        // See https://github.com/tokio-rs/tracing/pull/2254.
407        //
408        // If this is eventually supported (maybe with our efforts), we can use
409        // the opentelemetry-semantic-conventions crate, see here:
410        // https://docs.rs/opentelemetry-semantic-conventions/latest/opentelemetry_semantic_conventions/index.html
411
412        // Setting common fields first
413        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#common-attributes
414
415        let span = tracing::trace_span!(
416            "HTTP request",
417            { OTEL_NAME } = span_name,
418            { OTEL_KIND } = ?SpanKind::Server,
419            { OTEL_STATUS_CODE } = Empty,
420            // The current tracing-opentelemetry version still uses the old semantic convention
421            // See https://github.com/tokio-rs/tracing-opentelemetry/pull/209
422            { OTEL_STATUS_DESCRIPTION } = Empty,
423            { HTTP_REQUEST_METHOD } = http_method,
424            { HTTP_RESPONSE_STATUS_CODE } = Empty,
425            { HTTP_ROUTE } = Empty,
426            { URL_PATH } = url.path(),
427            { URL_QUERY } = url.query(),
428            { URL_SCHEME } = url.scheme_str().unwrap_or_default(),
429            { USER_AGENT_ORIGINAL } = Empty,
430            { SERVER_ADDRESS } = Empty,
431            { SERVER_PORT } = Empty,
432            { CLIENT_ADDRESS } = Empty,
433            { CLIENT_PORT } = Empty,
434            // TODO (@Techassi): Add network.protocol.version
435        );
436
437        // Set the parent span based on the extracted context
438        //
439        // The OpenTelemetry spec does not allow trace ids to be updated after
440        // a span has been created. Since the (optional) new trace id given by
441        // a client is only knowable after handling the request, it is not
442        // available to the existing parent spans for the lower layers (tcp/tls
443        // handling).
444        //
445        // Therefore, we have to made a decision about linking the two traces.
446        // These are the options:
447        // 1. Link to the trace id supplied in the incoming request, or
448        // 2. Link to the current trace id, then set the parent context based on
449        //    trace information supplied in the incoming request.
450        //
451        // Neither is ideal, as it means there are (at least) two traces to look
452        // at to get a complete picture of what happened over the request.
453        //
454        // Option 1 is not viable, as the trace id in the response headers will
455        // not be the same as what the client sent. Yet we are supposed to pass
456        // their trace id in any further requests.
457        //
458        // We will go with option 2 as it at least keeps the higher layer spans
459        // in one trace, which is likely going to be more useful to the person
460        // visualising the traces.
461        let new_parent = HeaderExtractor::new(req.headers()).extract_context();
462        let new_span_context = new_parent.span().span_context().clone();
463        let current_span_context = Context::current().span().span_context().clone();
464
465        if new_span_context != current_span_context {
466            tracing::trace!(
467                { OTEL_TRACE_ID_FROM } = ?current_span_context.trace_id(),
468                { OTEL_TRACE_ID_TO } = ?new_span_context.trace_id(),
469                "set parent span context based on context extracted from request headers"
470            );
471
472            Span::current().add_link(new_parent.span().span_context().clone());
473            span.add_link(Context::current().span().span_context().to_owned());
474            span.set_parent(new_parent);
475        }
476
477        if let Some(user_agent) = req.user_agent() {
478            span.record(USER_AGENT_ORIGINAL, user_agent);
479        }
480
481        // Setting server.address and server.port
482        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
483
484        if let Ok((host, port)) = req.server_host() {
485            // NOTE (@Techassi): We cast to i64, because otherwise the field
486            // will NOT be recorded as a number but as a string. This is likely
487            // an issue in the tracing-opentelemetry crate.
488            span.record(SERVER_ADDRESS, host)
489                .record(SERVER_PORT, port as i64);
490        }
491
492        // Setting fields according to the HTTP server semantic conventions
493        // See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server-semantic-conventions
494
495        if let Some(client_socket_address) = req.client_socket_address() {
496            span.record(CLIENT_ADDRESS, client_socket_address.ip().to_string());
497
498            if opt_in {
499                // NOTE (@Techassi): We cast to i64, because otherwise the field
500                // will NOT be recorded as a number but as a string. This is
501                // likely an issue in the tracing-opentelemetry crate.
502                span.record(CLIENT_PORT, client_socket_address.port() as i64);
503            }
504        }
505
506        // Only include the headers if the user opted in, because this might
507        // potentially be an expensive operation when many different headers
508        // are present. The OpenTelemetry spec also marks this as opt-in.
509
510        // FIXME (@Techassi): Currently, tracing doesn't support recording
511        // fields which are not registered at span creation which thus makes it
512        // impossible to record request headers at runtime.
513        // See: https://github.com/tokio-rs/tracing/issues/1343
514
515        if let Some(http_route) = req.matched_path() {
516            span.record(HTTP_ROUTE, http_route.as_str());
517        }
518
519        span
520    }
521
522    fn add_header_fields(&self, headers: &HeaderMap) {
523        for (header_name, header_value) in headers {
524            // TODO (@Techassi): Add an allow list for header names
525            // TODO (@Techassi): Handle multiple headers with the same name
526
527            // header_name.as_str() always returns lowercase strings and thus we
528            // don't need to call to_lowercase on it.
529            let header_name = header_name.as_str();
530            let field_name = format!("{HTTP_REQUEST_HEADER}.{header_name}");
531
532            self.record(
533                field_name.as_str(),
534                header_value.to_str().unwrap_or_default(),
535            );
536        }
537    }
538
539    fn finalize_with_response(&self, response: &mut Response) {
540        let status_code = response.status();
541
542        // NOTE (@Techassi): We cast to i64, because otherwise the field will
543        // NOT be recorded as a number but as a string. This is likely an issue
544        // in the tracing-opentelemetry crate.
545        self.record(HTTP_RESPONSE_STATUS_CODE, status_code.as_u16() as i64);
546
547        // Only set the span status to "Error" when we encountered an server
548        // error. See:
549        //
550        // - https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status
551        // - https://github.com/open-telemetry/opentelemetry-specification/blob/v1.26.0/specification/trace/api.md#set-status
552        if status_code.is_server_error() {
553            self.record(OTEL_STATUS_CODE, "Error");
554            // NOTE (@Techassi): Can we add a status_description here as well?
555        }
556
557        let mut injector = HeaderInjector::new(response.headers_mut());
558        injector.inject_context(&Span::current().context());
559    }
560
561    fn finalize_with_error<E>(&self, error: &mut E)
562    where
563        E: std::error::Error,
564    {
565        // NOTE (@Techassi): This field might get renamed: https://github.com/tokio-rs/tracing-opentelemetry/issues/115
566        // NOTE (@Techassi): It got renamed, a fixed version of tracing-opentelemetry is not available yet
567        self.record(OTEL_STATUS_CODE, "Error")
568            .record(OTEL_STATUS_DESCRIPTION, error.to_string());
569    }
570}
571
572#[cfg(test)]
573mod test {
574    use axum::{Router, routing::get};
575
576    use super::*;
577
578    #[tokio::test]
579    async fn test() {
580        let trace_layer = TraceLayer::new();
581        let router = Router::new()
582            .route("/", get(|| async { "Hello, World!" }))
583            .layer(trace_layer);
584
585        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
586        axum::serve(listener, router)
587            .with_graceful_shutdown(tokio::time::sleep(std::time::Duration::from_secs(1)))
588            .await
589            .unwrap();
590    }
591}