diff --git a/.worktrees/feat/config/dynamic-config b/.worktrees/feat/config/dynamic-config new file mode 160000 index 0000000..fbae1c4 --- /dev/null +++ b/.worktrees/feat/config/dynamic-config @@ -0,0 +1 @@ +Subproject commit fbae1c464eecde91494a23c518230754d6963196 diff --git a/.worktrees/feat/config/validation b/.worktrees/feat/config/validation new file mode 160000 index 0000000..468adb2 --- /dev/null +++ b/.worktrees/feat/config/validation @@ -0,0 +1 @@ +Subproject commit 468adb21de5804a4887081335448cc608ac2c116 diff --git a/.worktrees/feat/ops/logging b/.worktrees/feat/ops/logging new file mode 160000 index 0000000..36319db --- /dev/null +++ b/.worktrees/feat/ops/logging @@ -0,0 +1 @@ +Subproject commit 36319db10e1f72bc42315234a985f7c0836ee1f7 diff --git a/src/config/dynamic_config.rs b/src/config/dynamic_config.rs index 7e2840b..21038be 100644 --- a/src/config/dynamic_config.rs +++ b/src/config/dynamic_config.rs @@ -1,7 +1,14 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; use serde::Deserialize; +use tokio::sync::Mutex; + +use super::static_config::StaticConfig; +use super::validation::validate; #[allow(dead_code)] -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct DynamicConfig { pub sites: Vec, pub rate_limit: RateLimitConfig, @@ -9,7 +16,7 @@ pub struct DynamicConfig { } #[allow(dead_code)] -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct SiteConfig { pub host: String, pub upstream: String, @@ -37,14 +44,182 @@ fn default_request_timeout() -> u64 { } #[allow(dead_code)] -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RateLimitConfig { pub requests_per_second: u32, pub burst: u32, } #[allow(dead_code)] -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct BodyConfig { pub limit_bytes: u64, } + +#[allow(dead_code)] +pub struct ConfigReloadHandle { + config: Arc>, + static_config: StaticConfig, + reload_mutex: Mutex<()>, +} + +#[allow(dead_code)] +impl ConfigReloadHandle { + pub fn new(config: Arc>, static_config: StaticConfig) -> Self { + Self { + config, + static_config, + reload_mutex: Mutex::new(()), + } + } + + pub fn load(&self) -> Arc { + self.config.load_full() + } + + pub async fn reload( + &self, + new_static: StaticConfig, + new_dynamic: DynamicConfig, + ) -> anyhow::Result> { + let _guard = self.reload_mutex.lock().await; + + validate(&new_static, &new_dynamic, false).map_err(|errors| { + anyhow::anyhow!( + "{}", + errors + .iter() + .map(|e| e.to_string()) + .collect::>() + .join("; ") + ) + })?; + + let changed_fields = diff_static_config(&self.static_config, &new_static); + + self.config.store(Arc::new(new_dynamic)); + + Ok(changed_fields) + } +} + +fn diff_static_config(old: &StaticConfig, new: &StaticConfig) -> Vec { + let mut changes = Vec::new(); + + if old.listeners != new.listeners { + changes.push("listeners".to_string()); + } + if old.allow_wildcard_bind != new.allow_wildcard_bind { + changes.push("allow_wildcard_bind".to_string()); + } + if old.health_check_port != new.health_check_port { + changes.push("health_check_port".to_string()); + } + if old.admin_socket_path != new.admin_socket_path { + changes.push("admin_socket_path".to_string()); + } + if old.shutdown_timeout_secs != new.shutdown_timeout_secs { + changes.push("shutdown_timeout_secs".to_string()); + } + if old.logging != new.logging { + changes.push("logging".to_string()); + } + + changes +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::test_fixtures; + + #[test] + fn arcswap_swap_visible_after_reload() { + let initial = test_fixtures::test_dynamic_config(); + let config_arc = Arc::new(ArcSwap::from_pointee(initial.clone())); + let static_config = test_fixtures::test_static_config(); + let handle = ConfigReloadHandle::new(config_arc.clone(), static_config); + + let loaded = handle.load(); + assert_eq!(loaded.sites.len(), 1); + assert_eq!(loaded.rate_limit.requests_per_second, 10); + + let mut new_dynamic = initial.clone(); + new_dynamic.rate_limit.requests_per_second = 50; + new_dynamic.sites.push(SiteConfig { + host: "new.test".to_string(), + upstream: "127.0.0.1:9090".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(handle.reload(test_fixtures::test_static_config(), new_dynamic)) + .unwrap(); + + let loaded = handle.load(); + assert_eq!(loaded.sites.len(), 2); + assert_eq!(loaded.rate_limit.requests_per_second, 50); + } + + #[test] + fn reload_rejects_invalid_config() { + let initial = test_fixtures::test_dynamic_config(); + let config_arc = Arc::new(ArcSwap::from_pointee(initial.clone())); + let static_config = test_fixtures::test_static_config(); + let handle = ConfigReloadHandle::new(config_arc.clone(), static_config); + + let mut invalid_dynamic = initial.clone(); + invalid_dynamic.rate_limit.requests_per_second = 0; + + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = + rt.block_on(handle.reload(test_fixtures::test_static_config(), invalid_dynamic)); + assert!(result.is_err()); + + let loaded = config_arc.load(); + assert_eq!(loaded.rate_limit.requests_per_second, 10); + } + + #[tokio::test] + async fn concurrent_reload_serialization() { + let initial = test_fixtures::test_dynamic_config(); + let config_arc = Arc::new(ArcSwap::from_pointee(initial.clone())); + let static_config = test_fixtures::test_static_config(); + let handle = Arc::new(ConfigReloadHandle::new(config_arc.clone(), static_config)); + + let mut handles = Vec::new(); + for i in 1..=5u32 { + let h = handle.clone(); + let initial = initial.clone(); + handles.push(tokio::spawn(async move { + let mut dynamic = initial.clone(); + dynamic.rate_limit.requests_per_second = i * 10; + h.reload(test_fixtures::test_static_config(), dynamic).await + })); + } + + for h in handles { + h.await.unwrap().unwrap(); + } + + let loaded = config_arc.load(); + let rps = loaded.rate_limit.requests_per_second; + assert!((rps == 10) || (rps == 20) || (rps == 30) || (rps == 40) || (rps == 50)); + } + + #[test] + fn static_config_diff_detects_changes() { + let old = test_fixtures::test_static_config(); + let mut new = old.clone(); + assert!(diff_static_config(&old, &new).is_empty()); + + new.health_check_port = 8080; + new.logging.level = "debug".to_string(); + let changes = diff_static_config(&old, &new); + assert!(changes.contains(&"health_check_port".to_string())); + assert!(changes.contains(&"logging".to_string())); + assert_eq!(changes.len(), 2); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 072ddc9..cf746ca 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,3 +2,9 @@ pub mod dynamic_config; pub mod static_config; pub mod test_fixtures; pub mod validation; + +pub use dynamic_config::{ + BodyConfig, ConfigReloadHandle, DynamicConfig, RateLimitConfig, SiteConfig, +}; +pub use static_config::{ListenerConfig, LoggingConfig, StaticConfig, TlsConfig}; +pub use validation::{validate, ValidationError}; diff --git a/src/config/static_config.rs b/src/config/static_config.rs index 500dc66..444d415 100644 --- a/src/config/static_config.rs +++ b/src/config/static_config.rs @@ -1,7 +1,7 @@ use serde::Deserialize; #[allow(dead_code)] -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] pub struct StaticConfig { pub listeners: Vec, #[serde(default)] @@ -32,7 +32,7 @@ fn default_shutdown_timeout_secs() -> u64 { } #[allow(dead_code)] -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] pub struct ListenerConfig { pub bind_addr: String, #[serde(default = "default_http_port")] @@ -55,7 +55,7 @@ fn default_https_port() -> u16 { } #[allow(dead_code)] -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] pub struct TlsConfig { pub mode: String, #[serde(default)] @@ -76,7 +76,7 @@ fn default_acme_directory() -> String { } #[allow(dead_code)] -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, PartialEq)] pub struct LoggingConfig { #[serde(default = "default_log_level")] pub level: String, diff --git a/src/config/test_fixtures.rs b/src/config/test_fixtures.rs index c29a6cc..76479cd 100644 --- a/src/config/test_fixtures.rs +++ b/src/config/test_fixtures.rs @@ -8,12 +8,12 @@ pub fn test_static_config() -> StaticConfig { http_port: 80, https_port: 443, tls: TlsConfig { - mode: "manual".to_string(), - acme_domains: vec![], - acme_cache_dir: String::new(), - acme_directory: "production".to_string(), - cert_path: "/tmp/test-cert.pem".to_string(), - key_path: "/tmp/test-key.pem".to_string(), + mode: "acme".to_string(), + acme_domains: vec!["test.local".to_string()], + acme_cache_dir: "/tmp/acme-cache".to_string(), + acme_directory: "staging".to_string(), + cert_path: String::new(), + key_path: String::new(), }, sites: vec![], }], diff --git a/src/config/validation.rs b/src/config/validation.rs index 191ce72..50dfc47 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -81,7 +81,6 @@ pub fn validate( let allow_wildcard = static_config.allow_wildcard_bind || cli_allow_wildcard_bind; - // Rule 1: At least one listener if static_config.listeners.is_empty() { errors.push(ValidationError::NoListeners); } @@ -90,14 +89,12 @@ pub fn validate( let mut http_bind_keys = HashSet::new(); for listener in &static_config.listeners { - // Rule 2: Wildcard bind address if listener.bind_addr == "0.0.0.0" && !allow_wildcard { errors.push(ValidationError::WildcardBindNotAllowed { bind_addr: listener.bind_addr.clone(), }); } - // Rule 3: Unique bind_addr:https_port let https_key = (listener.bind_addr.as_str(), listener.https_port); if !https_bind_keys.insert(https_key) { errors.push(ValidationError::DuplicateHttpsBind { @@ -106,7 +103,6 @@ pub fn validate( }); } - // Rule 10: Unique bind_addr:http_port (if http_port > 0) if listener.http_port > 0 { let http_key = (listener.bind_addr.as_str(), listener.http_port); if !http_bind_keys.insert(http_key) { @@ -117,7 +113,6 @@ pub fn validate( } } - // Rule 12: https_port must be 1-65535 if listener.https_port == 0 { errors.push(ValidationError::HttpsPortInvalid { bind_addr: listener.bind_addr.clone(), @@ -125,7 +120,6 @@ pub fn validate( }); } - // Rule 11: http_port and https_port must differ if listener.http_port > 0 && listener.http_port == listener.https_port { errors.push(ValidationError::HttpsAndHttpPortSame { bind_addr: listener.bind_addr.clone(), @@ -134,10 +128,8 @@ pub fn validate( }); } - // Rule 4 & 5: TLS mode validation match listener.tls.mode.as_str() { "acme" => { - // Rule 4: ACME domains must be non-empty if listener.tls.acme_domains.is_empty() { errors.push(ValidationError::AcmeDomainsEmpty { bind_addr: listener.bind_addr.clone(), @@ -148,12 +140,10 @@ pub fn validate( let cert_empty = listener.tls.cert_path.is_empty(); let key_empty = listener.tls.key_path.is_empty(); if cert_empty || key_empty { - // Rule 5: Both paths must be set errors.push(ValidationError::ManualCertMissing { bind_addr: listener.bind_addr.clone(), }); } else { - // Rule 5: Files must be readable let cert_path = Path::new(&listener.tls.cert_path); if !cert_path.exists() { errors.push(ValidationError::CertPathNotReadable { @@ -176,7 +166,6 @@ pub fn validate( } } - // Rule 14: health_check_port conflicts if static_config.health_check_port > 0 { for listener in &static_config.listeners { if static_config.health_check_port == listener.https_port { @@ -196,50 +185,42 @@ pub fn validate( } } - // Site validation let mut site_hosts: HashSet = HashSet::new(); for listener in &static_config.listeners { for site in &listener.sites { - // Rule 6: host must be set if site.host.is_empty() { errors.push(ValidationError::SiteHostEmpty { host: String::new(), }); } - // Rule 6: upstream must be set if site.upstream.is_empty() { errors.push(ValidationError::SiteUpstreamEmpty { host: site.host.clone(), }); } - // Rule 16: Normalize hostname and check validity let normalized_host = site.host.to_lowercase(); - // Rule 7: Unique hosts (case-insensitive) if !site_hosts.insert(normalized_host.clone()) { errors.push(ValidationError::DuplicateSiteHost { host: normalized_host, }); } - // Rule 15: Host must not contain port if site.host.contains(':') { errors.push(ValidationError::SiteHostContainsPort { host: site.host.clone(), }); } - // Rule 16: Host must be a valid hostname if !is_valid_hostname(&site.host) { errors.push(ValidationError::SiteHostInvalid { host: site.host.clone(), }); } - // Rule 17: Upstream must be host:port format if !site.upstream.is_empty() && !is_valid_upstream(&site.upstream) { errors.push(ValidationError::UpstreamInvalid { host: site.host.clone(), @@ -247,7 +228,6 @@ pub fn validate( }); } - // Rule 18: upstream_scheme must be "http" or "https" if site.upstream_scheme != "http" && site.upstream_scheme != "https" { errors.push(ValidationError::UpstreamSchemeInvalid { host: site.host.clone(), @@ -257,12 +237,10 @@ pub fn validate( } } - // Rule 8: requests_per_second > 0 if dynamic_config.rate_limit.requests_per_second == 0 { errors.push(ValidationError::RequestsPerSecondZero { value: 0 }); } - // Rule 9: body limit_bytes > 0 if dynamic_config.body.limit_bytes == 0 { errors.push(ValidationError::BodyLimitBytesZero { value: 0 }); } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 72945d1..2da36e7 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -33,8 +33,9 @@ fn test_config_fixtures() { #[tokio::test] async fn test_health_check_local_port_returns_200() { - let (addr, handle) = - reverse_proxy::health::start_health_check_listener(0).await.unwrap(); + let (addr, handle) = reverse_proxy::health::start_health_check_listener(0) + .await + .unwrap(); let client = reqwest::Client::new(); let resp = client @@ -52,8 +53,9 @@ async fn test_health_check_local_port_returns_200() { #[tokio::test] async fn test_health_check_local_port_binds_localhost() { - let (addr, handle) = - reverse_proxy::health::start_health_check_listener(0).await.unwrap(); + let (addr, handle) = reverse_proxy::health::start_health_check_listener(0) + .await + .unwrap(); assert!(addr.ip().is_loopback()); assert_eq!(addr.ip().to_string(), "127.0.0.1"); @@ -68,4 +70,4 @@ async fn test_health_check_disabled_when_port_zero() { let (addr, handle) = result.unwrap(); assert_ne!(addr.port(), 0); handle.abort(); -} \ No newline at end of file +}