stackable_webhook/tls/
mod.rs

1//! This module contains structs and functions to easily create a TLS termination
2//! server, which can be used in combination with an Axum [`Router`].
3use std::{convert::Infallible, net::SocketAddr, sync::Arc};
4
5use axum::{
6    Router,
7    extract::{ConnectInfo, Request},
8    middleware::AddExtension,
9};
10use hyper::{body::Incoming, service::service_fn};
11use hyper_util::rt::{TokioExecutor, TokioIo};
12use opentelemetry::trace::{FutureExt, SpanKind};
13use opentelemetry_semantic_conventions as semconv;
14use snafu::{ResultExt, Snafu};
15use stackable_shared::time::Duration;
16use tokio::{
17    net::{TcpListener, TcpStream},
18    sync::mpsc,
19};
20use tokio_rustls::{
21    TlsAcceptor,
22    rustls::{
23        ServerConfig,
24        crypto::ring::default_provider,
25        version::{TLS12, TLS13},
26    },
27};
28use tower::{Service, ServiceExt};
29use tracing::{Instrument, Span, field::Empty, instrument};
30use tracing_opentelemetry::OpenTelemetrySpanExt;
31use x509_cert::Certificate;
32
33use crate::{
34    options::WebhookOptions,
35    tls::cert_resolver::{CertificateResolver, CertificateResolverError},
36};
37
38mod cert_resolver;
39
40pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_hours_unchecked(24);
41pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_hours_unchecked(24);
42pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_hours_unchecked(20);
43
44pub type Result<T, E = TlsServerError> = std::result::Result<T, E>;
45
46#[derive(Debug, Snafu)]
47pub enum TlsServerError {
48    #[snafu(display("failed to create certificate resolver"))]
49    CreateCertificateResolver { source: CertificateResolverError },
50
51    #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))]
52    BindTcpListener {
53        source: std::io::Error,
54        socket_addr: SocketAddr,
55    },
56
57    #[snafu(display("failed to rotate certificate"))]
58    RotateCertificate { source: CertificateResolverError },
59
60    #[snafu(display("failed to set safe TLS protocol versions"))]
61    SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error },
62}
63
64/// A server which terminates TLS connections and allows clients to communicate
65/// via HTTPS with the underlying HTTP router.
66///
67/// It also rotates the generated certificates as needed.
68pub struct TlsServer {
69    config: ServerConfig,
70    cert_resolver: Arc<CertificateResolver>,
71
72    socket_addr: SocketAddr,
73    router: Router,
74}
75
76impl TlsServer {
77    /// Create a new [`TlsServer`].
78    ///
79    /// This internally creates a `CertificateResolver` with the provided
80    /// `subject_alterative_dns_names`, which takes care of the certificate rotation. Afterwards it
81    /// creates the [`ServerConfig`], which let's the `CertificateResolver` provide the needed
82    /// certificates.
83    #[instrument(name = "create_tls_server", skip(router))]
84    pub async fn new(
85        router: Router,
86        options: WebhookOptions,
87    ) -> Result<(Self, mpsc::Receiver<Certificate>)> {
88        let (cert_tx, cert_rx) = mpsc::channel(1);
89
90        let WebhookOptions {
91            socket_addr,
92            subject_alterative_dns_names,
93        } = options;
94
95        let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, cert_tx)
96            .await
97            .context(CreateCertificateResolverSnafu)?;
98        let cert_resolver = Arc::new(cert_resolver);
99
100        let tls_provider = default_provider();
101        let mut config = ServerConfig::builder_with_provider(tls_provider.into())
102            .with_protocol_versions(&[&TLS12, &TLS13])
103            .context(SetSafeTlsProtocolVersionsSnafu)?
104            .with_no_client_auth()
105            .with_cert_resolver(cert_resolver.clone());
106        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
107
108        let tls_server = Self {
109            config,
110            cert_resolver,
111            socket_addr,
112            router,
113        };
114
115        Ok((tls_server, cert_rx))
116    }
117
118    /// Runs the TLS server by listening for incoming TCP connections on the
119    /// bound socket address. It only accepts TLS connections. Internally each
120    /// TLS stream get handled by a Hyper service, which in turn is an Axum
121    /// router.
122    ///
123    /// It also starts a background task to rotate the certificate as needed.
124    pub async fn run(self) -> Result<()> {
125        let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
126        let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
127
128        let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));
129        let tcp_listener =
130            TcpListener::bind(self.socket_addr)
131                .await
132                .context(BindTcpListenerSnafu {
133                    socket_addr: self.socket_addr,
134                })?;
135
136        // To be able to extract the connect info from incoming requests, it is
137        // required to turn the router into a Tower service which is capable of
138        // doing that. Calling `into_make_service_with_connect_info` returns a
139        // new struct `IntoMakeServiceWithConnectInfo` which implements the
140        // Tower Service trait. This service is called after the TCP connection
141        // has been accepted.
142        //
143        // Inspired by:
144        // - https://github.com/tokio-rs/axum/discussions/2397
145        // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98
146
147        let mut router = self
148            .router
149            .into_make_service_with_connect_info::<SocketAddr>();
150
151        loop {
152            let tls_acceptor = tls_acceptor.clone();
153
154            // Wait for either a new TCP connection or the certificate rotation interval tick
155            tokio::select! {
156                // We opt for a biased execution of arms to make sure we always check if the
157                // certificate needs rotation based on the interval. This ensures, we always use
158                // a valid certificate for the TLS connection.
159                biased;
160
161                // This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed.
162                // As such, we will not miss rotating the certificate.
163                _ = interval.tick() => {
164                    self.cert_resolver
165                        .rotate_certificate()
166                        .await
167                        .context(RotateCertificateSnafu)?
168                }
169
170                // This is cancellation-safe. If cancelled, no new connections are accepted.
171                tcp_connection = tcp_listener.accept() => {
172                    let (tcp_stream, remote_addr) = match tcp_connection {
173                        Ok((stream, addr)) => (stream, addr),
174                        Err(err) => {
175                            tracing::trace!(%err, "failed to accept incoming TCP connection");
176                            continue;
177                        }
178                    };
179
180                    // Here, the connect info is extracted by calling Tower's Service
181                    // trait function on `IntoMakeServiceWithConnectInfo`
182                    let tower_service: Result<_, Infallible> = router.call(remote_addr).await;
183                    let tower_service = tower_service.expect("Infallible error can never happen");
184
185                    let span = tracing::debug_span!("accept tcp connection");
186                    tokio::spawn(
187                        async move {
188                            Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr)
189                            .instrument(span)
190                            .await
191                        }
192                    );
193                }
194            };
195        }
196    }
197
198    async fn handle_request(
199        tcp_stream: TcpStream,
200        remote_addr: SocketAddr,
201        tls_acceptor: TlsAcceptor,
202        tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
203        socket_addr: SocketAddr,
204    ) {
205        let span = tracing::trace_span!(
206            "accept tls connection",
207            "otel.kind" = ?SpanKind::Server,
208            { semconv::attribute::OTEL_STATUS_CODE } = Empty,
209            { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
210            { semconv::trace::CLIENT_ADDRESS } = remote_addr.ip().to_string(),
211            { semconv::trace::CLIENT_PORT } = remote_addr.port() as i64,
212            { semconv::trace::SERVER_ADDRESS } = Empty,
213            { semconv::trace::SERVER_PORT } = Empty,
214            { semconv::trace::NETWORK_PEER_ADDRESS } = remote_addr.ip().to_string(),
215            { semconv::trace::NETWORK_PEER_PORT } = remote_addr.port() as i64,
216            { semconv::trace::NETWORK_LOCAL_ADDRESS } = Empty,
217            { semconv::trace::NETWORK_LOCAL_PORT } = Empty,
218            { semconv::trace::NETWORK_TRANSPORT } = "tcp",
219            { semconv::trace::NETWORK_TYPE } = socket_addr.semantic_convention_network_type(),
220        );
221
222        if let Ok(local_addr) = tcp_stream.local_addr() {
223            let addr = &local_addr.ip().to_string();
224            let port = local_addr.port();
225            span.record(semconv::trace::SERVER_ADDRESS, addr)
226                .record(semconv::trace::SERVER_PORT, port as i64)
227                .record(semconv::trace::NETWORK_LOCAL_ADDRESS, addr)
228                .record(semconv::trace::NETWORK_LOCAL_PORT, port as i64);
229        }
230
231        // Wait for tls handshake to happen
232        let tls_stream = match tls_acceptor
233            .accept(tcp_stream)
234            .instrument(span.clone())
235            .await
236        {
237            Ok(tls_stream) => tls_stream,
238            Err(err) => {
239                span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
240                    .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
241                tracing::trace!(%remote_addr, "error during tls handshake connection");
242                return;
243            }
244        };
245
246        // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
247        // `TokioIo` converts between them.
248        let tls_stream = TokioIo::new(tls_stream);
249
250        // Hyper also has its own `Service` trait and doesn't use tower. We can use
251        // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
252        // `tower::Service::call`.
253        let hyper_service = service_fn(move |request: Request<Incoming>| {
254            // This carries the current context with the trace id so that the TraceLayer can use that as a parent
255            let otel_context = Span::current().context();
256            // We need to clone here, because oneshot consumes self
257            tower_service
258                .clone()
259                .oneshot(request)
260                .with_context(otel_context)
261        });
262
263        let span = tracing::debug_span!("serve connection");
264        hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
265            .serve_connection_with_upgrades(tls_stream, hyper_service)
266            .instrument(span.clone())
267            .await
268            .unwrap_or_else(|err| {
269                span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
270                    .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
271                tracing::warn!(%err, %remote_addr, "failed to serve connection");
272            })
273    }
274}
275
276pub trait SocketAddrExt {
277    fn semantic_convention_network_type(&self) -> &'static str;
278}
279
280impl SocketAddrExt for SocketAddr {
281    fn semantic_convention_network_type(&self) -> &'static str {
282        match self {
283            SocketAddr::V4(_) => "ipv4",
284            SocketAddr::V6(_) => "ipv6",
285        }
286    }
287}
288
289// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields
290// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields