Implement token bucket rate limiting with IPv6 /64 normalization

- Add TokenBucket with nodelay semantics (nginx limit_req burst nodelay)
- Per-IP rate limiting: IPv4 /32, IPv6 /64 prefix normalization
- DashMap for concurrent access, ArcSwap for lock-free config reads
- Background eviction task for stale entry cleanup
- 429 response with plain text body, RATE_LIMIT log prefix
- Config reload adopts new rate/burst on next request without clearing state
- Unit tests for bucket algorithm and IPv6 normalization
- Integration tests for 429 responses and per-IP independence
This commit is contained in:
2026-06-11 13:01:25 +00:00
parent f1cada010f
commit 2791070971
5 changed files with 592 additions and 3 deletions

View File

@@ -1,5 +1,12 @@
mod helpers;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use axum::routing::get;
use axum::Router;
#[tokio::test]
async fn test_upstream_spawn_and_connect() {
let upstream = helpers::http_test_helper::TestUpstream::spawn_ok().await;
@@ -71,3 +78,173 @@ async fn test_health_check_disabled_when_port_zero() {
assert_ne!(addr.port(), 0);
handle.abort();
}
fn make_rate_limit_app(limiter: Arc<reverse_proxy::rate_limit::RateLimiter>) -> Router {
Router::new()
.route("/", get(|| async { "ok" }))
.layer(axum::middleware::from_fn_with_state(
limiter,
reverse_proxy::rate_limit::rate_limit_middleware,
))
}
#[tokio::test]
async fn test_rate_limit_allows_within_burst() {
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
config.rate_limit = reverse_proxy::config::RateLimitConfig {
requests_per_second: 10,
burst: 5,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
let app = make_rate_limit_app(limiter);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
let client = reqwest::Client::new();
for _ in 0..5 {
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.1")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
}
}
#[tokio::test]
async fn test_rate_limit_rejects_above_burst() {
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
config.rate_limit = reverse_proxy::config::RateLimitConfig {
requests_per_second: 10,
burst: 2,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
let app = make_rate_limit_app(limiter);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
let client = reqwest::Client::new();
for _ in 0..2 {
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "10.0.0.50")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
}
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "10.0.0.50")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
let body = resp.text().await.unwrap();
assert_eq!(body, "Too Many Requests");
}
#[tokio::test]
async fn test_rate_limit_429_response_body() {
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
config.rate_limit = reverse_proxy::config::RateLimitConfig {
requests_per_second: 10,
burst: 1,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
let app = make_rate_limit_app(limiter);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
let client = reqwest::Client::new();
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "203.0.113.50")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "203.0.113.50")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
let body = resp.text().await.unwrap();
assert_eq!(body, "Too Many Requests");
}
#[tokio::test]
async fn test_rate_limit_per_ip_independent() {
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
config.rate_limit = reverse_proxy::config::RateLimitConfig {
requests_per_second: 10,
burst: 1,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
let app = make_rate_limit_app(limiter);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
let client = reqwest::Client::new();
let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.1")
.send()
.await
.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
let resp2 = client
.get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.2")
.send()
.await
.unwrap();
assert_eq!(resp2.status(), reqwest::StatusCode::OK);
}
#[tokio::test]
async fn test_rate_limit_eviction_task() {
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
config.rate_limit = reverse_proxy::config::RateLimitConfig {
requests_per_second: 10,
burst: 20,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
limiter.check_and_consume(std::net::IpAddr::from([192, 168, 1, 1]));
let handle = reverse_proxy::rate_limit::start_eviction_task(
limiter.clone(),
Duration::from_millis(50),
Duration::from_millis(100),
);
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(!limiter.contains_ip(std::net::IpAddr::from([192, 168, 1, 1])));
handle.abort();
}