Implement proxy header injection, hop-by-hop removal, and request forwarding

- Add ProxyError enum with IntoResponse for error handling (400, 404, 502, 504)
- Implement proxy header injection: X-Real-IP, X-Forwarded-For (replaced, not appended), X-Forwarded-Proto
- Implement hop-by-hop header removal for both request and response headers
- Implement request forwarding via shared hyper::Client with HTTP and HTTPS support
- Add ProxyState with http_client and https_client instances shared via axum State
- Add per-site timeout overrides using tokio::time::timeout
- Add HTTPS upstream support with system native TLS root certificates
- No Server or Via headers added to responses
- Host header preserved as-is
- Add unit tests for header injection, hop-by-hop removal, and URI building
- Add integration tests for proxy forwarding, hop-by-hop removal, and 502 on unreachable upstream
This commit is contained in:
2026-06-11 13:18:56 +00:00
parent 2791070971
commit b9126a96f4
7 changed files with 647 additions and 150 deletions

View File

@@ -248,3 +248,203 @@ async fn test_rate_limit_eviction_task() {
handle.abort();
}
#[tokio::test]
async fn test_proxy_forwards_request_to_upstream() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let upstream = helpers::http_test_helper::TestUpstream::spawn(|| {
axum::Router::new().route(
"/test",
axum::routing::get(|req: axum::extract::Request| async move {
let x_real_ip = req
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
let x_fwd_for = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
let x_fwd_proto = req
.headers()
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
axum::response::IntoResponse::into_response(format!(
"ip={}|for={}|proto={}",
x_real_ip, x_fwd_for, x_fwd_proto
))
}),
)
})
.await;
let upstream_addr = format!("127.0.0.1:{}", upstream.addr.port());
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "test.local".to_string(),
upstream: upstream_addr.clone(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/test")
.header("Host", "test.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
54321,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("ip=192.168.1.100"));
assert!(body_str.contains("for=192.168.1.100"));
assert!(body_str.contains("proto=https"));
let _ = upstream.shutdown_tx.send(());
}
#[tokio::test]
async fn test_proxy_removes_hop_by_hop_from_response() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let upstream = helpers::http_test_helper::TestUpstream::spawn(|| {
axum::Router::new().route(
"/",
axum::routing::get(|| async {
([(axum::http::header::CONNECTION, "keep-alive")], "hello")
}),
)
})
.await;
let upstream_addr = format!("127.0.0.1:{}", upstream.addr.port());
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "test.local".to_string(),
upstream: upstream_addr.clone(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/")
.header("Host", "test.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().get("connection").is_none());
let _ = upstream.shutdown_tx.send(());
}
#[tokio::test]
async fn test_proxy_returns_502_on_unreachable_upstream() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "unreachable.local".to_string(),
upstream: "127.0.0.1:1".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 1,
upstream_request_timeout_secs: 2,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/")
.header("Host", "unreachable.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
}