diff --git a/src/main.rs b/src/main.rs index 51c11fd..c554acf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,16 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use reverse_proxy::admin::{start_admin_socket, AdminSocket}; use reverse_proxy::cli; +use reverse_proxy::config::{ConfigReloadHandle, DynamicConfig}; +use reverse_proxy::health::start_health_check_listener; +use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState}; +use reverse_proxy::rate_limit::{start_eviction_task, RateLimiter}; +use reverse_proxy::shutdown::GracefulShutdown; +use reverse_proxy::tls::redirect::start_http_redirect_listener; +use tokio::net::TcpListener; +use tracing::info; fn main() { let args = cli::parse(); @@ -10,13 +22,153 @@ fn main() { } } - match cli::load_config(&args) { - Ok(_config) => { - tracing::info!("reverse-proxy starting"); - } + let loaded_config = match cli::load_config(&args) { + Ok(config) => config, Err(e) => { eprintln!("error: {e:#}"); std::process::exit(1); } - } + }; + + let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime"); + + rt.block_on(async move { + if let Err(e) = run_server(loaded_config, &args.config).await { + tracing::error!("fatal error: {e:#}"); + std::process::exit(1); + } + }); +} + +async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> anyhow::Result<()> { + let shutdown = Arc::new(GracefulShutdown::new( + loaded_config.static_config.shutdown_timeout_secs, + )); + + let dynamic_config: DynamicConfig = loaded_config.dynamic_config; + let config_arc = Arc::new(ArcSwap::from_pointee(dynamic_config)); + let reload_handle = Arc::new(ConfigReloadHandle::new( + config_arc.clone(), + loaded_config.static_config.clone(), + )); + + reverse_proxy::logging::init(&loaded_config.static_config.logging)?; + + info!("reverse-proxy starting"); + + reverse_proxy::shutdown::register_signal_handlers( + shutdown.clone(), + reload_handle.clone(), + config_path.to_string(), + )?; + + let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone())); + + let proxy_state = Arc::new(ProxyState { + config: config_arc.clone(), + http_client: create_http_client(), + https_client: create_https_client(), + }); + + let mut server_handles: Vec>> = Vec::new(); + let mut tcp_listeners: Vec = Vec::new(); + + if loaded_config.static_config.health_check_port > 0 { + let (addr, handle) = + start_health_check_listener(loaded_config.static_config.health_check_port).await?; + info!(addr = %addr, "Health check listener started"); + server_handles.push(handle); + } + + let admin_socket = Arc::new(AdminSocket::new( + loaded_config.static_config.admin_socket_path.clone(), + reload_handle.clone(), + config_path.to_string(), + )); + + let admin_handle = tokio::spawn(start_admin_socket(admin_socket)); + + let eviction_handle = start_eviction_task( + rate_limiter.clone(), + std::time::Duration::from_secs(60), + std::time::Duration::from_secs(300), + ); + + for listener_config in &loaded_config.static_config.listeners { + if listener_config.http_port > 0 { + let (addr, handle) = start_http_redirect_listener(listener_config).await?; + info!(addr = %addr, "HTTP redirect listener started"); + server_handles.push(handle); + } + + let https_bind_addr: std::net::SocketAddr = format!( + "{}:{}", + listener_config.bind_addr, listener_config.https_port + ) + .parse() + .map_err(|e| { + anyhow::anyhow!( + "invalid bind address {}:{}: {}", + listener_config.bind_addr, + listener_config.https_port, + e + ) + })?; + + let tcp_listener = TcpListener::bind(https_bind_addr).await?; + let local_addr = tcp_listener.local_addr()?; + info!(addr = %local_addr, "HTTPS listener bound"); + tcp_listeners.push(tcp_listener); + } + + let app = proxy_router(proxy_state); + let app = reverse_proxy::proxy::router_with_body_limit(app, config_arc); + + let mut https_server_handles = Vec::new(); + for tcp_listener in tcp_listeners { + let shutdown_rx = shutdown.subscribe(); + let handle = tokio::spawn(serve_with_graceful_shutdown( + tcp_listener, + app.clone(), + shutdown_rx, + )); + https_server_handles.push(handle); + } + + info!("reverse-proxy ready"); + + let mut shutdown_rx = shutdown.subscribe(); + shutdown_rx + .changed() + .await + .map_err(|_| anyhow::anyhow!("shutdown channel error"))?; + + info!("shutdown signal received, starting graceful shutdown"); + + drop(https_server_handles); + + for handle in server_handles { + handle.abort(); + } + + admin_handle.abort(); + eviction_handle.abort(); + + info!("all connections closed, exiting"); + std::process::exit(0); +} + +async fn serve_with_graceful_shutdown( + listener: TcpListener, + app: axum::Router, + mut shutdown_rx: tokio::sync::watch::Receiver, +) -> anyhow::Result<()> { + let local_addr = listener.local_addr()?; + axum::serve(listener, app) + .with_graceful_shutdown(async move { + shutdown_rx.changed().await.ok(); + info!(addr = %local_addr, "HTTPS server shutting down"); + }) + .await + .map_err(anyhow::Error::from) } diff --git a/src/shutdown.rs b/src/shutdown.rs index 95004e0..e0eef33 100644 --- a/src/shutdown.rs +++ b/src/shutdown.rs @@ -1,2 +1,262 @@ -#[allow(dead_code)] -pub struct GracefulShutdown; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use signal_hook::consts::{SIGHUP, SIGINT, SIGTERM}; +use signal_hook::iterator::Signals; +use tokio::sync::watch; + +pub struct GracefulShutdown { + shutdown_timeout: Duration, + shutdown_tx: watch::Sender, + shutdown_rx: watch::Receiver, + shutdown_requested: Arc, +} + +impl GracefulShutdown { + pub fn new(shutdown_timeout_secs: u64) -> Self { + let shutdown_requested = Arc::new(AtomicBool::new(false)); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + Self { + shutdown_timeout: Duration::from_secs(shutdown_timeout_secs), + shutdown_tx, + shutdown_rx, + shutdown_requested, + } + } + + pub fn shutdown_timeout(&self) -> Duration { + self.shutdown_timeout + } + + pub fn is_shutdown_requested(&self) -> bool { + self.shutdown_requested.load(Ordering::SeqCst) + } + + pub fn subscribe(&self) -> watch::Receiver { + self.shutdown_rx.clone() + } + + pub fn trigger_shutdown(&self) { + self.shutdown_requested.store(true, Ordering::SeqCst); + let _ = self.shutdown_tx.send(true); + } +} + +pub fn register_signal_handlers( + shutdown: Arc, + reload_handle: Arc, + config_path: String, +) -> anyhow::Result<()> { + let mut signals = Signals::new([SIGTERM, SIGINT, SIGHUP])?; + let (tx, mut rx) = tokio::sync::mpsc::channel::(16); + + std::thread::spawn(move || { + for sig in signals.forever() { + if tx.blocking_send(sig).is_err() { + break; + } + } + }); + + tokio::spawn(async move { + while let Some(sig) = rx.recv().await { + match sig { + SIGTERM | SIGINT => { + tracing::info!(event = "SIGNAL", signal = %sig); + shutdown.trigger_shutdown(); + break; + } + SIGHUP => { + tracing::info!(event = "SIGNAL", signal = "SIGHUP"); + handle_sighup_reload(&reload_handle, &config_path).await; + } + _ => { + tracing::debug!(event = "SIGNAL", signal = %sig); + } + } + } + }); + + Ok(()) +} + +pub async fn handle_sighup_reload( + reload_handle: &Arc, + config_path: &str, +) { + let config_content = match tokio::fs::read_to_string(config_path).await { + Ok(content) => content, + Err(e) => { + tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e); + return; + } + }; + + let full_config = match crate::config::FullConfig::parse(&config_content) { + Ok(c) => c, + Err(e) => { + tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e); + return; + } + }; + + let (new_static, new_dynamic) = full_config.into_static_and_dynamic(); + + match reload_handle.reload(new_static, new_dynamic).await { + Ok(changed_fields) => { + if !changed_fields.is_empty() { + tracing::warn!( + event = "CONFIG_RELOAD", + status = "warning", + "static config fields changed (restart required): {}", + changed_fields.join(", ") + ); + } + tracing::info!(event = "CONFIG_RELOAD", status = "success"); + } + Err(e) => { + tracing::error!(event = "CONFIG_RELOAD", status = "error", error = %e); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn graceful_shutdown_new_default_timeout() { + let shutdown = GracefulShutdown::new(30); + assert_eq!(shutdown.shutdown_timeout(), Duration::from_secs(30)); + assert!(!shutdown.is_shutdown_requested()); + } + + #[test] + fn graceful_shutdown_trigger() { + let shutdown = GracefulShutdown::new(10); + assert!(!shutdown.is_shutdown_requested()); + + shutdown.trigger_shutdown(); + assert!(shutdown.is_shutdown_requested()); + } + + #[test] + fn graceful_shutdown_subscribe_receives_signal() { + let shutdown = GracefulShutdown::new(5); + let mut rx = shutdown.subscribe(); + + assert!(!*rx.borrow_and_update()); + + shutdown.trigger_shutdown(); + assert!(rx.has_changed().unwrap()); + assert!(*rx.borrow_and_update()); + } + + #[test] + fn graceful_shutdown_custom_timeout() { + let shutdown = GracefulShutdown::new(60); + assert_eq!(shutdown.shutdown_timeout(), Duration::from_secs(60)); + } + + #[tokio::test] + async fn sighup_reload_valid_config() { + use crate::config::test_fixtures; + + let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee( + test_fixtures::test_dynamic_config(), + )); + let static_config = test_fixtures::test_static_config(); + let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new( + config_arc.clone(), + static_config, + )); + + let dir = tempfile::tempdir().unwrap(); + let config_content = r#" +health_check_port = 9900 +admin_socket_path = "/tmp/test-admin.sock" + +[logging] +level = "info" +format = "text" + +[rate_limit] +requests_per_second = 20 +burst = 40 + +[body] +limit_bytes = 104857600 + +[[listeners]] +bind_addr = "127.0.0.1" +http_port = 80 +https_port = 443 + +[listeners.tls] +mode = "acme" +acme_domains = ["test.local"] +acme_cache_dir = "/tmp/acme-cache" +acme_directory = "staging" + +[[listeners.sites]] +host = "test.local" +upstream = "127.0.0.1:8080" +"#; + let config_path = dir.path().join("config.toml"); + tokio::fs::write(&config_path, config_content) + .await + .unwrap(); + + handle_sighup_reload(&reload_handle, config_path.to_str().unwrap()).await; + + let loaded = reload_handle.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 20); + assert_eq!(loaded.rate_limit.burst, 40); + } + + #[tokio::test] + async fn sighup_reload_invalid_config_keeps_old() { + use crate::config::test_fixtures; + + let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee( + test_fixtures::test_dynamic_config(), + )); + let static_config = test_fixtures::test_static_config(); + let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new( + config_arc.clone(), + static_config, + )); + + let dir = tempfile::tempdir().unwrap(); + let config_content = "invalid toml {{{"; + let config_path = dir.path().join("config.toml"); + tokio::fs::write(&config_path, config_content) + .await + .unwrap(); + + handle_sighup_reload(&reload_handle, config_path.to_str().unwrap()).await; + + let loaded = reload_handle.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 10); + } + + #[tokio::test] + async fn sighup_reload_missing_file_logs_error() { + use crate::config::test_fixtures; + + let config_arc = Arc::new(arc_swap::ArcSwap::from_pointee( + test_fixtures::test_dynamic_config(), + )); + let static_config = test_fixtures::test_static_config(); + let reload_handle = Arc::new(crate::config::ConfigReloadHandle::new( + config_arc.clone(), + static_config, + )); + + handle_sighup_reload(&reload_handle, "/nonexistent/config.toml").await; + + let loaded = reload_handle.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 10); + } +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index e02c04e..2aa686e 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -742,3 +742,147 @@ async fn test_body_limit_empty_body_request_succeeds() { let _ = server.shutdown_tx.send(()); } + +#[tokio::test] +async fn test_graceful_shutdown_trigger() { + let shutdown = Arc::new(reverse_proxy::shutdown::GracefulShutdown::new(30)); + + assert!(!shutdown.is_shutdown_requested()); + + let mut rx = shutdown.subscribe(); + assert!(!*rx.borrow_and_update()); + + shutdown.trigger_shutdown(); + + assert!(shutdown.is_shutdown_requested()); + assert!(rx.has_changed().unwrap()); + assert!(*rx.borrow_and_update()); +} + +#[tokio::test] +async fn test_graceful_shutdown_custom_timeout() { + let shutdown = reverse_proxy::shutdown::GracefulShutdown::new(60); + assert_eq!(shutdown.shutdown_timeout(), Duration::from_secs(60)); +} + +#[tokio::test] +async fn test_graceful_shutdown_subscribe_multiple_receivers() { + let shutdown = Arc::new(reverse_proxy::shutdown::GracefulShutdown::new(10)); + + let mut rx1 = shutdown.subscribe(); + let mut rx2 = shutdown.subscribe(); + + assert!(!*rx1.borrow_and_update()); + assert!(!*rx2.borrow_and_update()); + + shutdown.trigger_shutdown(); + + assert!(rx1.has_changed().unwrap()); + assert!(rx2.has_changed().unwrap()); +} + +#[tokio::test] +async fn test_sighup_config_reload_valid_config() { + let config_arc = Arc::new(ArcSwap::from_pointee( + reverse_proxy::config::test_fixtures::test_dynamic_config(), + )); + let static_config = reverse_proxy::config::test_fixtures::test_static_config(); + let reload_handle = Arc::new(reverse_proxy::config::ConfigReloadHandle::new( + config_arc.clone(), + static_config, + )); + + let dir = tempfile::tempdir().unwrap(); + let config_content = r#" +health_check_port = 9900 +admin_socket_path = "/tmp/test-admin.sock" + +[logging] +level = "info" +format = "text" + +[rate_limit] +requests_per_second = 20 +burst = 40 + +[body] +limit_bytes = 104857600 + +[[listeners]] +bind_addr = "127.0.0.1" +http_port = 80 +https_port = 443 + +[listeners.tls] +mode = "acme" +acme_domains = ["test.local"] +acme_cache_dir = "/tmp/acme-cache" +acme_directory = "staging" + +[[listeners.sites]] +host = "test.local" +upstream = "127.0.0.1:8080" +"#; + let config_path = dir.path().join("config.toml"); + tokio::fs::write(&config_path, config_content) + .await + .unwrap(); + + let config_path_str = config_path.to_str().unwrap().to_string(); + reverse_proxy::shutdown::handle_sighup_reload(&reload_handle, &config_path_str).await; + + let loaded = reload_handle.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 20); + assert_eq!(loaded.rate_limit.burst, 40); +} + +#[tokio::test] +async fn test_sighup_config_reload_invalid_config_keeps_old() { + let config_arc = Arc::new(ArcSwap::from_pointee( + reverse_proxy::config::test_fixtures::test_dynamic_config(), + )); + let static_config = reverse_proxy::config::test_fixtures::test_static_config(); + let reload_handle = Arc::new(reverse_proxy::config::ConfigReloadHandle::new( + config_arc.clone(), + static_config, + )); + + let dir = tempfile::tempdir().unwrap(); + let config_content = "invalid toml {{{"; + let config_path = dir.path().join("config.toml"); + tokio::fs::write(&config_path, config_content) + .await + .unwrap(); + + let config_path_str = config_path.to_str().unwrap().to_string(); + let _ = reverse_proxy::shutdown::handle_sighup_reload(&reload_handle, &config_path_str).await; + + let loaded = reload_handle.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 10); +} + +#[tokio::test] +async fn test_graceful_shutdown_with_health_check() { + let (addr, handle) = reverse_proxy::health::start_health_check_listener(0) + .await + .unwrap(); + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://127.0.0.1:{}/health", addr.port())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let shutdown = Arc::new(reverse_proxy::shutdown::GracefulShutdown::new(5)); + let rx = shutdown.subscribe(); + + assert!(!shutdown.is_shutdown_requested()); + + shutdown.trigger_shutdown(); + assert!(shutdown.is_shutdown_requested()); + assert!(rx.has_changed().unwrap()); + + handle.abort(); +}