diff --git a/.worktrees/feat/config/cli-parsing b/.worktrees/feat/config/cli-parsing new file mode 160000 index 0000000..2791070 --- /dev/null +++ b/.worktrees/feat/config/cli-parsing @@ -0,0 +1 @@ +Subproject commit 2791070971b1cd2df0c6f3b29aa2f3cf058e8f5a diff --git a/.worktrees/feat/ops/admin-socket b/.worktrees/feat/ops/admin-socket new file mode 160000 index 0000000..f1cada0 --- /dev/null +++ b/.worktrees/feat/ops/admin-socket @@ -0,0 +1 @@ +Subproject commit f1cada010f2a123f4629908e271f18b592f3bc7e diff --git a/.worktrees/feat/ops/body-size-limit b/.worktrees/feat/ops/body-size-limit new file mode 160000 index 0000000..5fa0fc6 --- /dev/null +++ b/.worktrees/feat/ops/body-size-limit @@ -0,0 +1 @@ +Subproject commit 5fa0fc600ee97b4e2af1efd72639bbd90dc48465 diff --git a/.worktrees/feat/proxy/error-responses b/.worktrees/feat/proxy/error-responses new file mode 160000 index 0000000..23ed5cd --- /dev/null +++ b/.worktrees/feat/proxy/error-responses @@ -0,0 +1 @@ +Subproject commit 23ed5cde27df7990b9170bd3d87cafdf95e18747 diff --git a/.worktrees/feat/proxy/headers-and-forwarding b/.worktrees/feat/proxy/headers-and-forwarding new file mode 160000 index 0000000..2791070 --- /dev/null +++ b/.worktrees/feat/proxy/headers-and-forwarding @@ -0,0 +1 @@ +Subproject commit 2791070971b1cd2df0c6f3b29aa2f3cf058e8f5a diff --git a/.worktrees/feat/tls/http-redirect b/.worktrees/feat/tls/http-redirect new file mode 160000 index 0000000..2791070 --- /dev/null +++ b/.worktrees/feat/tls/http-redirect @@ -0,0 +1 @@ +Subproject commit 2791070971b1cd2df0c6f3b29aa2f3cf058e8f5a diff --git a/Cargo.lock b/Cargo.lock index 80e5554..30de6d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1606,6 +1606,7 @@ dependencies = [ "clap", "dashmap", "futures", + "http-body-util", "hyper", "rcgen", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index aac46d7..565c5ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ clap = { version = "=4.6.1", features = ["derive"] } signal-hook = "=0.3.18" anyhow = "=1.0.102" thiserror = "=2.0.18" +http-body-util = "=0.1.3" futures = "=0.3.31" dashmap = "=6.1" diff --git a/src/proxy/body_limit.rs b/src/proxy/body_limit.rs new file mode 100644 index 0000000..7d231e7 --- /dev/null +++ b/src/proxy/body_limit.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use axum::body::Body; +use axum::extract::State; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use http_body_util::Limited; + +use crate::config::DynamicConfig; + +pub const DEFAULT_BODY_LIMIT_BYTES: u64 = 104_857_600; + +pub async fn body_limit_middleware( + State(config): State>>, + request: axum::extract::Request, + next: axum::middleware::Next, +) -> axum::response::Response { + let limit = config.load().body.limit_bytes; + let limit = if limit == 0 { + DEFAULT_BODY_LIMIT_BYTES + } else { + limit + }; + + if let Some(content_length) = request.headers().get("content-length") { + if let Ok(length_str) = content_length.to_str() { + if let Ok(length) = length_str.parse::() { + if length > limit { + return (StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large").into_response(); + } + } + } + } + + let (parts, body) = request.into_parts(); + let limited_body = Limited::new(body, limit as usize); + let request = axum::extract::Request::from_parts(parts, Body::new(limited_body)); + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_body_limit_is_100mb() { + assert_eq!(DEFAULT_BODY_LIMIT_BYTES, 104_857_600); + } +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 4301186..fe54458 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,5 +1,22 @@ +pub mod body_limit; pub mod error; pub mod handler; pub mod headers; pub use crate::config::dynamic_config::normalize_host; + +use std::sync::Arc; + +use arc_swap::ArcSwap; + +use crate::config::DynamicConfig; + +pub fn router_with_body_limit( + router: axum::Router, + config: Arc>, +) -> axum::Router { + router.layer(axum::middleware::from_fn_with_state( + config, + body_limit::body_limit_middleware, + )) +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 05e16fa..fe03796 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -4,8 +4,13 @@ use std::sync::Arc; use std::time::Duration; use arc_swap::ArcSwap; -use axum::routing::get; +use axum::routing::{get, post}; use axum::Router; +use reverse_proxy::config::dynamic_config::{ + BodyConfig, DynamicConfig, RateLimitConfig, SiteConfig, +}; +use reverse_proxy::proxy::body_limit::DEFAULT_BODY_LIMIT_BYTES; +use reverse_proxy::proxy::router_with_body_limit; #[tokio::test] async fn test_upstream_spawn_and_connect() { @@ -248,3 +253,209 @@ async fn test_rate_limit_eviction_task() { handle.abort(); } + +fn test_dynamic_config_with_limit(limit_bytes: u64) -> Arc> { + let config = DynamicConfig { + sites: vec![SiteConfig { + host: "test.local".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }], + rate_limit: RateLimitConfig { + requests_per_second: 10, + burst: 20, + }, + body: BodyConfig { limit_bytes }, + routing_table: Default::default(), + }; + Arc::new(ArcSwap::from_pointee(config)) +} + +async fn spawn_server_with_limit(limit_bytes: u64) -> helpers::http_test_helper::TestUpstream { + let config = test_dynamic_config_with_limit(limit_bytes); + helpers::http_test_helper::TestUpstream::spawn(|| { + let app = Router::new().route( + "/", + post(|body: axum::body::Body| async move { + let _ = body; + "ok" + }), + ); + router_with_body_limit(app, config.clone()) + }) + .await +} + +#[tokio::test] +async fn test_body_limit_rejects_oversized_request() { + let server = spawn_server_with_limit(100).await; + let client = reqwest::Client::new(); + + let large_body = vec![0u8; 200]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(large_body) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::PAYLOAD_TOO_LARGE); + let body = resp.text().await.unwrap(); + assert_eq!(body, "Payload Too Large"); + + let _ = server.shutdown_tx.send(()); +} + +#[tokio::test] +async fn test_body_limit_allows_request_within_limit() { + let server = spawn_server_with_limit(100).await; + let client = reqwest::Client::new(); + + let small_body = vec![0u8; 50]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(small_body) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let _ = server.shutdown_tx.send(()); +} + +#[tokio::test] +async fn test_body_limit_allows_request_at_exact_limit() { + let server = spawn_server_with_limit(100).await; + let client = reqwest::Client::new(); + + let exact_body = vec![0u8; 100]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(exact_body) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let _ = server.shutdown_tx.send(()); +} + +#[tokio::test] +async fn test_body_limit_content_length_header_rejection() { + let server = spawn_server_with_limit(100).await; + let client = reqwest::Client::new(); + + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .header("content-length", "200") + .body(vec![0u8; 200]) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::PAYLOAD_TOO_LARGE); + let body = resp.text().await.unwrap(); + assert_eq!(body, "Payload Too Large"); + + let _ = server.shutdown_tx.send(()); +} + +#[tokio::test] +async fn test_body_limit_default_is_100mb() { + assert_eq!(DEFAULT_BODY_LIMIT_BYTES, 104_857_600); +} + +#[tokio::test] +async fn test_body_limit_config_reload_changes_limit() { + let config = test_dynamic_config_with_limit(100); + let config_clone = config.clone(); + + let server = helpers::http_test_helper::TestUpstream::spawn(|| { + let app = Router::new().route( + "/", + post(|body: axum::body::Body| async move { + let _ = body; + "ok" + }), + ); + router_with_body_limit(app, config_clone.clone()) + }) + .await; + + let client = reqwest::Client::new(); + + let small_body = vec![0u8; 50]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(small_body.clone()) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let medium_body = vec![0u8; 150]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(medium_body.clone()) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::PAYLOAD_TOO_LARGE); + + let new_config = DynamicConfig { + sites: vec![SiteConfig { + host: "test.local".to_string(), + upstream: "127.0.0.1:8080".to_string(), + upstream_scheme: "http".to_string(), + upstream_connect_timeout_secs: 5, + upstream_request_timeout_secs: 60, + }], + rate_limit: RateLimitConfig { + requests_per_second: 10, + burst: 20, + }, + body: BodyConfig { limit_bytes: 200 }, + routing_table: Default::default(), + }; + config.store(Arc::new(new_config)); + + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(medium_body) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let large_body = vec![0u8; 300]; + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body(large_body) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), reqwest::StatusCode::PAYLOAD_TOO_LARGE); + + let _ = server.shutdown_tx.send(()); +} + +#[tokio::test] +async fn test_body_limit_empty_body_request_succeeds() { + let server = spawn_server_with_limit(100).await; + let client = reqwest::Client::new(); + + let resp = client + .post(format!("http://127.0.0.1:{}/", server.addr.port())) + .body("") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + + let _ = server.shutdown_tx.send(()); +}