Implement proxy header injection, hop-by-hop removal, and request forwarding

- Add ProxyError enum with IntoResponse for error handling (400, 404, 502, 504)
- Implement proxy header injection: X-Real-IP, X-Forwarded-For (replaced, not appended), X-Forwarded-Proto
- Implement hop-by-hop header removal for both request and response headers
- Implement request forwarding via shared hyper::Client with HTTP and HTTPS support
- Add ProxyState with http_client and https_client instances shared via axum State
- Add per-site timeout overrides using tokio::time::timeout
- Add HTTPS upstream support with system native TLS root certificates
- No Server or Via headers added to responses
- Host header preserved as-is
- Add unit tests for header injection, hop-by-hop removal, and URI building
- Add integration tests for proxy forwarding, hop-by-hop removal, and 502 on unreachable upstream
This commit is contained in:
2026-06-11 13:18:56 +00:00
parent 2791070971
commit b9126a96f4
7 changed files with 647 additions and 150 deletions

26
Cargo.lock generated
View File

@@ -902,7 +902,9 @@ dependencies = [
"http",
"hyper",
"hyper-util",
"log",
"rustls",
"rustls-native-certs",
"tokio",
"tokio-rustls",
"tower-service",
@@ -1227,7 +1229,7 @@ dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-probe 0.2.1",
"openssl-sys",
"schannel",
"security-framework",
@@ -1344,6 +1346,12 @@ dependencies = [
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-probe"
version = "0.2.1"
@@ -1606,11 +1614,15 @@ dependencies = [
"clap",
"dashmap",
"futures",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"rcgen",
"reqwest",
"rustls",
"rustls-acme",
"rustls-native-certs",
"rustls-pemfile",
"rustls-pki-types",
"serde",
@@ -1708,6 +1720,18 @@ dependencies = [
"x509-parser",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
dependencies = [
"openssl-probe 0.1.6",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.2.0"

View File

@@ -16,6 +16,10 @@ path = "src/main.rs"
axum = "=0.8.9"
tokio = { version = "=1.45.1", features = ["full"] }
hyper = "=1.6.0"
hyper-util = { version = "=0.1.17", features = ["client-legacy", "http1", "http2", "tokio"] }
http-body-util = "=0.1.3"
hyper-rustls = { version = "=0.27.9", features = ["http1", "http2"] }
rustls-native-certs = "=0.8.1"
tower = "=0.5.2"
rustls = { version = "=0.23.28", features = ["aws_lc_rs"] }
tokio-rustls = "=0.26.2"

View File

@@ -1,2 +1,37 @@
#[allow(dead_code)]
pub struct ProxyError;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ProxyError {
#[error("upstream connection failed")]
UpstreamConnection(#[source] hyper_util::client::legacy::Error),
#[error("upstream timeout")]
UpstreamTimeout,
#[error("upstream tls certificate validation failed")]
UpstreamTls(#[source] std::io::Error),
#[error("no matching site for host")]
UnknownHost,
#[error("missing host header")]
MissingHost,
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let (status, body) = match &self {
ProxyError::UpstreamConnection(_) => (StatusCode::BAD_GATEWAY, "Bad Gateway"),
ProxyError::UpstreamTimeout => (StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout"),
ProxyError::UpstreamTls(_) => (StatusCode::BAD_GATEWAY, "Bad Gateway"),
ProxyError::UnknownHost => (StatusCode::NOT_FOUND, "Not Found"),
ProxyError::MissingHost => (StatusCode::BAD_REQUEST, "Bad Request"),
};
tracing::warn!(
error = %self,
status = status.as_u16(),
"proxy error"
);
(status, body).into_response()
}
}

View File

@@ -1,23 +1,38 @@
use std::net::SocketAddr;
use std::sync::Arc;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use std::time::Duration;
use arc_swap::ArcSwap;
use axum::body::Body;
use axum::extract::{ConnectInfo, State};
use axum::http::{Request, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum::Router;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use tracing::warn;
use crate::config::dynamic_config::DynamicConfig;
use crate::proxy::error::ProxyError;
use crate::proxy::headers::{inject_proxy_headers, remove_hop_by_hop};
pub struct ProxyState {
pub config: Arc<ArcSwap<DynamicConfig>>,
pub http_client: Client<HttpConnector, Body>,
pub https_client: Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
}
async fn health_handler() -> impl IntoResponse {
StatusCode::OK
}
async fn proxy_handler(
State(state): State<Arc<ArcSwap<DynamicConfig>>>,
req: axum::http::Request<axum::body::Body>,
) -> impl IntoResponse {
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<ProxyState>>,
mut req: Request<Body>,
) -> Response {
if req.uri().path() == "/health" {
return StatusCode::OK.into_response();
}
@@ -29,17 +44,132 @@ async fn proxy_handler(
let host = match host {
Some(h) => h,
None => return StatusCode::BAD_REQUEST.into_response(),
None => return ProxyError::MissingHost.into_response(),
};
let config = state.load();
match config.lookup(host) {
Some(_site) => StatusCode::OK.into_response(),
None => StatusCode::NOT_FOUND.into_response(),
let config = state.config.load();
let site = match config.lookup(host) {
Some(s) => s.clone(),
None => return ProxyError::UnknownHost.into_response(),
};
let is_https = determine_if_https(host);
inject_proxy_headers(req.headers_mut(), remote_addr, is_https);
remove_hop_by_hop(req.headers_mut());
let upstream_scheme = site.upstream_scheme.clone();
let upstream = site.upstream.clone();
let upstream_uri = build_upstream_uri(&upstream_scheme, &upstream, req.uri());
let upstream_req = match build_upstream_request(req, &upstream_uri) {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "failed to build upstream request");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let request_timeout = Duration::from_secs(site.upstream_request_timeout_secs);
let result = if upstream_scheme == "https" {
tokio::time::timeout(request_timeout, state.https_client.request(upstream_req)).await
} else {
tokio::time::timeout(request_timeout, state.http_client.request(upstream_req)).await
};
match result {
Ok(Ok(upstream_resp)) => {
let (mut parts, body) = upstream_resp.into_parts();
remove_hop_by_hop(&mut parts.headers);
parts.headers.remove("server");
let body = Body::new(body);
Response::from_parts(parts, body)
}
Ok(Err(e)) => {
if e.is_connect() {
ProxyError::UpstreamConnection(e).into_response()
} else {
warn!(error = %e, "upstream request failed");
StatusCode::BAD_GATEWAY.into_response()
}
}
Err(_) => ProxyError::UpstreamTimeout.into_response(),
}
}
pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router {
fn determine_if_https(host: &str) -> bool {
let port_str = host.split(':').nth(1);
if let Some(port) = port_str {
if let Ok(p) = port.parse::<u16>() {
return p == 443;
}
}
true
}
fn build_upstream_uri(scheme: &str, upstream: &str, original_uri: &Uri) -> Uri {
let path = original_uri.path();
let query = original_uri
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
let uri_string = format!("{}://{}{}{}", scheme, upstream, path, query);
uri_string.parse::<Uri>().unwrap_or_else(|_| {
format!("{}://{}{}", scheme, upstream, path)
.parse::<Uri>()
.unwrap()
})
}
fn build_upstream_request(req: Request<Body>, upstream_uri: &Uri) -> anyhow::Result<Request<Body>> {
let mut builder = Request::builder()
.method(req.method().clone())
.uri(upstream_uri.clone());
for (name, value) in req.headers().iter() {
builder = builder.header(name.as_str(), value);
}
builder.body(req.into_body()).map_err(Into::into)
}
pub fn create_http_client() -> Client<HttpConnector, Body> {
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build_http()
}
pub fn create_https_client() -> Client<hyper_rustls::HttpsConnector<HttpConnector>, Body> {
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_certs())
.with_no_client_auth();
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.build();
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(https_connector)
}
fn root_certs() -> rustls::RootCertStore {
let mut roots = rustls::RootCertStore::empty();
let result = rustls_native_certs::load_native_certs();
for cert in result.certs {
roots.add(cert).ok();
}
if !result.errors.is_empty() {
for err in &result.errors {
warn!(error = %err, "failed to load native certificate");
}
}
roots
}
pub fn proxy_router(state: Arc<ProxyState>) -> Router {
Router::new()
.route("/health", get(health_handler))
.fallback(proxy_handler)
@@ -49,14 +179,16 @@ pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::dynamic_config::{BodyConfig, RateLimitConfig};
use crate::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use crate::config::SiteConfig;
use axum::body::Body;
use axum::http::{Request, Response};
use axum::http::Request;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
fn make_config_with_sites(sites: Vec<SiteConfig>) -> Arc<ArcSwap<DynamicConfig>> {
Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
fn make_proxy_state(sites: Vec<SiteConfig>) -> Arc<ProxyState> {
Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
sites,
RateLimitConfig {
requests_per_second: 10,
@@ -65,176 +197,145 @@ mod tests {
BodyConfig {
limit_bytes: 104857600,
},
)))
))),
http_client: create_http_client(),
https_client: create_https_client(),
})
}
async fn send_request(
router: &mut Router,
fn make_request_with_connect_info(
method: &str,
uri: &str,
host: Option<&str>,
) -> Response<axum::body::Body> {
remote_addr: SocketAddr,
) -> Request<Body> {
let mut builder = Request::builder().method(method).uri(uri);
if let Some(h) = host {
builder = builder.header("Host", h);
}
let req = builder.body(Body::empty()).unwrap();
router.oneshot(req).await.unwrap()
let mut req = builder.body(Body::empty()).unwrap();
req.extensions_mut().insert(ConnectInfo(remote_addr));
req
}
#[tokio::test]
async fn health_path_returns_200_regardless_of_host() {
let state = make_config_with_sites(vec![SiteConfig {
let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/health", None).await;
let router = proxy_router(state);
let req = make_request_with_connect_info(
"GET",
"/health",
None,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn health_with_unknown_host_returns_200() {
let state = make_config_with_sites(vec![SiteConfig {
let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/health", Some("unknown.host")).await;
let router = proxy_router(state);
let req = make_request_with_connect_info(
"GET",
"/health",
Some("unknown.host"),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn missing_host_returns_400() {
let state = make_config_with_sites(vec![SiteConfig {
let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/some/path", None).await;
let router = proxy_router(state);
let req = make_request_with_connect_info(
"GET",
"/some/path",
None,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn unknown_host_returns_404() {
let state = make_config_with_sites(vec![SiteConfig {
let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/some/path", Some("unknown.host")).await;
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn known_host_returns_200() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/some/path", Some("example.com")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn host_matching_is_case_insensitive() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM")).await;
assert_eq!(resp.status(), StatusCode::OK);
let resp = send_request(&mut router, "GET", "/path", Some("Example.Com")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn host_with_port_stripped() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/path", Some("example.com:443")).await;
assert_eq!(resp.status(), StatusCode::OK);
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM:8443")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn routing_table_update_visible_immediately() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state.clone());
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await;
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let new_config = DynamicConfig::from_sites(
vec![
SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
},
SiteConfig {
host: "new.example.com".to_string(),
upstream: "127.0.0.1:9090".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
},
],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
let router = proxy_router(state);
let req = make_request_with_connect_info(
"GET",
"/some/path",
Some("unknown.host"),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
state.store(Arc::new(new_config));
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await;
assert_eq!(resp.status(), StatusCode::OK);
#[test]
fn test_determine_if_https_port_443() {
assert!(determine_if_https("example.com:443"));
}
#[test]
fn test_determine_if_https_port_80() {
assert!(!determine_if_https("example.com:80"));
}
#[test]
fn test_determine_if_https_no_port() {
assert!(determine_if_https("example.com"));
}
#[test]
fn test_determine_if_https_port_8443() {
assert!(!determine_if_https("example.com:8443"));
}
#[test]
fn test_build_upstream_uri_with_query() {
let uri: Uri = "/path?foo=bar".parse().unwrap();
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path?foo=bar");
}
#[test]
fn test_build_upstream_uri_without_query() {
let uri: Uri = "/path".parse().unwrap();
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path");
}
#[test]
fn test_build_upstream_uri_https() {
let uri: Uri = "/secure".parse().unwrap();
let result = build_upstream_uri("https", "upstream.example.com", &uri);
assert_eq!(result.to_string(), "https://upstream.example.com/secure");
}
}

View File

@@ -1,2 +1,134 @@
#[allow(dead_code)]
pub struct ProxyHeaders;
use axum::http::{HeaderMap, HeaderName, HeaderValue};
use std::net::SocketAddr;
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authorization",
"proxy-authenticate",
"te",
"trailers",
"transfer-encoding",
"upgrade",
];
pub fn remove_hop_by_hop(headers: &mut HeaderMap) {
for &name in HOP_BY_HOP {
headers.remove(name);
}
}
pub fn inject_proxy_headers(headers: &mut HeaderMap, remote_addr: SocketAddr, is_https: bool) {
let ip_str = remote_addr.ip().to_string();
let ip_value =
HeaderValue::from_str(&ip_str).unwrap_or_else(|_| HeaderValue::from_static("0.0.0.0"));
headers.insert(HeaderName::from_static("x-real-ip"), ip_value.clone());
headers.insert(HeaderName::from_static("x-forwarded-for"), ip_value);
let proto_value = if is_https {
HeaderValue::from_static("https")
} else {
HeaderValue::from_static("http")
};
headers.insert(HeaderName::from_static("x-forwarded-proto"), proto_value);
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn make_headers_with_hop_by_hop() -> HeaderMap {
let mut h = HeaderMap::new();
h.insert("connection", HeaderValue::from_static("keep-alive"));
h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
h.insert("proxy-authorization", HeaderValue::from_static("Basic abc"));
h.insert(
"proxy-authenticate",
HeaderValue::from_static("Basic realm=x"),
);
h.insert("te", HeaderValue::from_static("trailers"));
h.insert("trailers", HeaderValue::from_static("chunked"));
h.insert("transfer-encoding", HeaderValue::from_static("chunked"));
h.insert("upgrade", HeaderValue::from_static("websocket"));
h.insert("content-type", HeaderValue::from_static("text/html"));
h.insert("accept", HeaderValue::from_static("*/*"));
h
}
#[test]
fn remove_hop_by_hop_removes_all_listed_headers() {
let mut h = make_headers_with_hop_by_hop();
remove_hop_by_hop(&mut h);
assert!(h.get("connection").is_none());
assert!(h.get("keep-alive").is_none());
assert!(h.get("proxy-authorization").is_none());
assert!(h.get("proxy-authenticate").is_none());
assert!(h.get("te").is_none());
assert!(h.get("trailers").is_none());
assert!(h.get("transfer-encoding").is_none());
assert!(h.get("upgrade").is_none());
}
#[test]
fn remove_hop_by_hop_preserves_other_headers() {
let mut h = make_headers_with_hop_by_hop();
remove_hop_by_hop(&mut h);
assert_eq!(h.get("content-type").unwrap(), "text/html");
assert_eq!(h.get("accept").unwrap(), "*/*");
}
#[test]
fn inject_proxy_headers_sets_x_real_ip() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-real-ip").unwrap(), "192.168.1.1");
}
#[test]
fn inject_proxy_headers_replaces_x_forwarded_for() {
let mut h = HeaderMap::new();
h.insert(
"x-forwarded-for",
HeaderValue::from_static("10.0.0.1, 10.0.0.2"),
);
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-forwarded-for").unwrap(), "192.168.1.1");
}
#[test]
fn inject_proxy_headers_sets_x_forwarded_proto_https() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-forwarded-proto").unwrap(), "https");
}
#[test]
fn inject_proxy_headers_sets_x_forwarded_proto_http() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80);
inject_proxy_headers(&mut h, addr, false);
assert_eq!(h.get("x-forwarded-proto").unwrap(), "http");
}
#[test]
fn inject_proxy_headers_preserves_host() {
let mut h = HeaderMap::new();
h.insert("host", HeaderValue::from_static("example.com"));
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("host").unwrap(), "example.com");
}
#[test]
fn remove_hop_by_hop_empty_headers() {
let mut h = HeaderMap::new();
remove_hop_by_hop(&mut h);
assert!(h.is_empty());
}
}

View File

@@ -3,3 +3,4 @@ pub mod handler;
pub mod headers;
pub use crate::config::dynamic_config::normalize_host;
pub use handler::{create_http_client, create_https_client, proxy_router, ProxyState};

View File

@@ -248,3 +248,203 @@ async fn test_rate_limit_eviction_task() {
handle.abort();
}
#[tokio::test]
async fn test_proxy_forwards_request_to_upstream() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let upstream = helpers::http_test_helper::TestUpstream::spawn(|| {
axum::Router::new().route(
"/test",
axum::routing::get(|req: axum::extract::Request| async move {
let x_real_ip = req
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
let x_fwd_for = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
let x_fwd_proto = req
.headers()
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
axum::response::IntoResponse::into_response(format!(
"ip={}|for={}|proto={}",
x_real_ip, x_fwd_for, x_fwd_proto
))
}),
)
})
.await;
let upstream_addr = format!("127.0.0.1:{}", upstream.addr.port());
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "test.local".to_string(),
upstream: upstream_addr.clone(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/test")
.header("Host", "test.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)),
54321,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("ip=192.168.1.100"));
assert!(body_str.contains("for=192.168.1.100"));
assert!(body_str.contains("proto=https"));
let _ = upstream.shutdown_tx.send(());
}
#[tokio::test]
async fn test_proxy_removes_hop_by_hop_from_response() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let upstream = helpers::http_test_helper::TestUpstream::spawn(|| {
axum::Router::new().route(
"/",
axum::routing::get(|| async {
([(axum::http::header::CONNECTION, "keep-alive")], "hello")
}),
)
})
.await;
let upstream_addr = format!("127.0.0.1:{}", upstream.addr.port());
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "test.local".to_string(),
upstream: upstream_addr.clone(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/")
.header("Host", "test.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().get("connection").is_none());
let _ = upstream.shutdown_tx.send(());
}
#[tokio::test]
async fn test_proxy_returns_502_on_unreachable_upstream() {
use axum::body::Body;
use axum::extract::ConnectInfo;
use axum::http::{Request, StatusCode};
use reverse_proxy::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use reverse_proxy::config::SiteConfig;
use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt;
let state = Arc::new(ProxyState {
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
vec![SiteConfig {
host: "unreachable.local".to_string(),
upstream: "127.0.0.1:1".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 1,
upstream_request_timeout_secs: 2,
}],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
))),
http_client: create_http_client(),
https_client: create_https_client(),
});
let router = proxy_router(state);
let req = Request::builder()
.method("GET")
.uri("/")
.header("Host", "unreachable.local")
.extension(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
}