Fix HTTP/2 support: use ALPN-based protocol detection and fallback to URI host

Two changes to properly support HTTP/2 clients:

1. server.rs: Detect ALPN protocol after TLS handshake and use
   hyper::server::conn::http2::Builder for H2 connections instead
   of the auto::Builder which failed to detect HTTP/2 over TLS.
   The auto::Builder's ReadVersion mechanism doesn't work reliably
   with tokio-rustls TlsStreams. For H1 connections, continue using
   auto::Builder with upgrade support.

2. handler.rs: Fallback to URI host when Host header is missing.
   In HTTP/2, the host is conveyed via :authority pseudo-header which
   hyper represents as the URI host, not a Host header.
This commit is contained in:
2026-06-12 06:14:46 +00:00
parent da28ea749d
commit 9ebb8ee7a8
2 changed files with 38 additions and 19 deletions

View File

@@ -39,11 +39,14 @@ async fn proxy_handler(
let host = req let host = req
.headers() .headers()
.get(axum::http::header::HOST) .get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok())
.or_else(|| req.uri().host())
.unwrap_or_default();
let host = match host { let host = if host.is_empty() {
Some(h) => h, return ProxyError::MissingHost.into_response();
None => return ProxyError::MissingHost.into_response(), } else {
host
}; };
let config = state.config.load(); let config = state.config.load();

View File

@@ -6,6 +6,7 @@ use axum::extract::ConnectInfo;
use axum::http::Request; use axum::http::Request;
use axum::response::Response; use axum::response::Response;
use axum::Router; use axum::Router;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioExecutor;
use hyper_util::service::TowerToHyperService; use hyper_util::service::TowerToHyperService;
use tokio::net::TcpListener; use tokio::net::TcpListener;
@@ -80,24 +81,40 @@ pub async fn serve_https_listener(
} }
}; };
let alpn = tls_stream.get_ref().1.alpn_protocol();
let is_h2 = alpn == Some(b"h2");
let svc = ConnectInfoService { let svc = ConnectInfoService {
inner: router.into_service::<hyper::body::Incoming>(), inner: router.into_service::<Incoming>(),
remote_addr, remote_addr,
}; };
let svc = TowerToHyperService::new(svc); let svc = TowerToHyperService::new(svc);
let io = hyper_util::rt::TokioIo::new(tls_stream); let io = hyper_util::rt::TokioIo::new(tls_stream);
if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) if is_h2 {
.serve_connection_with_upgrades(io, svc) let mut builder = hyper::server::conn::http2::Builder::new(TokioExecutor::new());
.await if let Err(e) = builder
{ .enable_connect_protocol()
if let Some(hyper_err) = e.downcast_ref::<hyper::Error>() { .serve_connection(io, svc)
if hyper_err.is_incomplete_message() { .await
return; {
} error!(error = %e, "HTTPS/2 connection error");
}
} else {
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
builder.http2().enable_connect_protocol();
if let Err(e) = builder
.serve_connection_with_upgrades(io, svc)
.await
{
if let Some(hyper_err) = e.downcast_ref::<hyper::Error>() {
if hyper_err.is_incomplete_message() {
return;
}
}
error!(error = %e, "HTTPS connection error");
} }
error!(error = %e, "HTTPS connection error");
} }
}); });
} }
@@ -135,11 +152,10 @@ struct ConnectInfoService<S> {
remote_addr: SocketAddr, remote_addr: SocketAddr,
} }
impl<S, B> Service<Request<B>> for ConnectInfoService<S> impl<S> Service<Request<Incoming>> for ConnectInfoService<S>
where where
S: Service<Request<B>, Response = Response> + Clone + Send + 'static, S: Service<Request<Incoming>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static, S::Future: Send + 'static,
B: Send + 'static,
{ {
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
@@ -152,7 +168,7 @@ where
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
} }
fn call(&mut self, mut req: Request<B>) -> Self::Future { fn call(&mut self, mut req: Request<Incoming>) -> Self::Future {
req.extensions_mut().insert(ConnectInfo(self.remote_addr)); req.extensions_mut().insert(ConnectInfo(self.remote_addr));
self.inner.call(req) self.inner.call(req)
} }