1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
4
5use axum::{
6 Router,
7 extract::{ConnectInfo, Request},
8 middleware::AddExtension,
9};
10use hyper::{body::Incoming, service::service_fn};
11use hyper_util::rt::{TokioExecutor, TokioIo};
12use opentelemetry::trace::{FutureExt, SpanKind};
13use opentelemetry_semantic_conventions as semconv;
14use snafu::{OptionExt, ResultExt, Snafu};
15use stackable_shared::time::Duration;
16use tokio::{
17 net::{TcpListener, TcpStream},
18 sync::mpsc,
19};
20use tokio_rustls::{
21 TlsAcceptor,
22 rustls::{
23 ServerConfig,
24 crypto::CryptoProvider,
25 version::{TLS12, TLS13},
26 },
27};
28use tower::{Service, ServiceExt};
29use tracing::{Instrument, Span, field::Empty, instrument};
30use tracing_opentelemetry::OpenTelemetrySpanExt;
31use x509_cert::Certificate;
32
33use crate::{
34 options::WebhookOptions,
35 tls::cert_resolver::{CertificateResolver, CertificateResolverError},
36};
37
38mod cert_resolver;
39
40pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_hours_unchecked(24);
41pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_hours_unchecked(24);
42pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_hours_unchecked(20);
43
44pub type Result<T, E = TlsServerError> = std::result::Result<T, E>;
45
46#[derive(Debug, Snafu)]
47pub enum TlsServerError {
48 #[snafu(display("failed to create certificate resolver"))]
49 CreateCertificateResolver { source: CertificateResolverError },
50
51 #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))]
52 BindTcpListener {
53 source: std::io::Error,
54 socket_addr: SocketAddr,
55 },
56
57 #[snafu(display("failed to rotate certificate"))]
58 RotateCertificate { source: CertificateResolverError },
59
60 #[snafu(display("failed to set safe TLS protocol versions"))]
61 SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error },
62
63 #[snafu(display("no default rustls CryptoProvider installed"))]
64 NoDefaultCryptoProviderInstalled,
65}
66
67pub struct TlsServer {
72 config: ServerConfig,
73 cert_resolver: Arc<CertificateResolver>,
74
75 socket_addr: SocketAddr,
76 router: Router,
77}
78
79impl TlsServer {
80 #[instrument(name = "create_tls_server", skip(router))]
87 pub async fn new(
88 router: Router,
89 options: WebhookOptions,
90 ) -> Result<(Self, mpsc::Receiver<Certificate>)> {
91 let (certificate_tx, certificate_rx) = mpsc::channel(1);
92
93 let WebhookOptions {
94 socket_addr,
95 subject_alterative_dns_names,
96 } = options;
97
98 let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, certificate_tx)
99 .await
100 .context(CreateCertificateResolverSnafu)?;
101 let cert_resolver = Arc::new(cert_resolver);
102
103 let tls_provider =
104 CryptoProvider::get_default().context(NoDefaultCryptoProviderInstalledSnafu)?;
105
106 let mut config = ServerConfig::builder_with_provider(tls_provider.clone())
107 .with_protocol_versions(&[&TLS12, &TLS13])
108 .context(SetSafeTlsProtocolVersionsSnafu)?
109 .with_no_client_auth()
110 .with_cert_resolver(cert_resolver.clone());
111 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
112
113 let tls_server = Self {
114 config,
115 cert_resolver,
116 socket_addr,
117 router,
118 };
119
120 Ok((tls_server, certificate_rx))
121 }
122
123 pub async fn run(self) -> Result<()> {
130 let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
131 let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
132
133 let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));
134 let tcp_listener =
135 TcpListener::bind(self.socket_addr)
136 .await
137 .context(BindTcpListenerSnafu {
138 socket_addr: self.socket_addr,
139 })?;
140
141 let mut router = self
153 .router
154 .into_make_service_with_connect_info::<SocketAddr>();
155
156 loop {
157 let tls_acceptor = tls_acceptor.clone();
158
159 tokio::select! {
161 biased;
165
166 _ = interval.tick() => {
169 self.cert_resolver
170 .rotate_certificate()
171 .await
172 .context(RotateCertificateSnafu)?
173 }
174
175 tcp_connection = tcp_listener.accept() => {
177 let (tcp_stream, remote_addr) = match tcp_connection {
178 Ok((stream, addr)) => (stream, addr),
179 Err(err) => {
180 tracing::trace!(%err, "failed to accept incoming TCP connection");
181 continue;
182 }
183 };
184
185 let tower_service: Result<_, Infallible> = router.call(remote_addr).await;
188 let tower_service = tower_service.expect("Infallible error can never happen");
189
190 let span = tracing::debug_span!("accept tcp connection");
191 tokio::spawn(
192 async move {
193 Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr)
194 .instrument(span)
195 .await
196 }
197 );
198 }
199 };
200 }
201 }
202
203 async fn handle_request(
204 tcp_stream: TcpStream,
205 remote_addr: SocketAddr,
206 tls_acceptor: TlsAcceptor,
207 tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
208 socket_addr: SocketAddr,
209 ) {
210 let span = tracing::trace_span!(
211 "accept tls connection",
212 "otel.kind" = ?SpanKind::Server,
213 { semconv::attribute::OTEL_STATUS_CODE } = Empty,
214 { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
215 { semconv::trace::CLIENT_ADDRESS } = remote_addr.ip().to_string(),
216 { semconv::trace::CLIENT_PORT } = remote_addr.port() as i64,
217 { semconv::trace::SERVER_ADDRESS } = Empty,
218 { semconv::trace::SERVER_PORT } = Empty,
219 { semconv::trace::NETWORK_PEER_ADDRESS } = remote_addr.ip().to_string(),
220 { semconv::trace::NETWORK_PEER_PORT } = remote_addr.port() as i64,
221 { semconv::trace::NETWORK_LOCAL_ADDRESS } = Empty,
222 { semconv::trace::NETWORK_LOCAL_PORT } = Empty,
223 { semconv::trace::NETWORK_TRANSPORT } = "tcp",
224 { semconv::trace::NETWORK_TYPE } = socket_addr.semantic_convention_network_type(),
225 );
226
227 if let Ok(local_addr) = tcp_stream.local_addr() {
228 let addr = &local_addr.ip().to_string();
229 let port = local_addr.port();
230 span.record(semconv::trace::SERVER_ADDRESS, addr)
231 .record(semconv::trace::SERVER_PORT, port as i64)
232 .record(semconv::trace::NETWORK_LOCAL_ADDRESS, addr)
233 .record(semconv::trace::NETWORK_LOCAL_PORT, port as i64);
234 }
235
236 let tls_stream = match tls_acceptor
238 .accept(tcp_stream)
239 .instrument(span.clone())
240 .await
241 {
242 Ok(tls_stream) => tls_stream,
243 Err(err) => {
244 span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
245 .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
246 tracing::trace!(%remote_addr, "error during tls handshake connection");
247 return;
248 }
249 };
250
251 let tls_stream = TokioIo::new(tls_stream);
254
255 let hyper_service = service_fn(move |request: Request<Incoming>| {
259 let otel_context = Span::current().context();
261 tower_service
263 .clone()
264 .oneshot(request)
265 .with_context(otel_context)
266 });
267
268 let span = tracing::debug_span!("serve connection");
269 hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
270 .serve_connection_with_upgrades(tls_stream, hyper_service)
271 .instrument(span.clone())
272 .await
273 .unwrap_or_else(|err| {
274 span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
275 .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
276 tracing::warn!(%err, %remote_addr, "failed to serve connection");
277 })
278 }
279}
280
281pub trait SocketAddrExt {
282 fn semantic_convention_network_type(&self) -> &'static str;
283}
284
285impl SocketAddrExt for SocketAddr {
286 fn semantic_convention_network_type(&self) -> &'static str {
287 match self {
288 SocketAddr::V4(_) => "ipv4",
289 SocketAddr::V6(_) => "ipv6",
290 }
291 }
292}
293
294