stackable_webhook/tls/
mod.rs

1//! This module contains structs and functions to easily create a TLS termination
2//! server, which can be used in combination with an Axum [`Router`].
3use std::{convert::Infallible, net::SocketAddr, sync::Arc};
4
5use axum::{
6    Router,
7    extract::{ConnectInfo, Request},
8    middleware::AddExtension,
9};
10use hyper::{body::Incoming, service::service_fn};
11use hyper_util::rt::{TokioExecutor, TokioIo};
12use opentelemetry::trace::{FutureExt, SpanKind};
13use opentelemetry_semantic_conventions as semconv;
14use snafu::{OptionExt, ResultExt, Snafu};
15use stackable_shared::time::Duration;
16use tokio::{
17    net::{TcpListener, TcpStream},
18    sync::mpsc,
19};
20use tokio_rustls::{
21    TlsAcceptor,
22    rustls::{
23        ServerConfig,
24        crypto::CryptoProvider,
25        version::{TLS12, TLS13},
26    },
27};
28use tower::{Service, ServiceExt};
29use tracing::{Instrument, Span, field::Empty, instrument};
30use tracing_opentelemetry::OpenTelemetrySpanExt;
31use x509_cert::Certificate;
32
33use crate::{
34    options::WebhookOptions,
35    tls::cert_resolver::{CertificateResolver, CertificateResolverError},
36};
37
38mod cert_resolver;
39
40pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_hours_unchecked(24);
41pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_hours_unchecked(24);
42pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_hours_unchecked(20);
43
44pub type Result<T, E = TlsServerError> = std::result::Result<T, E>;
45
46#[derive(Debug, Snafu)]
47pub enum TlsServerError {
48    #[snafu(display("failed to create certificate resolver"))]
49    CreateCertificateResolver { source: CertificateResolverError },
50
51    #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))]
52    BindTcpListener {
53        source: std::io::Error,
54        socket_addr: SocketAddr,
55    },
56
57    #[snafu(display("failed to rotate certificate"))]
58    RotateCertificate { source: CertificateResolverError },
59
60    #[snafu(display("failed to set safe TLS protocol versions"))]
61    SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error },
62
63    #[snafu(display("no default rustls CryptoProvider installed"))]
64    NoDefaultCryptoProviderInstalled,
65}
66
67/// A server which terminates TLS connections and allows clients to communicate
68/// via HTTPS with the underlying HTTP router.
69///
70/// It also rotates the generated certificates as needed.
71pub struct TlsServer {
72    config: ServerConfig,
73    cert_resolver: Arc<CertificateResolver>,
74
75    socket_addr: SocketAddr,
76    router: Router,
77}
78
79impl TlsServer {
80    /// Create a new [`TlsServer`].
81    ///
82    /// This internally creates a `CertificateResolver` with the provided
83    /// `subject_alterative_dns_names`, which takes care of the certificate rotation. Afterwards it
84    /// creates the [`ServerConfig`], which let's the `CertificateResolver` provide the needed
85    /// certificates.
86    #[instrument(name = "create_tls_server", skip(router))]
87    pub async fn new(
88        router: Router,
89        options: WebhookOptions,
90    ) -> Result<(Self, mpsc::Receiver<Certificate>)> {
91        let (certificate_tx, certificate_rx) = mpsc::channel(1);
92
93        let WebhookOptions {
94            socket_addr,
95            subject_alterative_dns_names,
96        } = options;
97
98        let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, certificate_tx)
99            .await
100            .context(CreateCertificateResolverSnafu)?;
101        let cert_resolver = Arc::new(cert_resolver);
102
103        let tls_provider =
104            CryptoProvider::get_default().context(NoDefaultCryptoProviderInstalledSnafu)?;
105
106        let mut config = ServerConfig::builder_with_provider(tls_provider.clone())
107            .with_protocol_versions(&[&TLS12, &TLS13])
108            .context(SetSafeTlsProtocolVersionsSnafu)?
109            .with_no_client_auth()
110            .with_cert_resolver(cert_resolver.clone());
111        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
112
113        let tls_server = Self {
114            config,
115            cert_resolver,
116            socket_addr,
117            router,
118        };
119
120        Ok((tls_server, certificate_rx))
121    }
122
123    /// Runs the TLS server by listening for incoming TCP connections on the
124    /// bound socket address. It only accepts TLS connections. Internally each
125    /// TLS stream get handled by a Hyper service, which in turn is an Axum
126    /// router.
127    ///
128    /// It also starts a background task to rotate the certificate as needed.
129    pub async fn run(self) -> Result<()> {
130        let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
131        let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
132
133        let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));
134        let tcp_listener =
135            TcpListener::bind(self.socket_addr)
136                .await
137                .context(BindTcpListenerSnafu {
138                    socket_addr: self.socket_addr,
139                })?;
140
141        // To be able to extract the connect info from incoming requests, it is
142        // required to turn the router into a Tower service which is capable of
143        // doing that. Calling `into_make_service_with_connect_info` returns a
144        // new struct `IntoMakeServiceWithConnectInfo` which implements the
145        // Tower Service trait. This service is called after the TCP connection
146        // has been accepted.
147        //
148        // Inspired by:
149        // - https://github.com/tokio-rs/axum/discussions/2397
150        // - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98
151
152        let mut router = self
153            .router
154            .into_make_service_with_connect_info::<SocketAddr>();
155
156        loop {
157            let tls_acceptor = tls_acceptor.clone();
158
159            // Wait for either a new TCP connection or the certificate rotation interval tick
160            tokio::select! {
161                // We opt for a biased execution of arms to make sure we always check if the
162                // certificate needs rotation based on the interval. This ensures, we always use
163                // a valid certificate for the TLS connection.
164                biased;
165
166                // This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed.
167                // As such, we will not miss rotating the certificate.
168                _ = interval.tick() => {
169                    self.cert_resolver
170                        .rotate_certificate()
171                        .await
172                        .context(RotateCertificateSnafu)?
173                }
174
175                // This is cancellation-safe. If cancelled, no new connections are accepted.
176                tcp_connection = tcp_listener.accept() => {
177                    let (tcp_stream, remote_addr) = match tcp_connection {
178                        Ok((stream, addr)) => (stream, addr),
179                        Err(err) => {
180                            tracing::trace!(%err, "failed to accept incoming TCP connection");
181                            continue;
182                        }
183                    };
184
185                    // Here, the connect info is extracted by calling Tower's Service
186                    // trait function on `IntoMakeServiceWithConnectInfo`
187                    let tower_service: Result<_, Infallible> = router.call(remote_addr).await;
188                    let tower_service = tower_service.expect("Infallible error can never happen");
189
190                    let span = tracing::debug_span!("accept tcp connection");
191                    tokio::spawn(
192                        async move {
193                            Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr)
194                            .instrument(span)
195                            .await
196                        }
197                    );
198                }
199            };
200        }
201    }
202
203    async fn handle_request(
204        tcp_stream: TcpStream,
205        remote_addr: SocketAddr,
206        tls_acceptor: TlsAcceptor,
207        tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
208        socket_addr: SocketAddr,
209    ) {
210        let span = tracing::trace_span!(
211            "accept tls connection",
212            "otel.kind" = ?SpanKind::Server,
213            { semconv::attribute::OTEL_STATUS_CODE } = Empty,
214            { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
215            { semconv::trace::CLIENT_ADDRESS } = remote_addr.ip().to_string(),
216            { semconv::trace::CLIENT_PORT } = remote_addr.port() as i64,
217            { semconv::trace::SERVER_ADDRESS } = Empty,
218            { semconv::trace::SERVER_PORT } = Empty,
219            { semconv::trace::NETWORK_PEER_ADDRESS } = remote_addr.ip().to_string(),
220            { semconv::trace::NETWORK_PEER_PORT } = remote_addr.port() as i64,
221            { semconv::trace::NETWORK_LOCAL_ADDRESS } = Empty,
222            { semconv::trace::NETWORK_LOCAL_PORT } = Empty,
223            { semconv::trace::NETWORK_TRANSPORT } = "tcp",
224            { semconv::trace::NETWORK_TYPE } = socket_addr.semantic_convention_network_type(),
225        );
226
227        if let Ok(local_addr) = tcp_stream.local_addr() {
228            let addr = &local_addr.ip().to_string();
229            let port = local_addr.port();
230            span.record(semconv::trace::SERVER_ADDRESS, addr)
231                .record(semconv::trace::SERVER_PORT, port as i64)
232                .record(semconv::trace::NETWORK_LOCAL_ADDRESS, addr)
233                .record(semconv::trace::NETWORK_LOCAL_PORT, port as i64);
234        }
235
236        // Wait for tls handshake to happen
237        let tls_stream = match tls_acceptor
238            .accept(tcp_stream)
239            .instrument(span.clone())
240            .await
241        {
242            Ok(tls_stream) => tls_stream,
243            Err(err) => {
244                span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
245                    .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
246                tracing::trace!(%remote_addr, "error during tls handshake connection");
247                return;
248            }
249        };
250
251        // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
252        // `TokioIo` converts between them.
253        let tls_stream = TokioIo::new(tls_stream);
254
255        // Hyper also has its own `Service` trait and doesn't use tower. We can use
256        // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
257        // `tower::Service::call`.
258        let hyper_service = service_fn(move |request: Request<Incoming>| {
259            // This carries the current context with the trace id so that the TraceLayer can use that as a parent
260            let otel_context = Span::current().context();
261            // We need to clone here, because oneshot consumes self
262            tower_service
263                .clone()
264                .oneshot(request)
265                .with_context(otel_context)
266        });
267
268        let span = tracing::debug_span!("serve connection");
269        hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
270            .serve_connection_with_upgrades(tls_stream, hyper_service)
271            .instrument(span.clone())
272            .await
273            .unwrap_or_else(|err| {
274                span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
275                    .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
276                tracing::warn!(%err, %remote_addr, "failed to serve connection");
277            })
278    }
279}
280
281pub trait SocketAddrExt {
282    fn semantic_convention_network_type(&self) -> &'static str;
283}
284
285impl SocketAddrExt for SocketAddr {
286    fn semantic_convention_network_type(&self) -> &'static str {
287        match self {
288            SocketAddr::V4(_) => "ipv4",
289            SocketAddr::V6(_) => "ipv6",
290        }
291    }
292}
293
294// TODO (@NickLarsenNZ): impl record_error(err: impl Error) for Span as a shortcut to set otel.status_* fields
295// TODO (@NickLarsenNZ): wrap tracing::span macros to automatically add otel fields