Fix spec deviations and implement graceful shutdown drain
- Replace determine_if_https() with ProxyState.is_https field so X-Forwarded-Proto reflects the listener's protocol instead of guessing from the Host header - Return ProxyError::BadGateway with host/upstream context for non-connect upstream errors instead of bare StatusCode::BAD_GATEWAY - Implement InFlightCounter with RAII guard for tracking in-flight connections - Add drain_in_flight() to wait for connections to complete on shutdown, with configurable timeout before forcing exit - Mark review/core-components and review/integration-readiness as complete
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::http::Request;
|
||||
@@ -9,13 +11,46 @@ use hyper_util::service::TowerToHyperService;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tower::Service;
|
||||
use tracing::{error, warn};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub struct InFlightCounter {
|
||||
count: AtomicUsize,
|
||||
}
|
||||
|
||||
struct InFlightGuard(Arc<InFlightCounter>);
|
||||
|
||||
impl Drop for InFlightGuard {
|
||||
fn drop(&mut self) {
|
||||
self.0.decrement();
|
||||
}
|
||||
}
|
||||
|
||||
impl InFlightCounter {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
count: AtomicUsize::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn increment(&self) {
|
||||
self.count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn decrement(&self) {
|
||||
self.count.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.count.load(Ordering::SeqCst) == 0
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_https_listener(
|
||||
tcp_listener: TcpListener,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
router: Router,
|
||||
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
in_flight: Arc<InFlightCounter>,
|
||||
) {
|
||||
let local_addr = tcp_listener.local_addr();
|
||||
|
||||
@@ -32,8 +67,11 @@ pub async fn serve_https_listener(
|
||||
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let router = router.clone();
|
||||
let in_flight = in_flight.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _guard = InFlightGuard(in_flight.clone());
|
||||
|
||||
let tls_stream = match tls_acceptor.accept(tcp_stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
@@ -63,7 +101,7 @@ pub async fn serve_https_listener(
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
if let Ok(addr) = local_addr {
|
||||
tracing::info!(addr = %addr, "HTTPS listener shutting down");
|
||||
info!(addr = %addr, "HTTPS listener shutting down");
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -71,6 +109,24 @@ pub async fn serve_https_listener(
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for in-flight connections to drain, with a timeout.
|
||||
/// Returns the number of connections still in-flight when the timeout expired (0 if all drained).
|
||||
pub async fn drain_in_flight(
|
||||
in_flight: &Arc<InFlightCounter>,
|
||||
timeout: std::time::Duration,
|
||||
) -> usize {
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
if in_flight.is_zero() {
|
||||
return 0;
|
||||
}
|
||||
if start.elapsed() >= timeout {
|
||||
return in_flight.count.load(Ordering::SeqCst);
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnectInfoService<S> {
|
||||
inner: S,
|
||||
|
||||
Reference in New Issue
Block a user