Implement token bucket rate limiting with IPv6 /64 normalization

- Add TokenBucket with nodelay semantics (nginx limit_req burst nodelay)
- Per-IP rate limiting: IPv4 /32, IPv6 /64 prefix normalization
- DashMap for concurrent access, ArcSwap for lock-free config reads
- Background eviction task for stale entry cleanup
- 429 response with plain text body, RATE_LIMIT log prefix
- Config reload adopts new rate/burst on next request without clearing state
- Unit tests for bucket algorithm and IPv6 normalization
- Integration tests for 429 responses and per-IP independence
This commit is contained in:
2026-06-11 13:01:25 +00:00
parent f1cada010f
commit 2791070971
5 changed files with 592 additions and 3 deletions

View File

@@ -1 +1,208 @@
pub mod bucket;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use axum::extract::Request;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::IntoResponse;
use dashmap::DashMap;
use tracing::warn;
use crate::config::DynamicConfig;
use bucket::{normalize_ip, TokenBucket};
pub struct RateLimiter {
buckets: DashMap<IpAddr, TokenBucket>,
config: Arc<ArcSwap<DynamicConfig>>,
}
impl RateLimiter {
pub fn new(config: Arc<ArcSwap<DynamicConfig>>) -> Self {
Self {
buckets: DashMap::new(),
config,
}
}
pub fn check_and_consume(&self, ip: IpAddr) -> bool {
let normalized = normalize_ip(ip);
let config = self.config.load();
let rate = config.rate_limit.requests_per_second as f64;
let burst = config.rate_limit.burst;
let entry = self.buckets.entry(normalized);
match entry {
dashmap::mapref::entry::Entry::Occupied(mut occupied) => {
occupied.get_mut().try_consume(rate, burst)
}
dashmap::mapref::entry::Entry::Vacant(vacant) => {
let mut bucket = TokenBucket::new(rate, burst);
let result = bucket.try_consume(rate, burst);
vacant.insert(bucket);
result
}
}
}
pub fn evict_stale(&self, max_age: Duration) {
let cutoff = std::time::Instant::now() - max_age;
self.buckets.retain(|_, bucket| bucket.last_access > cutoff);
}
pub fn contains_ip(&self, ip: IpAddr) -> bool {
self.buckets.contains_key(&normalize_ip(ip))
}
}
pub async fn rate_limit_middleware(
axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
req: Request,
next: Next,
) -> axum::response::Response {
let client_ip = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.and_then(|v| v.trim().parse::<IpAddr>().ok())
.or_else(|| {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|ci| ci.ip())
});
let Some(ip) = client_ip else {
return next.run(req).await;
};
let host = req
.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
let path = req.uri().path();
if !limiter.check_and_consume(ip) {
warn!(
"RATE_LIMIT client_ip={} host={} path={} status=429",
ip, host, path
);
return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response();
}
next.run(req).await
}
pub fn start_eviction_task(
limiter: Arc<RateLimiter>,
interval: Duration,
max_age: Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
limiter.evict_stale(max_age);
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::test_fixtures;
use std::net::Ipv6Addr;
use std::time::Duration;
fn make_limiter(rps: u32, burst: u32) -> Arc<RateLimiter> {
let mut config = test_fixtures::test_dynamic_config();
config.rate_limit = crate::config::RateLimitConfig {
requests_per_second: rps,
burst,
};
let config_arc = Arc::new(ArcSwap::from_pointee(config));
Arc::new(RateLimiter::new(config_arc))
}
#[test]
fn check_and_consume_allows_within_burst() {
let limiter = make_limiter(10, 5);
for _ in 0..5 {
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
}
}
#[test]
fn check_and_consume_rejects_above_burst() {
let limiter = make_limiter(10, 5);
for _ in 0..5 {
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
}
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
}
#[test]
fn check_and_consume_per_ip_independent() {
let limiter = make_limiter(10, 5);
for _ in 0..5 {
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
}
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 2])));
}
#[test]
fn check_and_consume_ipv6_normalized_to_64() {
let limiter = make_limiter(10, 5);
let ip1 = IpAddr::from(Ipv6Addr::new(
0x2001, 0x0db8, 0x85a3, 0x0001, 0x1111, 0x2222, 0x3333, 0x4444,
));
let ip2 = IpAddr::from(Ipv6Addr::new(
0x2001, 0x0db8, 0x85a3, 0x0001, 0x5555, 0x6666, 0x7777, 0x8888,
));
for _ in 0..5 {
assert!(limiter.check_and_consume(ip1));
}
assert!(!limiter.check_and_consume(ip2));
}
#[test]
fn evict_removes_stale_entries() {
let limiter = make_limiter(10, 20);
let ip1 = IpAddr::from([192, 168, 1, 1]);
let ip2 = IpAddr::from([192, 168, 1, 2]);
limiter.check_and_consume(ip1);
std::thread::sleep(Duration::from_millis(50));
limiter.check_and_consume(ip2);
limiter.evict_stale(Duration::from_millis(25));
assert!(!limiter.buckets.contains_key(&normalize_ip(ip1)));
assert!(limiter.buckets.contains_key(&normalize_ip(ip2)));
}
#[test]
fn config_reload_adopted_on_next_request() {
let config = test_fixtures::test_dynamic_config();
let config_arc = Arc::new(ArcSwap::from_pointee(config));
let limiter = Arc::new(RateLimiter::new(config_arc.clone()));
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
let mut new_config = test_fixtures::test_dynamic_config();
new_config.rate_limit = crate::config::RateLimitConfig {
requests_per_second: 10,
burst: 2,
};
config_arc.store(Arc::new(new_config));
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
}
}