1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
4
5use axum::{
6 Router,
7 extract::{ConnectInfo, Request},
8 middleware::AddExtension,
9};
10use hyper::{body::Incoming, service::service_fn};
11use hyper_util::rt::{TokioExecutor, TokioIo};
12use opentelemetry::trace::{FutureExt, SpanKind};
13use opentelemetry_semantic_conventions as semconv;
14use snafu::{ResultExt, Snafu};
15use stackable_shared::time::Duration;
16use tokio::{
17 net::{TcpListener, TcpStream},
18 sync::mpsc,
19};
20use tokio_rustls::{
21 TlsAcceptor,
22 rustls::{
23 ServerConfig,
24 crypto::ring::default_provider,
25 version::{TLS12, TLS13},
26 },
27};
28use tower::{Service, ServiceExt};
29use tracing::{Instrument, Span, field::Empty, instrument};
30use tracing_opentelemetry::OpenTelemetrySpanExt;
31use x509_cert::Certificate;
32
33use crate::{
34 options::WebhookOptions,
35 tls::cert_resolver::{CertificateResolver, CertificateResolverError},
36};
37
38mod cert_resolver;
39
40pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_hours_unchecked(24);
41pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_hours_unchecked(24);
42pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_hours_unchecked(20);
43
44pub type Result<T, E = TlsServerError> = std::result::Result<T, E>;
45
46#[derive(Debug, Snafu)]
47pub enum TlsServerError {
48 #[snafu(display("failed to create certificate resolver"))]
49 CreateCertificateResolver { source: CertificateResolverError },
50
51 #[snafu(display("failed to create TCP listener by binding to socket address {socket_addr:?}"))]
52 BindTcpListener {
53 source: std::io::Error,
54 socket_addr: SocketAddr,
55 },
56
57 #[snafu(display("failed to rotate certificate"))]
58 RotateCertificate { source: CertificateResolverError },
59
60 #[snafu(display("failed to set safe TLS protocol versions"))]
61 SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error },
62}
63
64pub struct TlsServer {
69 config: ServerConfig,
70 cert_resolver: Arc<CertificateResolver>,
71
72 socket_addr: SocketAddr,
73 router: Router,
74}
75
76impl TlsServer {
77 #[instrument(name = "create_tls_server", skip(router))]
84 pub async fn new(
85 router: Router,
86 options: WebhookOptions,
87 ) -> Result<(Self, mpsc::Receiver<Certificate>)> {
88 let (cert_tx, cert_rx) = mpsc::channel(1);
89
90 let WebhookOptions {
91 socket_addr,
92 subject_alterative_dns_names,
93 } = options;
94
95 let cert_resolver = CertificateResolver::new(subject_alterative_dns_names, cert_tx)
96 .await
97 .context(CreateCertificateResolverSnafu)?;
98 let cert_resolver = Arc::new(cert_resolver);
99
100 let tls_provider = default_provider();
101 let mut config = ServerConfig::builder_with_provider(tls_provider.into())
102 .with_protocol_versions(&[&TLS12, &TLS13])
103 .context(SetSafeTlsProtocolVersionsSnafu)?
104 .with_no_client_auth()
105 .with_cert_resolver(cert_resolver.clone());
106 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
107
108 let tls_server = Self {
109 config,
110 cert_resolver,
111 socket_addr,
112 router,
113 };
114
115 Ok((tls_server, cert_rx))
116 }
117
118 pub async fn run(self) -> Result<()> {
125 let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
126 let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
127
128 let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));
129 let tcp_listener =
130 TcpListener::bind(self.socket_addr)
131 .await
132 .context(BindTcpListenerSnafu {
133 socket_addr: self.socket_addr,
134 })?;
135
136 let mut router = self
148 .router
149 .into_make_service_with_connect_info::<SocketAddr>();
150
151 loop {
152 let tls_acceptor = tls_acceptor.clone();
153
154 tokio::select! {
156 biased;
160
161 _ = interval.tick() => {
164 self.cert_resolver
165 .rotate_certificate()
166 .await
167 .context(RotateCertificateSnafu)?
168 }
169
170 tcp_connection = tcp_listener.accept() => {
172 let (tcp_stream, remote_addr) = match tcp_connection {
173 Ok((stream, addr)) => (stream, addr),
174 Err(err) => {
175 tracing::trace!(%err, "failed to accept incoming TCP connection");
176 continue;
177 }
178 };
179
180 let tower_service: Result<_, Infallible> = router.call(remote_addr).await;
183 let tower_service = tower_service.expect("Infallible error can never happen");
184
185 let span = tracing::debug_span!("accept tcp connection");
186 tokio::spawn(
187 async move {
188 Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr)
189 .instrument(span)
190 .await
191 }
192 );
193 }
194 };
195 }
196 }
197
198 async fn handle_request(
199 tcp_stream: TcpStream,
200 remote_addr: SocketAddr,
201 tls_acceptor: TlsAcceptor,
202 tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
203 socket_addr: SocketAddr,
204 ) {
205 let span = tracing::trace_span!(
206 "accept tls connection",
207 "otel.kind" = ?SpanKind::Server,
208 { semconv::attribute::OTEL_STATUS_CODE } = Empty,
209 { semconv::attribute::OTEL_STATUS_DESCRIPTION } = Empty,
210 { semconv::trace::CLIENT_ADDRESS } = remote_addr.ip().to_string(),
211 { semconv::trace::CLIENT_PORT } = remote_addr.port() as i64,
212 { semconv::trace::SERVER_ADDRESS } = Empty,
213 { semconv::trace::SERVER_PORT } = Empty,
214 { semconv::trace::NETWORK_PEER_ADDRESS } = remote_addr.ip().to_string(),
215 { semconv::trace::NETWORK_PEER_PORT } = remote_addr.port() as i64,
216 { semconv::trace::NETWORK_LOCAL_ADDRESS } = Empty,
217 { semconv::trace::NETWORK_LOCAL_PORT } = Empty,
218 { semconv::trace::NETWORK_TRANSPORT } = "tcp",
219 { semconv::trace::NETWORK_TYPE } = socket_addr.semantic_convention_network_type(),
220 );
221
222 if let Ok(local_addr) = tcp_stream.local_addr() {
223 let addr = &local_addr.ip().to_string();
224 let port = local_addr.port();
225 span.record(semconv::trace::SERVER_ADDRESS, addr)
226 .record(semconv::trace::SERVER_PORT, port as i64)
227 .record(semconv::trace::NETWORK_LOCAL_ADDRESS, addr)
228 .record(semconv::trace::NETWORK_LOCAL_PORT, port as i64);
229 }
230
231 let tls_stream = match tls_acceptor
233 .accept(tcp_stream)
234 .instrument(span.clone())
235 .await
236 {
237 Ok(tls_stream) => tls_stream,
238 Err(err) => {
239 span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
240 .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
241 tracing::trace!(%remote_addr, "error during tls handshake connection");
242 return;
243 }
244 };
245
246 let tls_stream = TokioIo::new(tls_stream);
249
250 let hyper_service = service_fn(move |request: Request<Incoming>| {
254 let otel_context = Span::current().context();
256 tower_service
258 .clone()
259 .oneshot(request)
260 .with_context(otel_context)
261 });
262
263 let span = tracing::debug_span!("serve connection");
264 hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
265 .serve_connection_with_upgrades(tls_stream, hyper_service)
266 .instrument(span.clone())
267 .await
268 .unwrap_or_else(|err| {
269 span.record(semconv::attribute::OTEL_STATUS_CODE, "Error")
270 .record(semconv::attribute::OTEL_STATUS_DESCRIPTION, err.to_string());
271 tracing::warn!(%err, %remote_addr, "failed to serve connection");
272 })
273 }
274}
275
276pub trait SocketAddrExt {
277 fn semantic_convention_network_type(&self) -> &'static str;
278}
279
280impl SocketAddrExt for SocketAddr {
281 fn semantic_convention_network_type(&self) -> &'static str {
282 match self {
283 SocketAddr::V4(_) => "ipv4",
284 SocketAddr::V6(_) => "ipv6",
285 }
286 }
287}
288
289