fix(rate_limit): use ConnectInfo as sole IP source, reject without it

The rate limiter previously extracted client IP from the X-Forwarded-For
header first, falling back to ConnectInfo. This allowed attackers to bypass
rate limits by sending spoofed X-Forwarded-For headers. Per ADR-025, the
rate limiter now uses ConnectInfo<SocketAddr> exclusively and rejects
requests with 429 when ConnectInfo is absent.
This commit is contained in:
2026-06-12 14:00:31 +00:00
parent 54f1725173
commit ad9b9b9b78
2 changed files with 9 additions and 25 deletions

View File

@@ -64,24 +64,12 @@ pub async fn rate_limit_middleware(
next: Next, next: Next,
) -> axum::response::Response { ) -> axum::response::Response {
let client_ip = req let client_ip = req
.headers() .extensions()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.and_then(|v| v.trim().parse::<IpAddr>().ok())
.or_else(|| {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>() .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.ip()) .map(|ci| ci.ip());
});
let Some(ip) = client_ip else { let Some(ip) = client_ip else {
// If no client IP can be identified, the request passes through without rate return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response();
// limiting. In practice, ConnectInfo is always set by the server's
// ConnectInfoService, so this branch is unreachable. If the proxy were ever
// deployed without ConnectInfo propagation, rate limiting would silently become
// a no-op. Consider adding a warning log or returning 429 in a future phase.
return next.run(req).await;
}; };
let host = req let host = req

View File

@@ -87,13 +87,16 @@ async fn test_health_check_disabled_when_port_zero() {
handle.abort(); handle.abort();
} }
fn make_rate_limit_app(limiter: Arc<reverse_proxy::rate_limit::RateLimiter>) -> Router { fn make_rate_limit_app(
limiter: Arc<reverse_proxy::rate_limit::RateLimiter>,
) -> axum::extract::connect_info::IntoMakeServiceWithConnectInfo<Router, std::net::SocketAddr> {
Router::new() Router::new()
.route("/", get(|| async { "ok" })) .route("/", get(|| async { "ok" }))
.layer(axum::middleware::from_fn_with_state( .layer(axum::middleware::from_fn_with_state(
limiter, limiter,
reverse_proxy::rate_limit::rate_limit_middleware, reverse_proxy::rate_limit::rate_limit_middleware,
)) ))
.into_make_service_with_connect_info::<std::net::SocketAddr>()
} }
#[tokio::test] #[tokio::test]
@@ -116,7 +119,6 @@ async fn test_rate_limit_allows_within_burst() {
for _ in 0..5 { for _ in 0..5 {
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.1")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -144,7 +146,6 @@ async fn test_rate_limit_rejects_above_burst() {
for _ in 0..2 { for _ in 0..2 {
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "10.0.0.50")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -153,7 +154,6 @@ async fn test_rate_limit_rejects_above_burst() {
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "10.0.0.50")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -181,7 +181,6 @@ async fn test_rate_limit_429_response_body() {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "203.0.113.50")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -189,7 +188,6 @@ async fn test_rate_limit_429_response_body() {
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "203.0.113.50")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -217,7 +215,6 @@ async fn test_rate_limit_per_ip_independent() {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let resp = client let resp = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.1")
.send() .send()
.await .await
.unwrap(); .unwrap();
@@ -225,11 +222,10 @@ async fn test_rate_limit_per_ip_independent() {
let resp2 = client let resp2 = client
.get(format!("http://127.0.0.1:{}/", addr.port())) .get(format!("http://127.0.0.1:{}/", addr.port()))
.header("x-forwarded-for", "192.168.1.2")
.send() .send()
.await .await
.unwrap(); .unwrap();
assert_eq!(resp2.status(), reqwest::StatusCode::OK); assert_eq!(resp2.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
} }
#[tokio::test] #[tokio::test]