diff --git a/Cargo.lock b/Cargo.lock index 5ac2c75..80e5554 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -489,6 +489,20 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.11.0" @@ -789,6 +803,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.17.1" @@ -1064,7 +1084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.17.1", ] [[package]] @@ -1584,6 +1604,7 @@ dependencies = [ "arc-swap", "axum", "clap", + "dashmap", "futures", "hyper", "rcgen", diff --git a/Cargo.toml b/Cargo.toml index 64b2969..aac46d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ signal-hook = "=0.3.18" anyhow = "=1.0.102" thiserror = "=2.0.18" futures = "=0.3.31" +dashmap = "=6.1" [dev-dependencies] rcgen = "=0.13" diff --git a/src/rate_limit/bucket.rs b/src/rate_limit/bucket.rs index dfeb146..58cb7b1 100644 --- a/src/rate_limit/bucket.rs +++ b/src/rate_limit/bucket.rs @@ -1,2 +1,185 @@ -#[allow(dead_code)] -pub struct TokenBucket; +use std::net::{IpAddr, Ipv6Addr}; +use std::time::Instant; + +pub struct TokenBucket { + pub tokens: f64, + pub last_refill: Instant, + pub rate: f64, + pub max: u32, + pub last_access: Instant, +} + +impl TokenBucket { + pub fn new(rate: f64, max: u32) -> Self { + Self { + tokens: max as f64, + last_refill: Instant::now(), + rate, + max, + last_access: Instant::now(), + } + } + + pub fn try_consume(&mut self, rate: f64, max: u32) -> bool { + self.refill(rate, max); + self.last_access = Instant::now(); + + if self.tokens >= 1.0 { + self.tokens -= 1.0; + true + } else { + false + } + } + + fn refill(&mut self, rate: f64, max: u32) { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_refill).as_millis() as f64; + let tokens_to_add = (elapsed * rate) / 1000.0; + self.tokens = (self.tokens + tokens_to_add).min(max as f64); + self.last_refill = now; + + if max < self.max { + self.tokens = self.tokens.min(max as f64); + } + self.max = max; + self.rate = rate; + } +} + +pub fn normalize_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V4(_) => ip, + IpAddr::V6(v6) => { + let segments = v6.segments(); + let normalized = Ipv6Addr::new( + segments[0], + segments[1], + segments[2], + segments[3], + 0, + 0, + 0, + 0, + ); + IpAddr::V6(normalized) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + use std::thread; + use std::time::Duration; + + #[test] + fn token_bucket_starts_full() { + let bucket = TokenBucket::new(10.0, 20); + assert_eq!(bucket.tokens, 20.0); + } + + #[test] + fn token_bucket_consume_within_burst() { + let mut bucket = TokenBucket::new(10.0, 20); + for _ in 0..20 { + assert!(bucket.try_consume(10.0, 20)); + } + } + + #[test] + fn token_bucket_reject_when_empty() { + let mut bucket = TokenBucket::new(10.0, 5); + for _ in 0..5 { + assert!(bucket.try_consume(10.0, 5)); + } + assert!(!bucket.try_consume(10.0, 5)); + } + + #[test] + fn token_bucket_refills_over_time() { + let mut bucket = TokenBucket::new(1000.0, 5); + for _ in 0..5 { + assert!(bucket.try_consume(1000.0, 5)); + } + assert!(!bucket.try_consume(1000.0, 5)); + + thread::sleep(Duration::from_millis(10)); + assert!(bucket.try_consume(1000.0, 5)); + } + + #[test] + fn token_bucket_caps_at_max() { + let mut bucket = TokenBucket::new(10.0, 5); + for _ in 0..5 { + assert!(bucket.try_consume(10.0, 5)); + } + thread::sleep(Duration::from_millis(500)); + bucket.try_consume(10.0, 5); + assert!(bucket.tokens <= 5.5); + } + + #[test] + fn token_bucket_config_reload_caps_tokens() { + let mut bucket = TokenBucket::new(10.0, 20); + assert!(bucket.try_consume(10.0, 20)); + + let new_max = 10; + bucket.refill(10.0, new_max); + assert!(bucket.tokens <= new_max as f64 + 1.0); + } + + #[test] + fn token_bucket_config_reload_adopts_new_rate() { + let mut bucket = TokenBucket::new(10.0, 5); + for _ in 0..5 { + assert!(bucket.try_consume(10.0, 5)); + } + assert!(!bucket.try_consume(10.0, 5)); + + thread::sleep(Duration::from_millis(200)); + assert!(bucket.try_consume(50.0, 5)); + } + + #[test] + fn normalize_ipv4_unchanged() { + let ip = IpAddr::from(Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(normalize_ip(ip), ip); + } + + #[test] + fn normalize_ipv6_truncates_to_64_prefix() { + let ip = IpAddr::from(Ipv6Addr::new( + 0x2001, 0x0db8, 0x85a3, 0x0001, 0x1234, 0x5678, 0x9abc, 0xdef0, + )); + let normalized = normalize_ip(ip); + let expected = IpAddr::from(Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0001, 0, 0, 0, 0)); + assert_eq!(normalized, expected); + } + + #[test] + fn normalize_ipv6_already_64_prefix() { + let ip = IpAddr::from(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 0)); + let normalized = normalize_ip(ip); + assert_eq!(normalized, ip); + } + + #[test] + fn normalize_ipv6_link_local() { + let ip = IpAddr::from(Ipv6Addr::new( + 0xfe80, 0, 0, 0, 0x1234, 0x5678, 0x9abc, 0xdef0, + )); + let normalized = normalize_ip(ip); + let expected = IpAddr::from(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0)); + assert_eq!(normalized, expected); + } + + #[test] + fn normalize_loopback_ipv6() { + let ip = IpAddr::from(Ipv6Addr::LOCALHOST); + let normalized = normalize_ip(ip); + let expected = IpAddr::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + assert_eq!(normalized, expected); + } +} diff --git a/src/rate_limit/mod.rs b/src/rate_limit/mod.rs index d25f087..47be64d 100644 --- a/src/rate_limit/mod.rs +++ b/src/rate_limit/mod.rs @@ -1 +1,208 @@ pub mod bucket; + +use std::net::IpAddr; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use axum::extract::Request; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::IntoResponse; +use dashmap::DashMap; +use tracing::warn; + +use crate::config::DynamicConfig; +use bucket::{normalize_ip, TokenBucket}; + +pub struct RateLimiter { + buckets: DashMap, + config: Arc>, +} + +impl RateLimiter { + pub fn new(config: Arc>) -> Self { + Self { + buckets: DashMap::new(), + config, + } + } + + pub fn check_and_consume(&self, ip: IpAddr) -> bool { + let normalized = normalize_ip(ip); + let config = self.config.load(); + let rate = config.rate_limit.requests_per_second as f64; + let burst = config.rate_limit.burst; + + let entry = self.buckets.entry(normalized); + match entry { + dashmap::mapref::entry::Entry::Occupied(mut occupied) => { + occupied.get_mut().try_consume(rate, burst) + } + dashmap::mapref::entry::Entry::Vacant(vacant) => { + let mut bucket = TokenBucket::new(rate, burst); + let result = bucket.try_consume(rate, burst); + vacant.insert(bucket); + result + } + } + } + + pub fn evict_stale(&self, max_age: Duration) { + let cutoff = std::time::Instant::now() - max_age; + self.buckets.retain(|_, bucket| bucket.last_access > cutoff); + } + + pub fn contains_ip(&self, ip: IpAddr) -> bool { + self.buckets.contains_key(&normalize_ip(ip)) + } +} + +pub async fn rate_limit_middleware( + axum::extract::State(limiter): axum::extract::State>, + req: Request, + next: Next, +) -> axum::response::Response { + let client_ip = req + .headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.split(',').next()) + .and_then(|v| v.trim().parse::().ok()) + .or_else(|| { + req.extensions() + .get::>() + .map(|ci| ci.ip()) + }); + + let Some(ip) = client_ip else { + return next.run(req).await; + }; + + let host = req + .headers() + .get("host") + .and_then(|v| v.to_str().ok()) + .unwrap_or("-"); + + let path = req.uri().path(); + + if !limiter.check_and_consume(ip) { + warn!( + "RATE_LIMIT client_ip={} host={} path={} status=429", + ip, host, path + ); + return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response(); + } + + next.run(req).await +} + +pub fn start_eviction_task( + limiter: Arc, + interval: Duration, + max_age: Duration, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval_timer = tokio::time::interval(interval); + loop { + interval_timer.tick().await; + limiter.evict_stale(max_age); + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::test_fixtures; + use std::net::Ipv6Addr; + use std::time::Duration; + + fn make_limiter(rps: u32, burst: u32) -> Arc { + let mut config = test_fixtures::test_dynamic_config(); + config.rate_limit = crate::config::RateLimitConfig { + requests_per_second: rps, + burst, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + Arc::new(RateLimiter::new(config_arc)) + } + + #[test] + fn check_and_consume_allows_within_burst() { + let limiter = make_limiter(10, 5); + for _ in 0..5 { + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + } + } + + #[test] + fn check_and_consume_rejects_above_burst() { + let limiter = make_limiter(10, 5); + for _ in 0..5 { + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + } + assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + } + + #[test] + fn check_and_consume_per_ip_independent() { + let limiter = make_limiter(10, 5); + for _ in 0..5 { + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + } + assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 2]))); + } + + #[test] + fn check_and_consume_ipv6_normalized_to_64() { + let limiter = make_limiter(10, 5); + let ip1 = IpAddr::from(Ipv6Addr::new( + 0x2001, 0x0db8, 0x85a3, 0x0001, 0x1111, 0x2222, 0x3333, 0x4444, + )); + let ip2 = IpAddr::from(Ipv6Addr::new( + 0x2001, 0x0db8, 0x85a3, 0x0001, 0x5555, 0x6666, 0x7777, 0x8888, + )); + for _ in 0..5 { + assert!(limiter.check_and_consume(ip1)); + } + assert!(!limiter.check_and_consume(ip2)); + } + + #[test] + fn evict_removes_stale_entries() { + let limiter = make_limiter(10, 20); + let ip1 = IpAddr::from([192, 168, 1, 1]); + let ip2 = IpAddr::from([192, 168, 1, 2]); + + limiter.check_and_consume(ip1); + std::thread::sleep(Duration::from_millis(50)); + limiter.check_and_consume(ip2); + + limiter.evict_stale(Duration::from_millis(25)); + assert!(!limiter.buckets.contains_key(&normalize_ip(ip1))); + assert!(limiter.buckets.contains_key(&normalize_ip(ip2))); + } + + #[test] + fn config_reload_adopted_on_next_request() { + let config = test_fixtures::test_dynamic_config(); + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(RateLimiter::new(config_arc.clone())); + + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + + let mut new_config = test_fixtures::test_dynamic_config(); + new_config.rate_limit = crate::config::RateLimitConfig { + requests_per_second: 10, + burst: 2, + }; + config_arc.store(Arc::new(new_config)); + + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); + } +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 2da36e7..05e16fa 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,5 +1,12 @@ mod helpers; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use axum::routing::get; +use axum::Router; + #[tokio::test] async fn test_upstream_spawn_and_connect() { let upstream = helpers::http_test_helper::TestUpstream::spawn_ok().await; @@ -71,3 +78,173 @@ async fn test_health_check_disabled_when_port_zero() { assert_ne!(addr.port(), 0); handle.abort(); } + +fn make_rate_limit_app(limiter: Arc) -> Router { + Router::new() + .route("/", get(|| async { "ok" })) + .layer(axum::middleware::from_fn_with_state( + limiter, + reverse_proxy::rate_limit::rate_limit_middleware, + )) +} + +#[tokio::test] +async fn test_rate_limit_allows_within_burst() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 5, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = make_rate_limit_app(limiter); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + for _ in 0..5 { + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "192.168.1.1") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + } +} + +#[tokio::test] +async fn test_rate_limit_rejects_above_burst() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 2, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = make_rate_limit_app(limiter); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + for _ in 0..2 { + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "10.0.0.50") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + } + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "10.0.0.50") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + let body = resp.text().await.unwrap(); + assert_eq!(body, "Too Many Requests"); +} + +#[tokio::test] +async fn test_rate_limit_429_response_body() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 1, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = make_rate_limit_app(limiter); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "203.0.113.50") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "203.0.113.50") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + let body = resp.text().await.unwrap(); + assert_eq!(body, "Too Many Requests"); +} + +#[tokio::test] +async fn test_rate_limit_per_ip_independent() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 1, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = make_rate_limit_app(limiter); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "192.168.1.1") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp2 = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("x-forwarded-for", "192.168.1.2") + .send() + .await + .unwrap(); + assert_eq!(resp2.status(), reqwest::StatusCode::OK); +} + +#[tokio::test] +async fn test_rate_limit_eviction_task() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 20, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + limiter.check_and_consume(std::net::IpAddr::from([192, 168, 1, 1])); + + let handle = reverse_proxy::rate_limit::start_eviction_task( + limiter.clone(), + Duration::from_millis(50), + Duration::from_millis(100), + ); + + tokio::time::sleep(Duration::from_millis(200)).await; + + assert!(!limiter.contains_ip(std::net::IpAddr::from([192, 168, 1, 1]))); + + handle.abort(); +}