diff --git a/src/admin/socket.rs b/src/admin/socket.rs index eb11cfe..7689f30 100644 --- a/src/admin/socket.rs +++ b/src/admin/socket.rs @@ -9,6 +9,8 @@ use tokio::net::UnixListener; use tokio::sync::Mutex; use tracing::{info, warn}; +use crate::shutdown::GracefulShutdown; + use crate::config::ConfigReloadHandle; #[derive(Debug, thiserror::Error)] @@ -70,7 +72,10 @@ impl AdminSocket { } } -pub async fn start_admin_socket(admin_socket: Arc) -> Result<(), AdminSocketError> { +pub async fn start_admin_socket( + admin_socket: Arc, + shutdown: Arc, +) -> Result<(), AdminSocketError> { if admin_socket.socket_path.is_empty() { info!("admin socket disabled (empty path)"); return Err(AdminSocketError::Disabled); @@ -96,19 +101,41 @@ pub async fn start_admin_socket(admin_socket: Arc) -> Result<(), Ad info!("admin socket listening on {}", socket_path); + let mut shutdown_rx = shutdown.subscribe(); + loop { - match listener.accept().await { - Ok((stream, _addr)) => { - let admin_socket = admin_socket.clone(); - tokio::spawn(async move { - handle_connection(stream, admin_socket).await; - }); + tokio::select! { + result = listener.accept() => { + match result { + Ok((stream, _addr)) => { + let admin_socket = admin_socket.clone(); + tokio::spawn(async move { + handle_connection(stream, admin_socket).await; + }); + } + Err(e) => { + warn!("failed to accept admin socket connection: {}", e); + } + } } - Err(e) => { - warn!("failed to accept admin socket connection: {}", e); + _ = shutdown_rx.changed() => { + info!("admin socket shutting down"); + break; } } } + + cleanup_socket_file(socket_path).await; + + Ok(()) +} + +async fn cleanup_socket_file(path: &str) { + if Path::new(path).exists() { + if let Err(e) = tokio::fs::remove_file(path).await { + warn!("failed to remove admin socket file {}: {}", path, e); + } + } } async fn cleanup_stale_socket(path: &str) -> Result<(), AdminSocketError> { @@ -508,7 +535,7 @@ upstream = "127.0.0.1:8080" dir.path().join("config.toml").to_string_lossy().to_string(), )); - let result = start_admin_socket(admin_socket).await; + let result = start_admin_socket(admin_socket, Arc::new(GracefulShutdown::new(30))).await; assert!(matches!(result, Err(AdminSocketError::Disabled))); } @@ -531,7 +558,7 @@ upstream = "127.0.0.1:8080" dir.path().join("config.toml").to_string_lossy().to_string(), )); - let result = start_admin_socket(admin_socket).await; + let result = start_admin_socket(admin_socket, Arc::new(GracefulShutdown::new(30))).await; assert!(matches!(result, Err(AdminSocketError::SocketInUse(_)))); } diff --git a/src/main.rs b/src/main.rs index 9cf3cc0..60ff4c7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,6 +65,10 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu let dynamic_config: DynamicConfig = loaded_config.dynamic_config; let config_arc = Arc::new(ArcSwap::from_pointee(dynamic_config)); + let shutdown = Arc::new(GracefulShutdown::new( + loaded_config.static_config.shutdown_timeout_secs, + )); + let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone())); let http_client = create_http_client(); @@ -81,6 +85,12 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu loaded_config.static_config.clone(), )); + reverse_proxy::shutdown::register_signal_handlers( + shutdown.clone(), + reload_handle.clone(), + config_path.to_string(), + )?; + if loaded_config.static_config.health_check_port > 0 { let (health_addr, _health_handle) = health::start_health_check_listener(loaded_config.static_config.health_check_port) @@ -96,8 +106,9 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu config_path.to_string(), )); let admin_socket_clone = admin_socket.clone(); + let shutdown_for_admin = shutdown.clone(); tokio::spawn(async move { - if let Err(e) = start_admin_socket(admin_socket_clone).await { + if let Err(e) = start_admin_socket(admin_socket_clone, shutdown_for_admin).await { match e { AdminSocketError::Disabled => {} AdminSocketError::SocketInUse(path) => { @@ -150,7 +161,7 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu let mut tls_acceptors = Vec::new(); for (listener_config, _) in &bound_listeners { - let tls_mode = setup_tls(&listener_config.tls).context(format!( + let tls_mode = setup_tls(&listener_config.tls, shutdown.clone()).context(format!( "failed to setup TLS for listener {}", listener_config.bind_addr ))?; @@ -175,20 +186,11 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu } } - let shutdown = Arc::new(GracefulShutdown::new( - loaded_config.static_config.shutdown_timeout_secs, - )); - - reverse_proxy::shutdown::register_signal_handlers( - shutdown.clone(), - reload_handle.clone(), - config_path.to_string(), - )?; - let _eviction_handle = start_eviction_task( rate_limiter.clone(), std::time::Duration::from_secs(60), std::time::Duration::from_secs(300), + shutdown.subscribe(), ); let app = build_router(proxy_state.clone(), config_arc.clone(), rate_limiter); @@ -230,8 +232,12 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu info!("shutdown signal received, starting graceful shutdown"); + let shutdown_timeout = shutdown.shutdown_timeout(); for handle in https_server_handles { - handle.abort(); + let result = tokio::time::timeout(shutdown_timeout, handle).await; + if result.is_err() { + warn!("shutdown timeout expired, aborting listener task"); + } } let remaining = drain_in_flight(&in_flight, shutdown.shutdown_timeout()).await; diff --git a/src/rate_limit/mod.rs b/src/rate_limit/mod.rs index 47be64d..69bc6e5 100644 --- a/src/rate_limit/mod.rs +++ b/src/rate_limit/mod.rs @@ -102,12 +102,20 @@ pub fn start_eviction_task( limiter: Arc, interval: Duration, max_age: Duration, + mut shutdown_rx: tokio::sync::watch::Receiver, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut interval_timer = tokio::time::interval(interval); loop { - interval_timer.tick().await; - limiter.evict_stale(max_age); + tokio::select! { + _ = interval_timer.tick() => { + limiter.evict_stale(max_age); + } + _ = shutdown_rx.changed() => { + tracing::info!("rate limiter eviction task shutting down"); + break; + } + } } }) } diff --git a/src/tls/acceptor.rs b/src/tls/acceptor.rs index d12e32a..6d90d43 100644 --- a/src/tls/acceptor.rs +++ b/src/tls/acceptor.rs @@ -8,6 +8,7 @@ use tracing::info; use super::acme::{spawn_acme_state, AcmeTlsConfig}; use super::config::crypto_provider; use crate::config::static_config::TlsConfig; +use crate::shutdown::GracefulShutdown; const ACME_TLS_ALPN_01: &[u8] = b"acme-tls/1"; @@ -41,7 +42,7 @@ pub enum TlsMode { } #[allow(dead_code)] -pub fn setup_tls(tls_config: &TlsConfig) -> Result { +pub fn setup_tls(tls_config: &TlsConfig, shutdown: Arc) -> Result { match tls_config.mode.as_str() { "manual" => { if tls_config.cert_path.is_empty() { @@ -75,7 +76,7 @@ pub fn setup_tls(tls_config: &TlsConfig) -> Result { let default_config = build_acme_server_config(resolver.clone())?; - spawn_acme_state(state, tls_config.acme_domains.clone()); + spawn_acme_state(state, tls_config.acme_domains.clone(), shutdown); info!( domains = ?tls_config.acme_domains, @@ -136,7 +137,8 @@ mod tests { cert_path: String::new(), key_path: "/some/key.pem".to_string(), }; - let result = setup_tls(&tls_config); + let shutdown = Arc::new(GracefulShutdown::new(30)); + let result = setup_tls(&tls_config, shutdown); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("cert_path")); @@ -153,7 +155,8 @@ mod tests { cert_path: "/some/cert.pem".to_string(), key_path: String::new(), }; - let result = setup_tls(&tls_config); + let shutdown = Arc::new(GracefulShutdown::new(30)); + let result = setup_tls(&tls_config, shutdown); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("key_path")); @@ -170,7 +173,8 @@ mod tests { cert_path: String::new(), key_path: String::new(), }; - let result = setup_tls(&tls_config); + let shutdown = Arc::new(GracefulShutdown::new(30)); + let result = setup_tls(&tls_config, shutdown); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("acme_domains")); @@ -187,7 +191,8 @@ mod tests { cert_path: String::new(), key_path: String::new(), }; - let result = setup_tls(&tls_config); + let shutdown = Arc::new(GracefulShutdown::new(30)); + let result = setup_tls(&tls_config, shutdown); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("acme_cache_dir")); @@ -204,7 +209,8 @@ mod tests { cert_path: String::new(), key_path: String::new(), }; - let result = setup_tls(&tls_config); + let shutdown = Arc::new(GracefulShutdown::new(30)); + let result = setup_tls(&tls_config, shutdown); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("unknown TLS mode")); diff --git a/src/tls/acme.rs b/src/tls/acme.rs index e6c20d5..a6faa48 100644 --- a/src/tls/acme.rs +++ b/src/tls/acme.rs @@ -6,6 +6,8 @@ use rustls_acme::caches::DirCache; use rustls_acme::{AcmeConfig, AcmeState, EventError, EventOk, ResolvesServerCertAcme}; use tracing::{error, info, warn}; +use crate::shutdown::GracefulShutdown; + #[allow(dead_code)] const LETS_ENCRYPT_PRODUCTION_DIRECTORY: &str = "https://acme-v02.api.letsencrypt.org/directory"; #[allow(dead_code)] @@ -66,93 +68,106 @@ impl AcmeTlsConfig { pub fn spawn_acme_state( state: AcmeState, domains: Vec, + shutdown: Arc, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { use futures::StreamExt; let mut state = state; + let mut shutdown_rx = shutdown.subscribe(); loop { - match state.next().await { - Some(Ok(event)) => match event { - EventOk::DeployedCachedCert => { - info!( - domains = ?domains, - "ACME: deployed cached certificate" - ); + tokio::select! { + event = state.next() => { + match event { + Some(Ok(event)) => match event { + EventOk::DeployedCachedCert => { + info!( + domains = ?domains, + "ACME: deployed cached certificate" + ); + } + EventOk::DeployedNewCert => { + info!( + domains = ?domains, + "ACME: deployed new certificate" + ); + } + EventOk::CertCacheStore => { + info!( + domains = ?domains, + "ACME: certificate stored to cache" + ); + } + EventOk::AccountCacheStore => { + info!( + domains = ?domains, + "ACME: account stored to cache" + ); + } + }, + Some(Err(err)) => match &err { + EventError::CertCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: certificate cache load failed" + ); + } + EventError::AccountCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: account cache load failed" + ); + } + EventError::CertCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate cache store failed" + ); + } + EventError::AccountCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: account cache store failed" + ); + } + EventError::CachedCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: cached certificate parse failed" + ); + } + EventError::Order(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate order failed, will retry" + ); + } + EventError::NewCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: new certificate parse failed" + ); + } + }, + None => { + info!( + domains = ?domains, + "ACME: state machine ended" + ); + break; + } } - EventOk::DeployedNewCert => { - info!( - domains = ?domains, - "ACME: deployed new certificate" - ); - } - EventOk::CertCacheStore => { - info!( - domains = ?domains, - "ACME: certificate stored to cache" - ); - } - EventOk::AccountCacheStore => { - info!( - domains = ?domains, - "ACME: account stored to cache" - ); - } - }, - Some(Err(err)) => match &err { - EventError::CertCacheLoad(e) => { - error!( - domains = ?domains, - error = ?e, - "ACME: certificate cache load failed" - ); - } - EventError::AccountCacheLoad(e) => { - error!( - domains = ?domains, - error = ?e, - "ACME: account cache load failed" - ); - } - EventError::CertCacheStore(e) => { - warn!( - domains = ?domains, - error = ?e, - "ACME: certificate cache store failed" - ); - } - EventError::AccountCacheStore(e) => { - warn!( - domains = ?domains, - error = ?e, - "ACME: account cache store failed" - ); - } - EventError::CachedCertParse(e) => { - error!( - domains = ?domains, - error = ?e, - "ACME: cached certificate parse failed" - ); - } - EventError::Order(e) => { - warn!( - domains = ?domains, - error = ?e, - "ACME: certificate order failed, will retry" - ); - } - EventError::NewCertParse(e) => { - error!( - domains = ?domains, - error = ?e, - "ACME: new certificate parse failed" - ); - } - }, - None => { + } + _ = shutdown_rx.changed() => { info!( domains = ?domains, - "ACME: state machine ended" + "ACME: state machine shutting down" ); break; } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 025a08a..c011a64 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -244,10 +244,12 @@ async fn test_rate_limit_eviction_task() { limiter.check_and_consume(std::net::IpAddr::from([192, 168, 1, 1])); + let shutdown = Arc::new(reverse_proxy::shutdown::GracefulShutdown::new(30)); let handle = reverse_proxy::rate_limit::start_eviction_task( limiter.clone(), Duration::from_millis(50), Duration::from_millis(100), + shutdown.subscribe(), ); tokio::time::sleep(Duration::from_millis(200)).await;