diff --git a/src/rate_limit/mod.rs b/src/rate_limit/mod.rs index 576f2ac..ba7d4b1 100644 --- a/src/rate_limit/mod.rs +++ b/src/rate_limit/mod.rs @@ -206,4 +206,108 @@ mod tests { assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1]))); } + + #[tokio::test] + async fn middleware_uses_connect_info_without_xff_header() { + let limiter = make_limiter(10, 5); + + let app = axum::Router::new() + .route("/", axum::routing::get(|| async { "ok" })) + .layer(axum::middleware::from_fn_with_state( + limiter, + rate_limit_middleware, + )) + .into_make_service_with_connect_info::(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + for _ in 0..5 { + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + } + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + } + + #[tokio::test] + async fn middleware_rejects_without_connect_info() { + let limiter = make_limiter(10, 20); + + let app = axum::Router::new() + .route("/", axum::routing::get(|| async { "ok" })) + .layer(axum::middleware::from_fn_with_state( + limiter, + rate_limit_middleware, + )); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + } + + #[tokio::test] + async fn middleware_ignores_xff_header_same_bucket() { + let limiter = make_limiter(10, 2); + + let app = axum::Router::new() + .route("/", axum::routing::get(|| async { "ok" })) + .layer(axum::middleware::from_fn_with_state( + limiter, + rate_limit_middleware, + )) + .into_make_service_with_connect_info::(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.1") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.2") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.3") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index d42ec74..a66876e 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -228,6 +228,81 @@ async fn test_rate_limit_per_ip_independent() { assert_eq!(resp2.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); } +#[tokio::test] +async fn test_rate_limit_without_connect_info_rejected_with_429() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 20, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = Router::new().route("/", get(|| async { "ok" })).layer( + axum::middleware::from_fn_with_state( + limiter, + reverse_proxy::rate_limit::rate_limit_middleware, + ), + ); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); + let body = resp.text().await.unwrap(); + assert_eq!(body, "Too Many Requests"); +} + +#[tokio::test] +async fn test_rate_limit_xff_header_ignored_same_bucket() { + let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config(); + config.rate_limit = reverse_proxy::config::RateLimitConfig { + requests_per_second: 10, + burst: 2, + }; + let config_arc = Arc::new(ArcSwap::from_pointee(config)); + let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc)); + + let app = make_rate_limit_app(limiter); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async { axum::serve(listener, app).await.unwrap() }); + + let client = reqwest::Client::new(); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.1") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.2") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let resp = client + .get(format!("http://127.0.0.1:{}/", addr.port())) + .header("X-Forwarded-For", "10.0.0.3") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS); +} + #[tokio::test] async fn test_rate_limit_eviction_task() { let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();