From d5f5713debee026ea785b1dec08776324c0a57bf Mon Sep 17 00:00:00 2001 From: "glm-5.1" Date: Thu, 11 Jun 2026 12:57:31 +0000 Subject: [PATCH] Implement host-based routing with global routing table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add routing table (HashMap) to DynamicConfig for O(1) host lookup. Implement normalize_host (lowercase + strip port) per RFC 7230 ยง2.7.3. Add proxy_handler that routes /health to 200, missing Host to 400, unknown host to 404, and known host to 200. Routing table updates atomically via ArcSwap. --- src/config/dynamic_config.rs | 179 +++++++++++++++++++++++++-- src/config/mod.rs | 3 +- src/config/test_fixtures.rs | 10 +- src/config/validation.rs | 10 +- src/proxy/handler.rs | 232 ++++++++++++++++++++++++++++++++++- src/proxy/mod.rs | 2 + 6 files changed, 413 insertions(+), 23 deletions(-) diff --git a/src/config/dynamic_config.rs b/src/config/dynamic_config.rs index 21038be..b7c9d7a 100644 --- a/src/config/dynamic_config.rs +++ b/src/config/dynamic_config.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use arc_swap::ArcSwap; @@ -7,15 +8,65 @@ use tokio::sync::Mutex; use super::static_config::StaticConfig; use super::validation::validate; -#[allow(dead_code)] -#[derive(Debug, Deserialize, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct DynamicConfig { + pub sites: Vec, + pub routing_table: HashMap, + pub rate_limit: RateLimitConfig, + pub body: BodyConfig, +} + +impl DynamicConfig { + pub fn from_sites( + sites: Vec, + rate_limit: RateLimitConfig, + body: BodyConfig, + ) -> Self { + let routing_table = build_routing_table(&sites); + Self { + sites, + routing_table, + rate_limit, + body, + } + } + + pub fn lookup(&self, host: &str) -> Option<&SiteConfig> { + self.routing_table.get(&normalize_host(host)) + } +} + +impl PartialEq for DynamicConfig { + fn eq(&self, other: &Self) -> bool { + self.sites == other.sites && self.rate_limit == other.rate_limit && self.body == other.body + } +} + +pub fn build_routing_table(sites: &[SiteConfig]) -> HashMap { + sites + .iter() + .map(|s| (s.host.to_lowercase(), s.clone())) + .collect() +} + +pub fn normalize_host(host: &str) -> String { + let lower = host.to_lowercase(); + lower.split(':').next().unwrap_or(&lower).to_string() +} + +#[derive(Debug, Deserialize, Clone, PartialEq)] +pub struct SerializableDynamicConfig { pub sites: Vec, pub rate_limit: RateLimitConfig, pub body: BodyConfig, } -#[allow(dead_code)] +impl From for DynamicConfig { + fn from(value: SerializableDynamicConfig) -> Self { + DynamicConfig::from_sites(value.sites, value.rate_limit, value.body) + } +} + #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct SiteConfig { pub host: String, @@ -28,42 +79,35 @@ pub struct SiteConfig { pub upstream_request_timeout_secs: u64, } -#[allow(dead_code)] fn default_upstream_scheme() -> String { "http".to_string() } -#[allow(dead_code)] fn default_connect_timeout() -> u64 { 5 } -#[allow(dead_code)] fn default_request_timeout() -> u64 { 60 } -#[allow(dead_code)] #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RateLimitConfig { pub requests_per_second: u32, pub burst: u32, } -#[allow(dead_code)] #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct BodyConfig { pub limit_bytes: u64, } -#[allow(dead_code)] pub struct ConfigReloadHandle { config: Arc>, static_config: StaticConfig, reload_mutex: Mutex<()>, } -#[allow(dead_code)] impl ConfigReloadHandle { pub fn new(config: Arc>, static_config: StaticConfig) -> Self { Self { @@ -153,6 +197,14 @@ mod tests { upstream_connect_timeout_secs: 5, upstream_request_timeout_secs: 60, }); + let new_dynamic = DynamicConfig::from_sites( + new_dynamic.sites, + RateLimitConfig { + requests_per_second: 50, + burst: new_dynamic.rate_limit.burst, + }, + new_dynamic.body, + ); let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(handle.reload(test_fixtures::test_static_config(), new_dynamic)) @@ -196,6 +248,14 @@ mod tests { handles.push(tokio::spawn(async move { let mut dynamic = initial.clone(); dynamic.rate_limit.requests_per_second = i * 10; + let dynamic = DynamicConfig::from_sites( + dynamic.sites, + RateLimitConfig { + requests_per_second: i * 10, + burst: dynamic.rate_limit.burst, + }, + dynamic.body, + ); h.reload(test_fixtures::test_static_config(), dynamic).await })); } @@ -222,4 +282,103 @@ mod tests { assert!(changes.contains(&"logging".to_string())); assert_eq!(changes.len(), 2); } + + #[test] + fn normalize_host_converts_to_lowercase() { + assert_eq!(normalize_host("Git.Alk.DEV"), "git.alk.dev"); + } + + #[test] + fn normalize_host_strips_port() { + assert_eq!(normalize_host("git.alk.dev:443"), "git.alk.dev"); + assert_eq!(normalize_host("GIT.ALK.DEV:8443"), "git.alk.dev"); + } + + #[test] + fn normalize_host_no_port() { + assert_eq!(normalize_host("git.alk.dev"), "git.alk.dev"); + } + + #[test] + fn normalize_host_empty_string() { + assert_eq!(normalize_host(""), ""); + } + + #[test] + fn routing_table_lookup_finds_site() { + let config = test_fixtures::test_dynamic_config(); + let site = config.lookup("test.local"); + assert!(site.is_some()); + assert_eq!(site.unwrap().host, "test.local"); + } + + #[test] + fn routing_table_lookup_case_insensitive() { + let config = test_fixtures::test_dynamic_config(); + let site = config.lookup("TEST.LOCAL"); + assert!(site.is_some()); + assert_eq!(site.unwrap().host, "test.local"); + } + + #[test] + fn routing_table_lookup_strips_port() { + let config = test_fixtures::test_dynamic_config(); + let site = config.lookup("test.local:443"); + assert!(site.is_some()); + assert_eq!(site.unwrap().host, "test.local"); + } + + #[test] + fn routing_table_lookup_unknown_host() { + let config = test_fixtures::test_dynamic_config(); + let site = config.lookup("unknown.example"); + assert!(site.is_none()); + } + + #[test] + fn build_routing_table_multiple_sites() { + let sites = vec![ + SiteConfig { + host: "git.example.com".to_string(), + upstream: "127.0.0.1:3000".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }, + SiteConfig { + host: "www.example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }, + ]; + let table = build_routing_table(&sites); + assert_eq!(table.len(), 2); + assert!(table.contains_key("git.example.com")); + assert!(table.contains_key("www.example.com")); + } + + #[test] + fn dynamic_config_from_sites_builds_routing_table() { + let sites = vec![SiteConfig { + host: "My.Site".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]; + let config = DynamicConfig::from_sites( + sites, + RateLimitConfig { + requests_per_second: 10, + burst: 20, + }, + BodyConfig { + limit_bytes: 104857600, + }, + ); + assert!(config.routing_table.contains_key("my.site")); + assert_eq!(config.routing_table.len(), 1); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index cf746ca..78ceb8f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,7 +4,8 @@ pub mod test_fixtures; pub mod validation; pub use dynamic_config::{ - BodyConfig, ConfigReloadHandle, DynamicConfig, RateLimitConfig, SiteConfig, + build_routing_table, normalize_host, BodyConfig, ConfigReloadHandle, DynamicConfig, + RateLimitConfig, SerializableDynamicConfig, SiteConfig, }; pub use static_config::{ListenerConfig, LoggingConfig, StaticConfig, TlsConfig}; pub use validation::{validate, ValidationError}; diff --git a/src/config/test_fixtures.rs b/src/config/test_fixtures.rs index 76479cd..1bb6e3f 100644 --- a/src/config/test_fixtures.rs +++ b/src/config/test_fixtures.rs @@ -26,22 +26,22 @@ pub fn test_static_config() -> StaticConfig { } pub fn test_dynamic_config() -> DynamicConfig { - DynamicConfig { - sites: vec![SiteConfig { + DynamicConfig::from_sites( + vec![SiteConfig { host: "test.local".to_string(), upstream: "127.0.0.1:8080".to_string(), upstream_scheme: "http".to_string(), upstream_connect_timeout_secs: 5, upstream_request_timeout_secs: 60, }], - rate_limit: RateLimitConfig { + RateLimitConfig { requests_per_second: 10, burst: 20, }, - body: BodyConfig { + BodyConfig { limit_bytes: 104857600, }, - } + ) } #[cfg(test)] diff --git a/src/config/validation.rs b/src/config/validation.rs index 50dfc47..fbe44e2 100644 --- a/src/config/validation.rs +++ b/src/config/validation.rs @@ -335,22 +335,22 @@ mod tests { } fn valid_dynamic_config() -> DynamicConfig { - DynamicConfig { - sites: vec![SiteConfig { + DynamicConfig::from_sites( + vec![SiteConfig { host: "test.local".to_string(), upstream: "127.0.0.1:8080".to_string(), upstream_scheme: "http".to_string(), upstream_connect_timeout_secs: 5, upstream_request_timeout_secs: 60, }], - rate_limit: RateLimitConfig { + RateLimitConfig { requests_per_second: 10, burst: 20, }, - body: BodyConfig { + BodyConfig { limit_bytes: 104857600, }, - } + ) } fn make_static_with_sites(sites: Vec, tls: TlsConfig) -> StaticConfig { diff --git a/src/proxy/handler.rs b/src/proxy/handler.rs index fc8a6c2..81d09fa 100644 --- a/src/proxy/handler.rs +++ b/src/proxy/handler.rs @@ -1,12 +1,240 @@ +use std::sync::Arc; + +use axum::extract::State; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; use axum::Router; +use arc_swap::ArcSwap; + +use crate::config::dynamic_config::DynamicConfig; + async fn health_handler() -> impl IntoResponse { StatusCode::OK } -pub fn health_route() -> Router { - Router::new().route("/health", get(health_handler)) +async fn proxy_handler( + State(state): State>>, + req: axum::http::Request, +) -> impl IntoResponse { + if req.uri().path() == "/health" { + return StatusCode::OK.into_response(); + } + + let host = req + .headers() + .get(axum::http::header::HOST) + .and_then(|v| v.to_str().ok()); + + let host = match host { + Some(h) => h, + None => return StatusCode::BAD_REQUEST.into_response(), + }; + + let config = state.load(); + match config.lookup(host) { + Some(_site) => StatusCode::OK.into_response(), + None => StatusCode::NOT_FOUND.into_response(), + } +} + +pub fn proxy_router(state: Arc>) -> Router { + Router::new() + .route("/health", get(health_handler)) + .fallback(proxy_handler) + .with_state(state) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::dynamic_config::{BodyConfig, RateLimitConfig}; + use crate::config::SiteConfig; + use axum::body::Body; + use axum::http::{Request, Response}; + use tower::ServiceExt; + + fn make_config_with_sites(sites: Vec) -> Arc> { + Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites( + sites, + RateLimitConfig { + requests_per_second: 10, + burst: 20, + }, + BodyConfig { + limit_bytes: 104857600, + }, + ))) + } + + async fn send_request( + router: &mut Router, + method: &str, + uri: &str, + host: Option<&str>, + ) -> Response { + let mut builder = Request::builder().method(method).uri(uri); + if let Some(h) = host { + builder = builder.header("Host", h); + } + let req = builder.body(Body::empty()).unwrap(); + router.oneshot(req).await.unwrap() + } + + #[tokio::test] + async fn health_path_returns_200_regardless_of_host() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/health", None).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn health_with_unknown_host_returns_200() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/health", Some("unknown.host")).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn missing_host_returns_400() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/some/path", None).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn unknown_host_returns_404() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/some/path", Some("unknown.host")).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn known_host_returns_200() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/some/path", Some("example.com")).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn host_matching_is_case_insensitive() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM")).await; + assert_eq!(resp.status(), StatusCode::OK); + + let resp = send_request(&mut router, "GET", "/path", Some("Example.Com")).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn host_with_port_stripped() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state); + + let resp = send_request(&mut router, "GET", "/path", Some("example.com:443")).await; + assert_eq!(resp.status(), StatusCode::OK); + + let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM:8443")).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn routing_table_update_visible_immediately() { + let state = make_config_with_sites(vec![SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }]); + let mut router = proxy_router(state.clone()); + + let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let new_config = DynamicConfig::from_sites( + vec![ + SiteConfig { + host: "example.com".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }, + SiteConfig { + host: "new.example.com".to_string(), + upstream: "127.0.0.1:9090".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }, + ], + RateLimitConfig { + requests_per_second: 10, + burst: 20, + }, + BodyConfig { + limit_bytes: 104857600, + }, + ); + state.store(Arc::new(new_config)); + + let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await; + assert_eq!(resp.status(), StatusCode::OK); + } } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 2d783aa..4301186 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,3 +1,5 @@ pub mod error; pub mod handler; pub mod headers; + +pub use crate::config::dynamic_config::normalize_host;