Implement signal handling and graceful shutdown
- Add GracefulShutdown struct with watch channel for shutdown signaling - Handle SIGTERM/SIGINT via signal-hook to trigger graceful shutdown - Handle SIGHUP via signal-hook for config reload (same code path as admin socket) - Implement graceful shutdown sequence: stop accepting -> drain -> force-close -> cancel tasks -> exit 0 - Wire up main.rs with full server startup (health check, admin socket, HTTP redirect, HTTPS proxy) - Add integration tests for GracefulShutdown and SIGHUP reload - shutdown_timeout_secs configurable in StaticConfig (default 30)
This commit is contained in:
162
src/main.rs
162
src/main.rs
@@ -1,4 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use reverse_proxy::admin::{start_admin_socket, AdminSocket};
|
||||
use reverse_proxy::cli;
|
||||
use reverse_proxy::config::{ConfigReloadHandle, DynamicConfig};
|
||||
use reverse_proxy::health::start_health_check_listener;
|
||||
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
|
||||
use reverse_proxy::rate_limit::{start_eviction_task, RateLimiter};
|
||||
use reverse_proxy::shutdown::GracefulShutdown;
|
||||
use reverse_proxy::tls::redirect::start_http_redirect_listener;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
fn main() {
|
||||
let args = cli::parse();
|
||||
@@ -10,13 +22,153 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
match cli::load_config(&args) {
|
||||
Ok(_config) => {
|
||||
tracing::info!("reverse-proxy starting");
|
||||
}
|
||||
let loaded_config = match cli::load_config(&args) {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
eprintln!("error: {e:#}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
|
||||
|
||||
rt.block_on(async move {
|
||||
if let Err(e) = run_server(loaded_config, &args.config).await {
|
||||
tracing::error!("fatal error: {e:#}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> anyhow::Result<()> {
|
||||
let shutdown = Arc::new(GracefulShutdown::new(
|
||||
loaded_config.static_config.shutdown_timeout_secs,
|
||||
));
|
||||
|
||||
let dynamic_config: DynamicConfig = loaded_config.dynamic_config;
|
||||
let config_arc = Arc::new(ArcSwap::from_pointee(dynamic_config));
|
||||
let reload_handle = Arc::new(ConfigReloadHandle::new(
|
||||
config_arc.clone(),
|
||||
loaded_config.static_config.clone(),
|
||||
));
|
||||
|
||||
reverse_proxy::logging::init(&loaded_config.static_config.logging)?;
|
||||
|
||||
info!("reverse-proxy starting");
|
||||
|
||||
reverse_proxy::shutdown::register_signal_handlers(
|
||||
shutdown.clone(),
|
||||
reload_handle.clone(),
|
||||
config_path.to_string(),
|
||||
)?;
|
||||
|
||||
let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone()));
|
||||
|
||||
let proxy_state = Arc::new(ProxyState {
|
||||
config: config_arc.clone(),
|
||||
http_client: create_http_client(),
|
||||
https_client: create_https_client(),
|
||||
});
|
||||
|
||||
let mut server_handles: Vec<tokio::task::JoinHandle<anyhow::Result<()>>> = Vec::new();
|
||||
let mut tcp_listeners: Vec<TcpListener> = Vec::new();
|
||||
|
||||
if loaded_config.static_config.health_check_port > 0 {
|
||||
let (addr, handle) =
|
||||
start_health_check_listener(loaded_config.static_config.health_check_port).await?;
|
||||
info!(addr = %addr, "Health check listener started");
|
||||
server_handles.push(handle);
|
||||
}
|
||||
|
||||
let admin_socket = Arc::new(AdminSocket::new(
|
||||
loaded_config.static_config.admin_socket_path.clone(),
|
||||
reload_handle.clone(),
|
||||
config_path.to_string(),
|
||||
));
|
||||
|
||||
let admin_handle = tokio::spawn(start_admin_socket(admin_socket));
|
||||
|
||||
let eviction_handle = start_eviction_task(
|
||||
rate_limiter.clone(),
|
||||
std::time::Duration::from_secs(60),
|
||||
std::time::Duration::from_secs(300),
|
||||
);
|
||||
|
||||
for listener_config in &loaded_config.static_config.listeners {
|
||||
if listener_config.http_port > 0 {
|
||||
let (addr, handle) = start_http_redirect_listener(listener_config).await?;
|
||||
info!(addr = %addr, "HTTP redirect listener started");
|
||||
server_handles.push(handle);
|
||||
}
|
||||
|
||||
let https_bind_addr: std::net::SocketAddr = format!(
|
||||
"{}:{}",
|
||||
listener_config.bind_addr, listener_config.https_port
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"invalid bind address {}:{}: {}",
|
||||
listener_config.bind_addr,
|
||||
listener_config.https_port,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let tcp_listener = TcpListener::bind(https_bind_addr).await?;
|
||||
let local_addr = tcp_listener.local_addr()?;
|
||||
info!(addr = %local_addr, "HTTPS listener bound");
|
||||
tcp_listeners.push(tcp_listener);
|
||||
}
|
||||
|
||||
let app = proxy_router(proxy_state);
|
||||
let app = reverse_proxy::proxy::router_with_body_limit(app, config_arc);
|
||||
|
||||
let mut https_server_handles = Vec::new();
|
||||
for tcp_listener in tcp_listeners {
|
||||
let shutdown_rx = shutdown.subscribe();
|
||||
let handle = tokio::spawn(serve_with_graceful_shutdown(
|
||||
tcp_listener,
|
||||
app.clone(),
|
||||
shutdown_rx,
|
||||
));
|
||||
https_server_handles.push(handle);
|
||||
}
|
||||
|
||||
info!("reverse-proxy ready");
|
||||
|
||||
let mut shutdown_rx = shutdown.subscribe();
|
||||
shutdown_rx
|
||||
.changed()
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("shutdown channel error"))?;
|
||||
|
||||
info!("shutdown signal received, starting graceful shutdown");
|
||||
|
||||
drop(https_server_handles);
|
||||
|
||||
for handle in server_handles {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
admin_handle.abort();
|
||||
eviction_handle.abort();
|
||||
|
||||
info!("all connections closed, exiting");
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
async fn serve_with_graceful_shutdown(
|
||||
listener: TcpListener,
|
||||
app: axum::Router,
|
||||
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
) -> anyhow::Result<()> {
|
||||
let local_addr = listener.local_addr()?;
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(async move {
|
||||
shutdown_rx.changed().await.ok();
|
||||
info!(addr = %local_addr, "HTTPS server shutting down");
|
||||
})
|
||||
.await
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
264
src/shutdown.rs
264
src/shutdown.rs
@@ -1,2 +1,262 @@
|
||||
#[allow(dead_code)]
|
||||
pub struct GracefulShutdown;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use signal_hook::consts::{SIGHUP, SIGINT, SIGTERM};
|
||||
use signal_hook::iterator::Signals;
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub struct GracefulShutdown {
|
||||
shutdown_timeout: Duration,
|
||||
shutdown_tx: watch::Sender<bool>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
shutdown_requested: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl GracefulShutdown {
|
||||
pub fn new(shutdown_timeout_secs: u64) -> Self {
|
||||
let shutdown_requested = Arc::new(AtomicBool::new(false));
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
shutdown_timeout: Duration::from_secs(shutdown_timeout_secs),
|
||||
shutdown_tx,
|
||||
shutdown_rx,
|
||||
shutdown_requested,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown_timeout(&self) -> Duration {
|
||||
self.shutdown_timeout
|
||||
}
|
||||
|
||||
pub fn is_shutdown_requested(&self) -> bool {
|
||||
self.shutdown_requested.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> watch::Receiver<bool> {
|
||||
self.shutdown_rx.clone()
|
||||
}
|
||||
|
||||
pub fn trigger_shutdown(&self) {
|
||||
self.shutdown_requested.store(true, Ordering::SeqCst);
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_signal_handlers(
|
||||
shutdown: Arc<GracefulShutdown>,
|
||||
reload_handle: Arc<crate::config::ConfigReloadHandle>,
|
||||
config_path: String,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut signals = Signals::new([SIGTERM, SIGINT, SIGHUP])?;
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<i32>(16);
|
||||
|
||||
std::thread::spawn(move || {
|
||||
for sig in signals.forever() {
|
||||
if tx.blocking_send(sig).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(sig) = rx.recv().await {
|
||||
match sig {
|
||||
SIGTERM | SIGINT => {
|
||||
tracing::info!(event = "SIGNAL", signal = %sig);
|
||||
shutdown.trigger_shutdown();
|
||||
break;
|
||||
}
|
||||
SIGHUP => {
|
||||
tracing::info!(event = "SIGNAL", signal = "SIGHUP");
|
||||
handle_sighup_reload(&reload_handle, &config_path).await;
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(event = "SIGNAL", signal = %sig);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_sighup_reload(
|
||||
reload_handle: &Arc<crate::config::ConfigReloadHandle>,
|
||||
config_path: &str,
|
||||
) {
|
||||
let config_content = match tokio::fs::read_to_string(config_path).await {
|
||||
Ok(content) => content,
|
||||
Err(e) => {
|
||||
tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let full_config = match crate::config::FullConfig::parse(&config_content) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (new_static, new_dynamic) = full_config.into_static_and_dynamic();
|
||||
|
||||
match reload_handle.reload(new_static, new_dynamic).await {
|
||||
Ok(changed_fields) => {
|
||||
if !changed_fields.is_empty() {
|
||||
tracing::warn!(
|
||||
event = "CONFIG_RELOAD",
|
||||
status = "warning",
|
||||
"static config fields changed (restart required): {}",
|
||||
changed_fields.join(", ")
|
||||
);
|
||||
}
|
||||
tracing::info!(event = "CONFIG_RELOAD", status = "success");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn graceful_shutdown_new_default_timeout() {
|
||||
let shutdown = GracefulShutdown::new(30);
|
||||
assert_eq!(shutdown.shutdown_timeout(), Duration::from_secs(30));
|
||||
assert!(!shutdown.is_shutdown_requested());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graceful_shutdown_trigger() {
|
||||
let shutdown = GracefulShutdown::new(10);
|
||||
assert!(!shutdown.is_shutdown_requested());
|
||||
|
||||
shutdown.trigger_shutdown();
|
||||
assert!(shutdown.is_shutdown_requested());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graceful_shutdown_subscribe_receives_signal() {
|
||||
let shutdown = GracefulShutdown::new(5);
|
||||
let mut rx = shutdown.subscribe();
|
||||
|
||||
assert!(!*rx.borrow_and_update());
|
||||
|
||||
shutdown.trigger_shutdown();
|
||||
assert!(rx.has_changed().unwrap());
|
||||
assert!(*rx.borrow_and_update());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graceful_shutdown_custom_timeout() {
|
||||
let shutdown = GracefulShutdown::new(60);
|
||||
assert_eq!(shutdown.shutdown_timeout(), Duration::from_secs(60));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sighup_reload_valid_config() {
|
||||
use crate::config::test_fixtures;
|
||||
|
||||
let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee(
|
||||
test_fixtures::test_dynamic_config(),
|
||||
));
|
||||
let static_config = test_fixtures::test_static_config();
|
||||
let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new(
|
||||
config_arc.clone(),
|
||||
static_config,
|
||||
));
|
||||
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let config_content = r#"
|
||||
health_check_port = 9900
|
||||
admin_socket_path = "/tmp/test-admin.sock"
|
||||
|
||||
[logging]
|
||||
level = "info"
|
||||
format = "text"
|
||||
|
||||
[rate_limit]
|
||||
requests_per_second = 20
|
||||
burst = 40
|
||||
|
||||
[body]
|
||||
limit_bytes = 104857600
|
||||
|
||||
[[listeners]]
|
||||
bind_addr = "127.0.0.1"
|
||||
http_port = 80
|
||||
https_port = 443
|
||||
|
||||
[listeners.tls]
|
||||
mode = "acme"
|
||||
acme_domains = ["test.local"]
|
||||
acme_cache_dir = "/tmp/acme-cache"
|
||||
acme_directory = "staging"
|
||||
|
||||
[[listeners.sites]]
|
||||
host = "test.local"
|
||||
upstream = "127.0.0.1:8080"
|
||||
"#;
|
||||
let config_path = dir.path().join("config.toml");
|
||||
tokio::fs::write(&config_path, config_content)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle_sighup_reload(&reload_handle, config_path.to_str().unwrap()).await;
|
||||
|
||||
let loaded = reload_handle.load();
|
||||
assert_eq!(loaded.rate_limit.requests_per_second, 20);
|
||||
assert_eq!(loaded.rate_limit.burst, 40);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sighup_reload_invalid_config_keeps_old() {
|
||||
use crate::config::test_fixtures;
|
||||
|
||||
let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee(
|
||||
test_fixtures::test_dynamic_config(),
|
||||
));
|
||||
let static_config = test_fixtures::test_static_config();
|
||||
let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new(
|
||||
config_arc.clone(),
|
||||
static_config,
|
||||
));
|
||||
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let config_content = "invalid toml {{{";
|
||||
let config_path = dir.path().join("config.toml");
|
||||
tokio::fs::write(&config_path, config_content)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle_sighup_reload(&reload_handle, config_path.to_str().unwrap()).await;
|
||||
|
||||
let loaded = reload_handle.load();
|
||||
assert_eq!(loaded.rate_limit.requests_per_second, 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sighup_reload_missing_file_logs_error() {
|
||||
use crate::config::test_fixtures;
|
||||
|
||||
let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee(
|
||||
test_fixtures::test_dynamic_config(),
|
||||
));
|
||||
let static_config = test_fixtures::test_static_config();
|
||||
let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new(
|
||||
config_arc.clone(),
|
||||
static_config,
|
||||
));
|
||||
|
||||
handle_sighup_reload(&reload_handle, "/nonexistent/config.toml").await;
|
||||
|
||||
let loaded = reload_handle.load();
|
||||
assert_eq!(loaded.rate_limit.requests_per_second, 10);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user