Implement token bucket rate limiting with IPv6 /64 normalization
- Add TokenBucket with nodelay semantics (nginx limit_req burst nodelay) - Per-IP rate limiting: IPv4 /32, IPv6 /64 prefix normalization - DashMap for concurrent access, ArcSwap for lock-free config reads - Background eviction task for stale entry cleanup - 429 response with plain text body, RATE_LIMIT log prefix - Config reload adopts new rate/burst on next request without clearing state - Unit tests for bucket algorithm and IPv6 normalization - Integration tests for 429 responses and per-IP independence
This commit is contained in:
23
Cargo.lock
generated
23
Cargo.lock
generated
@@ -489,6 +489,20 @@ version = "0.8.21"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dashmap"
|
||||||
|
version = "6.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"lock_api",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "data-encoding"
|
name = "data-encoding"
|
||||||
version = "2.11.0"
|
version = "2.11.0"
|
||||||
@@ -789,6 +803,12 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.14.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.17.1"
|
version = "0.17.1"
|
||||||
@@ -1064,7 +1084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9"
|
checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"equivalent",
|
"equivalent",
|
||||||
"hashbrown",
|
"hashbrown 0.17.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1584,6 +1604,7 @@ dependencies = [
|
|||||||
"arc-swap",
|
"arc-swap",
|
||||||
"axum",
|
"axum",
|
||||||
"clap",
|
"clap",
|
||||||
|
"dashmap",
|
||||||
"futures",
|
"futures",
|
||||||
"hyper",
|
"hyper",
|
||||||
"rcgen",
|
"rcgen",
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ signal-hook = "=0.3.18"
|
|||||||
anyhow = "=1.0.102"
|
anyhow = "=1.0.102"
|
||||||
thiserror = "=2.0.18"
|
thiserror = "=2.0.18"
|
||||||
futures = "=0.3.31"
|
futures = "=0.3.31"
|
||||||
|
dashmap = "=6.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rcgen = "=0.13"
|
rcgen = "=0.13"
|
||||||
|
|||||||
@@ -1,2 +1,185 @@
|
|||||||
#[allow(dead_code)]
|
use std::net::{IpAddr, Ipv6Addr};
|
||||||
pub struct TokenBucket;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
pub struct TokenBucket {
|
||||||
|
pub tokens: f64,
|
||||||
|
pub last_refill: Instant,
|
||||||
|
pub rate: f64,
|
||||||
|
pub max: u32,
|
||||||
|
pub last_access: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenBucket {
|
||||||
|
pub fn new(rate: f64, max: u32) -> Self {
|
||||||
|
Self {
|
||||||
|
tokens: max as f64,
|
||||||
|
last_refill: Instant::now(),
|
||||||
|
rate,
|
||||||
|
max,
|
||||||
|
last_access: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_consume(&mut self, rate: f64, max: u32) -> bool {
|
||||||
|
self.refill(rate, max);
|
||||||
|
self.last_access = Instant::now();
|
||||||
|
|
||||||
|
if self.tokens >= 1.0 {
|
||||||
|
self.tokens -= 1.0;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refill(&mut self, rate: f64, max: u32) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let elapsed = now.duration_since(self.last_refill).as_millis() as f64;
|
||||||
|
let tokens_to_add = (elapsed * rate) / 1000.0;
|
||||||
|
self.tokens = (self.tokens + tokens_to_add).min(max as f64);
|
||||||
|
self.last_refill = now;
|
||||||
|
|
||||||
|
if max < self.max {
|
||||||
|
self.tokens = self.tokens.min(max as f64);
|
||||||
|
}
|
||||||
|
self.max = max;
|
||||||
|
self.rate = rate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_ip(ip: IpAddr) -> IpAddr {
|
||||||
|
match ip {
|
||||||
|
IpAddr::V4(_) => ip,
|
||||||
|
IpAddr::V6(v6) => {
|
||||||
|
let segments = v6.segments();
|
||||||
|
let normalized = Ipv6Addr::new(
|
||||||
|
segments[0],
|
||||||
|
segments[1],
|
||||||
|
segments[2],
|
||||||
|
segments[3],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
IpAddr::V6(normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_starts_full() {
|
||||||
|
let bucket = TokenBucket::new(10.0, 20);
|
||||||
|
assert_eq!(bucket.tokens, 20.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_consume_within_burst() {
|
||||||
|
let mut bucket = TokenBucket::new(10.0, 20);
|
||||||
|
for _ in 0..20 {
|
||||||
|
assert!(bucket.try_consume(10.0, 20));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_reject_when_empty() {
|
||||||
|
let mut bucket = TokenBucket::new(10.0, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(bucket.try_consume(10.0, 5));
|
||||||
|
}
|
||||||
|
assert!(!bucket.try_consume(10.0, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_refills_over_time() {
|
||||||
|
let mut bucket = TokenBucket::new(1000.0, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(bucket.try_consume(1000.0, 5));
|
||||||
|
}
|
||||||
|
assert!(!bucket.try_consume(1000.0, 5));
|
||||||
|
|
||||||
|
thread::sleep(Duration::from_millis(10));
|
||||||
|
assert!(bucket.try_consume(1000.0, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_caps_at_max() {
|
||||||
|
let mut bucket = TokenBucket::new(10.0, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(bucket.try_consume(10.0, 5));
|
||||||
|
}
|
||||||
|
thread::sleep(Duration::from_millis(500));
|
||||||
|
bucket.try_consume(10.0, 5);
|
||||||
|
assert!(bucket.tokens <= 5.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_config_reload_caps_tokens() {
|
||||||
|
let mut bucket = TokenBucket::new(10.0, 20);
|
||||||
|
assert!(bucket.try_consume(10.0, 20));
|
||||||
|
|
||||||
|
let new_max = 10;
|
||||||
|
bucket.refill(10.0, new_max);
|
||||||
|
assert!(bucket.tokens <= new_max as f64 + 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn token_bucket_config_reload_adopts_new_rate() {
|
||||||
|
let mut bucket = TokenBucket::new(10.0, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(bucket.try_consume(10.0, 5));
|
||||||
|
}
|
||||||
|
assert!(!bucket.try_consume(10.0, 5));
|
||||||
|
|
||||||
|
thread::sleep(Duration::from_millis(200));
|
||||||
|
assert!(bucket.try_consume(50.0, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_ipv4_unchanged() {
|
||||||
|
let ip = IpAddr::from(Ipv4Addr::new(192, 168, 1, 1));
|
||||||
|
assert_eq!(normalize_ip(ip), ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_ipv6_truncates_to_64_prefix() {
|
||||||
|
let ip = IpAddr::from(Ipv6Addr::new(
|
||||||
|
0x2001, 0x0db8, 0x85a3, 0x0001, 0x1234, 0x5678, 0x9abc, 0xdef0,
|
||||||
|
));
|
||||||
|
let normalized = normalize_ip(ip);
|
||||||
|
let expected = IpAddr::from(Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0001, 0, 0, 0, 0));
|
||||||
|
assert_eq!(normalized, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_ipv6_already_64_prefix() {
|
||||||
|
let ip = IpAddr::from(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 0));
|
||||||
|
let normalized = normalize_ip(ip);
|
||||||
|
assert_eq!(normalized, ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_ipv6_link_local() {
|
||||||
|
let ip = IpAddr::from(Ipv6Addr::new(
|
||||||
|
0xfe80, 0, 0, 0, 0x1234, 0x5678, 0x9abc, 0xdef0,
|
||||||
|
));
|
||||||
|
let normalized = normalize_ip(ip);
|
||||||
|
let expected = IpAddr::from(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0));
|
||||||
|
assert_eq!(normalized, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_loopback_ipv6() {
|
||||||
|
let ip = IpAddr::from(Ipv6Addr::LOCALHOST);
|
||||||
|
let normalized = normalize_ip(ip);
|
||||||
|
let expected = IpAddr::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
|
||||||
|
assert_eq!(normalized, expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1,208 @@
|
|||||||
pub mod bucket;
|
pub mod bucket;
|
||||||
|
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
|
use axum::extract::Request;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::config::DynamicConfig;
|
||||||
|
use bucket::{normalize_ip, TokenBucket};
|
||||||
|
|
||||||
|
pub struct RateLimiter {
|
||||||
|
buckets: DashMap<IpAddr, TokenBucket>,
|
||||||
|
config: Arc<ArcSwap<DynamicConfig>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(config: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||||
|
Self {
|
||||||
|
buckets: DashMap::new(),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_and_consume(&self, ip: IpAddr) -> bool {
|
||||||
|
let normalized = normalize_ip(ip);
|
||||||
|
let config = self.config.load();
|
||||||
|
let rate = config.rate_limit.requests_per_second as f64;
|
||||||
|
let burst = config.rate_limit.burst;
|
||||||
|
|
||||||
|
let entry = self.buckets.entry(normalized);
|
||||||
|
match entry {
|
||||||
|
dashmap::mapref::entry::Entry::Occupied(mut occupied) => {
|
||||||
|
occupied.get_mut().try_consume(rate, burst)
|
||||||
|
}
|
||||||
|
dashmap::mapref::entry::Entry::Vacant(vacant) => {
|
||||||
|
let mut bucket = TokenBucket::new(rate, burst);
|
||||||
|
let result = bucket.try_consume(rate, burst);
|
||||||
|
vacant.insert(bucket);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evict_stale(&self, max_age: Duration) {
|
||||||
|
let cutoff = std::time::Instant::now() - max_age;
|
||||||
|
self.buckets.retain(|_, bucket| bucket.last_access > cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contains_ip(&self, ip: IpAddr) -> bool {
|
||||||
|
self.buckets.contains_key(&normalize_ip(ip))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn rate_limit_middleware(
|
||||||
|
axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
|
||||||
|
req: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
let client_ip = req
|
||||||
|
.headers()
|
||||||
|
.get("x-forwarded-for")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.split(',').next())
|
||||||
|
.and_then(|v| v.trim().parse::<IpAddr>().ok())
|
||||||
|
.or_else(|| {
|
||||||
|
req.extensions()
|
||||||
|
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
|
||||||
|
.map(|ci| ci.ip())
|
||||||
|
});
|
||||||
|
|
||||||
|
let Some(ip) = client_ip else {
|
||||||
|
return next.run(req).await;
|
||||||
|
};
|
||||||
|
|
||||||
|
let host = req
|
||||||
|
.headers()
|
||||||
|
.get("host")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("-");
|
||||||
|
|
||||||
|
let path = req.uri().path();
|
||||||
|
|
||||||
|
if !limiter.check_and_consume(ip) {
|
||||||
|
warn!(
|
||||||
|
"RATE_LIMIT client_ip={} host={} path={} status=429",
|
||||||
|
ip, host, path
|
||||||
|
);
|
||||||
|
return (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests").into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_eviction_task(
|
||||||
|
limiter: Arc<RateLimiter>,
|
||||||
|
interval: Duration,
|
||||||
|
max_age: Duration,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval_timer = tokio::time::interval(interval);
|
||||||
|
loop {
|
||||||
|
interval_timer.tick().await;
|
||||||
|
limiter.evict_stale(max_age);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::test_fixtures;
|
||||||
|
use std::net::Ipv6Addr;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
fn make_limiter(rps: u32, burst: u32) -> Arc<RateLimiter> {
|
||||||
|
let mut config = test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = crate::config::RateLimitConfig {
|
||||||
|
requests_per_second: rps,
|
||||||
|
burst,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
Arc::new(RateLimiter::new(config_arc))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_and_consume_allows_within_burst() {
|
||||||
|
let limiter = make_limiter(10, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_and_consume_rejects_above_burst() {
|
||||||
|
let limiter = make_limiter(10, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
}
|
||||||
|
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_and_consume_per_ip_independent() {
|
||||||
|
let limiter = make_limiter(10, 5);
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
}
|
||||||
|
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 2])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_and_consume_ipv6_normalized_to_64() {
|
||||||
|
let limiter = make_limiter(10, 5);
|
||||||
|
let ip1 = IpAddr::from(Ipv6Addr::new(
|
||||||
|
0x2001, 0x0db8, 0x85a3, 0x0001, 0x1111, 0x2222, 0x3333, 0x4444,
|
||||||
|
));
|
||||||
|
let ip2 = IpAddr::from(Ipv6Addr::new(
|
||||||
|
0x2001, 0x0db8, 0x85a3, 0x0001, 0x5555, 0x6666, 0x7777, 0x8888,
|
||||||
|
));
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(limiter.check_and_consume(ip1));
|
||||||
|
}
|
||||||
|
assert!(!limiter.check_and_consume(ip2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn evict_removes_stale_entries() {
|
||||||
|
let limiter = make_limiter(10, 20);
|
||||||
|
let ip1 = IpAddr::from([192, 168, 1, 1]);
|
||||||
|
let ip2 = IpAddr::from([192, 168, 1, 2]);
|
||||||
|
|
||||||
|
limiter.check_and_consume(ip1);
|
||||||
|
std::thread::sleep(Duration::from_millis(50));
|
||||||
|
limiter.check_and_consume(ip2);
|
||||||
|
|
||||||
|
limiter.evict_stale(Duration::from_millis(25));
|
||||||
|
assert!(!limiter.buckets.contains_key(&normalize_ip(ip1)));
|
||||||
|
assert!(limiter.buckets.contains_key(&normalize_ip(ip2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_reload_adopted_on_next_request() {
|
||||||
|
let config = test_fixtures::test_dynamic_config();
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(RateLimiter::new(config_arc.clone()));
|
||||||
|
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
|
||||||
|
let mut new_config = test_fixtures::test_dynamic_config();
|
||||||
|
new_config.rate_limit = crate::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 2,
|
||||||
|
};
|
||||||
|
config_arc.store(Arc::new(new_config));
|
||||||
|
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
assert!(limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
assert!(!limiter.check_and_consume(IpAddr::from([192, 168, 1, 1])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
mod helpers;
|
mod helpers;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::Router;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_upstream_spawn_and_connect() {
|
async fn test_upstream_spawn_and_connect() {
|
||||||
let upstream = helpers::http_test_helper::TestUpstream::spawn_ok().await;
|
let upstream = helpers::http_test_helper::TestUpstream::spawn_ok().await;
|
||||||
@@ -71,3 +78,173 @@ async fn test_health_check_disabled_when_port_zero() {
|
|||||||
assert_ne!(addr.port(), 0);
|
assert_ne!(addr.port(), 0);
|
||||||
handle.abort();
|
handle.abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_rate_limit_app(limiter: Arc<reverse_proxy::rate_limit::RateLimiter>) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/", get(|| async { "ok" }))
|
||||||
|
.layer(axum::middleware::from_fn_with_state(
|
||||||
|
limiter,
|
||||||
|
reverse_proxy::rate_limit::rate_limit_middleware,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_allows_within_burst() {
|
||||||
|
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = reverse_proxy::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 5,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
|
||||||
|
|
||||||
|
let app = make_rate_limit_app(limiter);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
for _ in 0..5 {
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "192.168.1.1")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_rejects_above_burst() {
|
||||||
|
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = reverse_proxy::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 2,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
|
||||||
|
|
||||||
|
let app = make_rate_limit_app(limiter);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
for _ in 0..2 {
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "10.0.0.50")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "10.0.0.50")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
let body = resp.text().await.unwrap();
|
||||||
|
assert_eq!(body, "Too Many Requests");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_429_response_body() {
|
||||||
|
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = reverse_proxy::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 1,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
|
||||||
|
|
||||||
|
let app = make_rate_limit_app(limiter);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "203.0.113.50")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "203.0.113.50")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
let body = resp.text().await.unwrap();
|
||||||
|
assert_eq!(body, "Too Many Requests");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_per_ip_independent() {
|
||||||
|
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = reverse_proxy::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 1,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
|
||||||
|
|
||||||
|
let app = make_rate_limit_app(limiter);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(async { axum::serve(listener, app).await.unwrap() });
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "192.168.1.1")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), reqwest::StatusCode::OK);
|
||||||
|
|
||||||
|
let resp2 = client
|
||||||
|
.get(format!("http://127.0.0.1:{}/", addr.port()))
|
||||||
|
.header("x-forwarded-for", "192.168.1.2")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp2.status(), reqwest::StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rate_limit_eviction_task() {
|
||||||
|
let mut config = reverse_proxy::config::test_fixtures::test_dynamic_config();
|
||||||
|
config.rate_limit = reverse_proxy::config::RateLimitConfig {
|
||||||
|
requests_per_second: 10,
|
||||||
|
burst: 20,
|
||||||
|
};
|
||||||
|
let config_arc = Arc::new(ArcSwap::from_pointee(config));
|
||||||
|
let limiter = Arc::new(reverse_proxy::rate_limit::RateLimiter::new(config_arc));
|
||||||
|
|
||||||
|
limiter.check_and_consume(std::net::IpAddr::from([192, 168, 1, 1]));
|
||||||
|
|
||||||
|
let handle = reverse_proxy::rate_limit::start_eviction_task(
|
||||||
|
limiter.clone(),
|
||||||
|
Duration::from_millis(50),
|
||||||
|
Duration::from_millis(100),
|
||||||
|
);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||||
|
|
||||||
|
assert!(!limiter.contains_ip(std::net::IpAddr::from([192, 168, 1, 1])));
|
||||||
|
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user