Implement graceful shutdown for listeners, admin socket, eviction task, and ACME
- Replace handle.abort() for HTTPS server tasks with timeout-based join, allowing in-flight requests to drain before forceful shutdown - Add shutdown_rx to start_admin_socket with tokio::select! for clean accept loop exit and Unix socket file cleanup on shutdown - Add shutdown_rx to start_eviction_task with tokio::select! for cancellable eviction loop - Add shutdown channel to spawn_acme_state for cancellable ACME state machine via tokio::select! - Pass Arc<GracefulShutdown> through setup_tls to ACME state machine - Move GracefulShutdown creation before admin socket and TLS setup - Update integration test for new start_eviction_task signature
This commit is contained in:
@@ -9,6 +9,8 @@ use tokio::net::UnixListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::shutdown::GracefulShutdown;
|
||||
|
||||
use crate::config::ConfigReloadHandle;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -70,7 +72,10 @@ impl AdminSocket {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_admin_socket(admin_socket: Arc<AdminSocket>) -> Result<(), AdminSocketError> {
|
||||
pub async fn start_admin_socket(
|
||||
admin_socket: Arc<AdminSocket>,
|
||||
shutdown: Arc<GracefulShutdown>,
|
||||
) -> Result<(), AdminSocketError> {
|
||||
if admin_socket.socket_path.is_empty() {
|
||||
info!("admin socket disabled (empty path)");
|
||||
return Err(AdminSocketError::Disabled);
|
||||
@@ -96,19 +101,41 @@ pub async fn start_admin_socket(admin_socket: Arc<AdminSocket>) -> Result<(), Ad
|
||||
|
||||
info!("admin socket listening on {}", socket_path);
|
||||
|
||||
let mut shutdown_rx = shutdown.subscribe();
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, _addr)) => {
|
||||
let admin_socket = admin_socket.clone();
|
||||
tokio::spawn(async move {
|
||||
handle_connection(stream, admin_socket).await;
|
||||
});
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, _addr)) => {
|
||||
let admin_socket = admin_socket.clone();
|
||||
tokio::spawn(async move {
|
||||
handle_connection(stream, admin_socket).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to accept admin socket connection: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to accept admin socket connection: {}", e);
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("admin socket shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_socket_file(socket_path).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cleanup_socket_file(path: &str) {
|
||||
if Path::new(path).exists() {
|
||||
if let Err(e) = tokio::fs::remove_file(path).await {
|
||||
warn!("failed to remove admin socket file {}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup_stale_socket(path: &str) -> Result<(), AdminSocketError> {
|
||||
@@ -507,7 +534,7 @@ upstream = "127.0.0.1:8080"
|
||||
dir.path().join("config.toml").to_string_lossy().to_string(),
|
||||
));
|
||||
|
||||
let result = start_admin_socket(admin_socket).await;
|
||||
let result = start_admin_socket(admin_socket, Arc::new(GracefulShutdown::new(30))).await;
|
||||
assert!(matches!(result, Err(AdminSocketError::Disabled)));
|
||||
}
|
||||
|
||||
@@ -530,7 +557,7 @@ upstream = "127.0.0.1:8080"
|
||||
dir.path().join("config.toml").to_string_lossy().to_string(),
|
||||
));
|
||||
|
||||
let result = start_admin_socket(admin_socket).await;
|
||||
let result = start_admin_socket(admin_socket, Arc::new(GracefulShutdown::new(30))).await;
|
||||
assert!(matches!(result, Err(AdminSocketError::SocketInUse(_))));
|
||||
}
|
||||
|
||||
|
||||
32
src/main.rs
32
src/main.rs
@@ -65,6 +65,10 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
let dynamic_config: DynamicConfig = loaded_config.dynamic_config;
|
||||
let config_arc = Arc::new(ArcSwap::from_pointee(dynamic_config));
|
||||
|
||||
let shutdown = Arc::new(GracefulShutdown::new(
|
||||
loaded_config.static_config.shutdown_timeout_secs,
|
||||
));
|
||||
|
||||
let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone()));
|
||||
|
||||
let http_client = create_http_client();
|
||||
@@ -81,6 +85,12 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
loaded_config.static_config.clone(),
|
||||
));
|
||||
|
||||
reverse_proxy::shutdown::register_signal_handlers(
|
||||
shutdown.clone(),
|
||||
reload_handle.clone(),
|
||||
config_path.to_string(),
|
||||
)?;
|
||||
|
||||
if loaded_config.static_config.health_check_port > 0 {
|
||||
let (health_addr, _health_handle) =
|
||||
health::start_health_check_listener(loaded_config.static_config.health_check_port)
|
||||
@@ -96,8 +106,9 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
config_path.to_string(),
|
||||
));
|
||||
let admin_socket_clone = admin_socket.clone();
|
||||
let shutdown_for_admin = shutdown.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = start_admin_socket(admin_socket_clone).await {
|
||||
if let Err(e) = start_admin_socket(admin_socket_clone, shutdown_for_admin).await {
|
||||
match e {
|
||||
AdminSocketError::Disabled => {}
|
||||
AdminSocketError::SocketInUse(path) => {
|
||||
@@ -150,7 +161,7 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
|
||||
let mut tls_acceptors = Vec::new();
|
||||
for (listener_config, _) in &bound_listeners {
|
||||
let tls_mode = setup_tls(&listener_config.tls).context(format!(
|
||||
let tls_mode = setup_tls(&listener_config.tls, shutdown.clone()).context(format!(
|
||||
"failed to setup TLS for listener {}",
|
||||
listener_config.bind_addr
|
||||
))?;
|
||||
@@ -175,20 +186,11 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
let shutdown = Arc::new(GracefulShutdown::new(
|
||||
loaded_config.static_config.shutdown_timeout_secs,
|
||||
));
|
||||
|
||||
reverse_proxy::shutdown::register_signal_handlers(
|
||||
shutdown.clone(),
|
||||
reload_handle.clone(),
|
||||
config_path.to_string(),
|
||||
)?;
|
||||
|
||||
let _eviction_handle = start_eviction_task(
|
||||
rate_limiter.clone(),
|
||||
std::time::Duration::from_secs(60),
|
||||
std::time::Duration::from_secs(300),
|
||||
shutdown.subscribe(),
|
||||
);
|
||||
|
||||
let app = build_router(proxy_state.clone(), config_arc.clone(), rate_limiter);
|
||||
@@ -230,8 +232,12 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Resu
|
||||
|
||||
info!("shutdown signal received, starting graceful shutdown");
|
||||
|
||||
let shutdown_timeout = shutdown.shutdown_timeout();
|
||||
for handle in https_server_handles {
|
||||
handle.abort();
|
||||
let result = tokio::time::timeout(shutdown_timeout, handle).await;
|
||||
if result.is_err() {
|
||||
warn!("shutdown timeout expired, aborting listener task");
|
||||
}
|
||||
}
|
||||
|
||||
let remaining = drain_in_flight(&in_flight, shutdown.shutdown_timeout()).await;
|
||||
|
||||
@@ -102,12 +102,20 @@ pub fn start_eviction_task(
|
||||
limiter: Arc<RateLimiter>,
|
||||
interval: Duration,
|
||||
max_age: Duration,
|
||||
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let mut interval_timer = tokio::time::interval(interval);
|
||||
loop {
|
||||
interval_timer.tick().await;
|
||||
limiter.evict_stale(max_age);
|
||||
tokio::select! {
|
||||
_ = interval_timer.tick() => {
|
||||
limiter.evict_stale(max_age);
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
tracing::info!("rate limiter eviction task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use tracing::info;
|
||||
use super::acme::{spawn_acme_state, AcmeTlsConfig};
|
||||
use super::config::crypto_provider;
|
||||
use crate::config::static_config::TlsConfig;
|
||||
use crate::shutdown::GracefulShutdown;
|
||||
|
||||
const ACME_TLS_ALPN_01: &[u8] = b"acme-tls/1";
|
||||
|
||||
@@ -41,7 +42,7 @@ pub enum TlsMode {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn setup_tls(tls_config: &TlsConfig) -> Result<TlsMode> {
|
||||
pub fn setup_tls(tls_config: &TlsConfig, shutdown: Arc<GracefulShutdown>) -> Result<TlsMode> {
|
||||
match tls_config.mode.as_str() {
|
||||
"manual" => {
|
||||
if tls_config.cert_path.is_empty() {
|
||||
@@ -75,7 +76,7 @@ pub fn setup_tls(tls_config: &TlsConfig) -> Result<TlsMode> {
|
||||
|
||||
let default_config = build_acme_server_config(resolver.clone())?;
|
||||
|
||||
spawn_acme_state(state, tls_config.acme_domains.clone());
|
||||
spawn_acme_state(state, tls_config.acme_domains.clone(), shutdown);
|
||||
|
||||
info!(
|
||||
domains = ?tls_config.acme_domains,
|
||||
@@ -136,7 +137,8 @@ mod tests {
|
||||
cert_path: String::new(),
|
||||
key_path: "/some/key.pem".to_string(),
|
||||
};
|
||||
let result = setup_tls(&tls_config);
|
||||
let shutdown = Arc::new(GracefulShutdown::new(30));
|
||||
let result = setup_tls(&tls_config, shutdown);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("cert_path"));
|
||||
@@ -153,7 +155,8 @@ mod tests {
|
||||
cert_path: "/some/cert.pem".to_string(),
|
||||
key_path: String::new(),
|
||||
};
|
||||
let result = setup_tls(&tls_config);
|
||||
let shutdown = Arc::new(GracefulShutdown::new(30));
|
||||
let result = setup_tls(&tls_config, shutdown);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("key_path"));
|
||||
@@ -170,7 +173,8 @@ mod tests {
|
||||
cert_path: String::new(),
|
||||
key_path: String::new(),
|
||||
};
|
||||
let result = setup_tls(&tls_config);
|
||||
let shutdown = Arc::new(GracefulShutdown::new(30));
|
||||
let result = setup_tls(&tls_config, shutdown);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("acme_domains"));
|
||||
@@ -187,7 +191,8 @@ mod tests {
|
||||
cert_path: String::new(),
|
||||
key_path: String::new(),
|
||||
};
|
||||
let result = setup_tls(&tls_config);
|
||||
let shutdown = Arc::new(GracefulShutdown::new(30));
|
||||
let result = setup_tls(&tls_config, shutdown);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("acme_cache_dir"));
|
||||
@@ -204,9 +209,10 @@ mod tests {
|
||||
cert_path: String::new(),
|
||||
key_path: String::new(),
|
||||
};
|
||||
let result = setup_tls(&tls_config);
|
||||
let shutdown = Arc::new(GracefulShutdown::new(30));
|
||||
let result = setup_tls(&tls_config, shutdown);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("unknown TLS mode"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
173
src/tls/acme.rs
173
src/tls/acme.rs
@@ -6,6 +6,8 @@ use rustls_acme::caches::DirCache;
|
||||
use rustls_acme::{AcmeConfig, AcmeState, EventError, EventOk, ResolvesServerCertAcme};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::shutdown::GracefulShutdown;
|
||||
|
||||
#[allow(dead_code)]
|
||||
const LETS_ENCRYPT_PRODUCTION_DIRECTORY: &str = "https://acme-v02.api.letsencrypt.org/directory";
|
||||
#[allow(dead_code)]
|
||||
@@ -66,93 +68,106 @@ impl AcmeTlsConfig {
|
||||
pub fn spawn_acme_state(
|
||||
state: AcmeState<std::io::Error, std::io::Error>,
|
||||
domains: Vec<String>,
|
||||
shutdown: Arc<GracefulShutdown>,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
use futures::StreamExt;
|
||||
let mut state = state;
|
||||
let mut shutdown_rx = shutdown.subscribe();
|
||||
loop {
|
||||
match state.next().await {
|
||||
Some(Ok(event)) => match event {
|
||||
EventOk::DeployedCachedCert => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: deployed cached certificate"
|
||||
);
|
||||
tokio::select! {
|
||||
event = state.next() => {
|
||||
match event {
|
||||
Some(Ok(event)) => match event {
|
||||
EventOk::DeployedCachedCert => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: deployed cached certificate"
|
||||
);
|
||||
}
|
||||
EventOk::DeployedNewCert => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: deployed new certificate"
|
||||
);
|
||||
}
|
||||
EventOk::CertCacheStore => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: certificate stored to cache"
|
||||
);
|
||||
}
|
||||
EventOk::AccountCacheStore => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: account stored to cache"
|
||||
);
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => match &err {
|
||||
EventError::CertCacheLoad(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate cache load failed"
|
||||
);
|
||||
}
|
||||
EventError::AccountCacheLoad(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: account cache load failed"
|
||||
);
|
||||
}
|
||||
EventError::CertCacheStore(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate cache store failed"
|
||||
);
|
||||
}
|
||||
EventError::AccountCacheStore(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: account cache store failed"
|
||||
);
|
||||
}
|
||||
EventError::CachedCertParse(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: cached certificate parse failed"
|
||||
);
|
||||
}
|
||||
EventError::Order(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate order failed, will retry"
|
||||
);
|
||||
}
|
||||
EventError::NewCertParse(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: new certificate parse failed"
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: state machine ended"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
EventOk::DeployedNewCert => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: deployed new certificate"
|
||||
);
|
||||
}
|
||||
EventOk::CertCacheStore => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: certificate stored to cache"
|
||||
);
|
||||
}
|
||||
EventOk::AccountCacheStore => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: account stored to cache"
|
||||
);
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => match &err {
|
||||
EventError::CertCacheLoad(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate cache load failed"
|
||||
);
|
||||
}
|
||||
EventError::AccountCacheLoad(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: account cache load failed"
|
||||
);
|
||||
}
|
||||
EventError::CertCacheStore(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate cache store failed"
|
||||
);
|
||||
}
|
||||
EventError::AccountCacheStore(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: account cache store failed"
|
||||
);
|
||||
}
|
||||
EventError::CachedCertParse(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: cached certificate parse failed"
|
||||
);
|
||||
}
|
||||
EventError::Order(e) => {
|
||||
warn!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: certificate order failed, will retry"
|
||||
);
|
||||
}
|
||||
EventError::NewCertParse(e) => {
|
||||
error!(
|
||||
domains = ?domains,
|
||||
error = ?e,
|
||||
"ACME: new certificate parse failed"
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!(
|
||||
domains = ?domains,
|
||||
"ACME: state machine ended"
|
||||
"ACME: state machine shutting down"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -244,10 +244,12 @@ async fn test_rate_limit_eviction_task() {
|
||||
|
||||
limiter.check_and_consume(std::net::IpAddr::from([192, 168, 1, 1]));
|
||||
|
||||
let shutdown = Arc::new(reverse_proxy::shutdown::GracefulShutdown::new(30));
|
||||
let handle = reverse_proxy::rate_limit::start_eviction_task(
|
||||
limiter.clone(),
|
||||
Duration::from_millis(50),
|
||||
Duration::from_millis(100),
|
||||
shutdown.subscribe(),
|
||||
);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
Reference in New Issue
Block a user