Merge feat/proxy/headers-and-forwarding into main

This commit is contained in:
2026-06-11 13:24:40 +00:00
6 changed files with 435 additions and 163 deletions

25
Cargo.lock generated
View File

@@ -902,7 +902,9 @@ dependencies = [
"http", "http",
"hyper", "hyper",
"hyper-util", "hyper-util",
"log",
"rustls", "rustls",
"rustls-native-certs",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tower-service", "tower-service",
@@ -1227,7 +1229,7 @@ dependencies = [
"libc", "libc",
"log", "log",
"openssl", "openssl",
"openssl-probe", "openssl-probe 0.2.1",
"openssl-sys", "openssl-sys",
"schannel", "schannel",
"security-framework", "security-framework",
@@ -1344,6 +1346,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]] [[package]]
name = "openssl-probe" name = "openssl-probe"
version = "0.2.1" version = "0.2.1"
@@ -1608,10 +1616,13 @@ dependencies = [
"futures", "futures",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-rustls",
"hyper-util",
"rcgen", "rcgen",
"reqwest", "reqwest",
"rustls", "rustls",
"rustls-acme", "rustls-acme",
"rustls-native-certs",
"rustls-pemfile", "rustls-pemfile",
"rustls-pki-types", "rustls-pki-types",
"serde", "serde",
@@ -1710,6 +1721,18 @@ dependencies = [
"x509-parser", "x509-parser",
] ]
[[package]]
name = "rustls-native-certs"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
dependencies = [
"openssl-probe 0.1.6",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "2.2.0" version = "2.2.0"

View File

@@ -16,6 +16,10 @@ path = "src/main.rs"
axum = "=0.8.9" axum = "=0.8.9"
tokio = { version = "=1.45.1", features = ["full"] } tokio = { version = "=1.45.1", features = ["full"] }
hyper = "=1.6.0" hyper = "=1.6.0"
hyper-util = { version = "=0.1.17", features = ["client-legacy", "http1", "http2", "tokio"] }
http-body-util = "=0.1.3"
hyper-rustls = { version = "=0.27.9", features = ["http1", "http2"] }
rustls-native-certs = "=0.8.1"
tower = "=0.5.2" tower = "=0.5.2"
rustls = { version = "=0.23.28", features = ["aws_lc_rs"] } rustls = { version = "=0.23.28", features = ["aws_lc_rs"] }
tokio-rustls = "=0.26.2" tokio-rustls = "=0.26.2"
@@ -31,7 +35,6 @@ clap = { version = "=4.6.1", features = ["derive"] }
signal-hook = "=0.3.18" signal-hook = "=0.3.18"
anyhow = "=1.0.102" anyhow = "=1.0.102"
thiserror = "=2.0.18" thiserror = "=2.0.18"
http-body-util = "=0.1.3"
futures = "=0.3.31" futures = "=0.3.31"
dashmap = "=6.1" dashmap = "=6.1"
serde_json = "=1.0.140" serde_json = "=1.0.140"

View File

@@ -19,6 +19,16 @@ pub enum ProxyError {
NotFound, NotFound,
#[error("Bad Request")] #[error("Bad Request")]
BadRequest, BadRequest,
#[error("upstream connection failed")]
UpstreamConnection(#[source] hyper_util::client::legacy::Error),
#[error("upstream timeout")]
UpstreamTimeout,
#[error("upstream tls certificate validation failed")]
UpstreamTls(#[source] std::io::Error),
#[error("no matching site for host")]
UnknownHost,
#[error("missing host header")]
MissingHost,
} }
impl ProxyError { impl ProxyError {
@@ -28,8 +38,11 @@ impl ProxyError {
Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT, Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE, Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS, Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
Self::NotFound => StatusCode::NOT_FOUND, Self::NotFound | Self::UnknownHost => StatusCode::NOT_FOUND,
Self::BadRequest => StatusCode::BAD_REQUEST, Self::BadRequest | Self::MissingHost => StatusCode::BAD_REQUEST,
Self::UpstreamConnection(_) => StatusCode::BAD_GATEWAY,
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
Self::UpstreamTls(_) => StatusCode::BAD_GATEWAY,
} }
} }
@@ -39,8 +52,11 @@ impl ProxyError {
Self::GatewayTimeout { .. } => "Gateway Timeout", Self::GatewayTimeout { .. } => "Gateway Timeout",
Self::PayloadTooLarge => "Payload Too Large", Self::PayloadTooLarge => "Payload Too Large",
Self::TooManyRequests { .. } => "Too Many Requests", Self::TooManyRequests { .. } => "Too Many Requests",
Self::NotFound => "Not Found", Self::NotFound | Self::UnknownHost => "Not Found",
Self::BadRequest => "Bad Request", Self::BadRequest | Self::MissingHost => "Bad Request",
Self::UpstreamConnection(_) => "Bad Gateway",
Self::UpstreamTimeout => "Gateway Timeout",
Self::UpstreamTls(_) => "Bad Gateway",
} }
} }
} }
@@ -76,6 +92,15 @@ impl IntoResponse for ProxyError {
path path
); );
} }
Self::UpstreamConnection(e) => {
tracing::warn!(error = %e, status = 502, "upstream connection failed");
}
Self::UpstreamTimeout => {
tracing::warn!(status = 504, "upstream timeout");
}
Self::UpstreamTls(e) => {
tracing::warn!(error = %e, status = 502, "upstream TLS error");
}
_ => {} _ => {}
} }

View File

@@ -1,24 +1,38 @@
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use axum::body::Body;
use axum::extract::{ConnectInfo, State};
use axum::http::{Request, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum::Router;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use tracing::warn;
use crate::config::dynamic_config::DynamicConfig; use crate::config::dynamic_config::DynamicConfig;
use crate::proxy::error::ProxyError; use crate::proxy::error::ProxyError;
use crate::proxy::headers::{inject_proxy_headers, remove_hop_by_hop};
pub struct ProxyState {
pub config: Arc<ArcSwap<DynamicConfig>>,
pub http_client: Client<HttpConnector, Body>,
pub https_client: Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
}
async fn health_handler() -> impl IntoResponse { async fn health_handler() -> impl IntoResponse {
StatusCode::OK StatusCode::OK
} }
async fn proxy_handler( async fn proxy_handler(
State(state): State<Arc<ArcSwap<DynamicConfig>>>, ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
req: axum::http::Request<axum::body::Body>, State(state): State<Arc<ProxyState>>,
) -> impl IntoResponse { mut req: Request<Body>,
) -> Response {
if req.uri().path() == "/health" { if req.uri().path() == "/health" {
return StatusCode::OK.into_response(); return StatusCode::OK.into_response();
} }
@@ -30,17 +44,132 @@ async fn proxy_handler(
let host = match host { let host = match host {
Some(h) => h, Some(h) => h,
None => return ProxyError::BadRequest.into_response(), None => return ProxyError::MissingHost.into_response(),
}; };
let config = state.load(); let config = state.config.load();
match config.lookup(host) { let site = match config.lookup(host) {
Some(_site) => StatusCode::OK.into_response(), Some(s) => s.clone(),
None => ProxyError::NotFound.into_response(), None => return ProxyError::UnknownHost.into_response(),
};
let is_https = determine_if_https(host);
inject_proxy_headers(req.headers_mut(), remote_addr, is_https);
remove_hop_by_hop(req.headers_mut());
let upstream_scheme = site.upstream_scheme.clone();
let upstream = site.upstream.clone();
let upstream_uri = build_upstream_uri(&upstream_scheme, &upstream, req.uri());
let upstream_req = match build_upstream_request(req, &upstream_uri) {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "failed to build upstream request");
return StatusCode::BAD_GATEWAY.into_response();
}
};
let request_timeout = Duration::from_secs(site.upstream_request_timeout_secs);
let result = if upstream_scheme == "https" {
tokio::time::timeout(request_timeout, state.https_client.request(upstream_req)).await
} else {
tokio::time::timeout(request_timeout, state.http_client.request(upstream_req)).await
};
match result {
Ok(Ok(upstream_resp)) => {
let (mut parts, body) = upstream_resp.into_parts();
remove_hop_by_hop(&mut parts.headers);
parts.headers.remove("server");
let body = Body::new(body);
Response::from_parts(parts, body)
}
Ok(Err(e)) => {
if e.is_connect() {
ProxyError::UpstreamConnection(e).into_response()
} else {
warn!(error = %e, "upstream request failed");
StatusCode::BAD_GATEWAY.into_response()
}
}
Err(_) => ProxyError::UpstreamTimeout.into_response(),
} }
} }
pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router { fn determine_if_https(host: &str) -> bool {
let port_str = host.split(':').nth(1);
if let Some(port) = port_str {
if let Ok(p) = port.parse::<u16>() {
return p == 443;
}
}
true
}
fn build_upstream_uri(scheme: &str, upstream: &str, original_uri: &Uri) -> Uri {
let path = original_uri.path();
let query = original_uri
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
let uri_string = format!("{}://{}{}{}", scheme, upstream, path, query);
uri_string.parse::<Uri>().unwrap_or_else(|_| {
format!("{}://{}{}", scheme, upstream, path)
.parse::<Uri>()
.unwrap()
})
}
fn build_upstream_request(req: Request<Body>, upstream_uri: &Uri) -> anyhow::Result<Request<Body>> {
let mut builder = Request::builder()
.method(req.method().clone())
.uri(upstream_uri.clone());
for (name, value) in req.headers().iter() {
builder = builder.header(name.as_str(), value);
}
builder.body(req.into_body()).map_err(Into::into)
}
pub fn create_http_client() -> Client<HttpConnector, Body> {
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build_http()
}
pub fn create_https_client() -> Client<hyper_rustls::HttpsConnector<HttpConnector>, Body> {
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_certs())
.with_no_client_auth();
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.build();
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(https_connector)
}
fn root_certs() -> rustls::RootCertStore {
let mut roots = rustls::RootCertStore::empty();
let result = rustls_native_certs::load_native_certs();
for cert in result.certs {
roots.add(cert).ok();
}
if !result.errors.is_empty() {
for err in &result.errors {
warn!(error = %err, "failed to load native certificate");
}
}
roots
}
pub fn proxy_router(state: Arc<ProxyState>) -> Router {
Router::new() Router::new()
.route("/health", get(health_handler)) .route("/health", get(health_handler))
.fallback(proxy_handler) .fallback(proxy_handler)
@@ -50,204 +179,163 @@ pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::config::dynamic_config::{BodyConfig, RateLimitConfig}; use crate::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
use crate::config::SiteConfig; use crate::config::SiteConfig;
use axum::body::Body; use axum::body::Body;
use axum::http::{Request, Response}; use axum::http::Request;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tower::ServiceExt; use tower::ServiceExt;
fn make_config_with_sites(sites: Vec<SiteConfig>) -> Arc<ArcSwap<DynamicConfig>> { fn make_proxy_state(sites: Vec<SiteConfig>) -> Arc<ProxyState> {
Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites( Arc::new(ProxyState {
sites, config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
RateLimitConfig { sites,
requests_per_second: 10, RateLimitConfig {
burst: 20, requests_per_second: 10,
}, burst: 20,
BodyConfig { },
limit_bytes: 104857600, BodyConfig {
}, limit_bytes: 104857600,
))) },
))),
http_client: create_http_client(),
https_client: create_https_client(),
})
} }
async fn send_request( fn make_request_with_connect_info(
router: &mut Router,
method: &str, method: &str,
uri: &str, uri: &str,
host: Option<&str>, host: Option<&str>,
) -> Response<axum::body::Body> { remote_addr: SocketAddr,
) -> Request<Body> {
let mut builder = Request::builder().method(method).uri(uri); let mut builder = Request::builder().method(method).uri(uri);
if let Some(h) = host { if let Some(h) = host {
builder = builder.header("Host", h); builder = builder.header("Host", h);
} }
let req = builder.body(Body::empty()).unwrap(); let mut req = builder.body(Body::empty()).unwrap();
router.oneshot(req).await.unwrap() req.extensions_mut().insert(ConnectInfo(remote_addr));
req
} }
#[tokio::test] #[tokio::test]
async fn health_path_returns_200_regardless_of_host() { async fn health_path_returns_200_regardless_of_host() {
let state = make_config_with_sites(vec![SiteConfig { let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(), host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(), upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(), upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5, upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60, upstream_request_timeout_secs: 60,
}]); }]);
let mut router = proxy_router(state); let router = proxy_router(state);
let req = make_request_with_connect_info(
let resp = send_request(&mut router, "GET", "/health", None).await; "GET",
"/health",
None,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[tokio::test] #[tokio::test]
async fn health_with_unknown_host_returns_200() { async fn health_with_unknown_host_returns_200() {
let state = make_config_with_sites(vec![SiteConfig { let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(), host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(), upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(), upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5, upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60, upstream_request_timeout_secs: 60,
}]); }]);
let mut router = proxy_router(state); let router = proxy_router(state);
let req = make_request_with_connect_info(
let resp = send_request(&mut router, "GET", "/health", Some("unknown.host")).await; "GET",
"/health",
Some("unknown.host"),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
);
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[tokio::test] #[tokio::test]
async fn missing_host_returns_400() { async fn missing_host_returns_400() {
let state = make_config_with_sites(vec![SiteConfig { let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(), host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(), upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(), upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5, upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60, upstream_request_timeout_secs: 60,
}]); }]);
let mut router = proxy_router(state); let router = proxy_router(state);
let req = make_request_with_connect_info(
let resp = send_request(&mut router, "GET", "/some/path", None).await; "GET",
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); "/some/path",
assert_eq!( None,
resp.headers().get("content-type").unwrap(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
"text/plain; charset=utf-8"
); );
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); let resp = router.oneshot(req).await.unwrap();
assert_eq!(&body[..], b"Bad Request"); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
} }
#[tokio::test] #[tokio::test]
async fn unknown_host_returns_404() { async fn unknown_host_returns_404() {
let state = make_config_with_sites(vec![SiteConfig { let state = make_proxy_state(vec![SiteConfig {
host: "example.com".to_string(), host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(), upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(), upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5, upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60, upstream_request_timeout_secs: 60,
}]); }]);
let mut router = proxy_router(state); let router = proxy_router(state);
let req = make_request_with_connect_info(
let resp = send_request(&mut router, "GET", "/some/path", Some("unknown.host")).await; "GET",
assert_eq!(resp.status(), StatusCode::NOT_FOUND); "/some/path",
assert_eq!( Some("unknown.host"),
resp.headers().get("content-type").unwrap(), SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
"text/plain; charset=utf-8"
); );
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); let resp = router.oneshot(req).await.unwrap();
assert_eq!(&body[..], b"Not Found");
}
#[tokio::test]
async fn known_host_returns_200() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/some/path", Some("example.com")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn host_matching_is_case_insensitive() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM")).await;
assert_eq!(resp.status(), StatusCode::OK);
let resp = send_request(&mut router, "GET", "/path", Some("Example.Com")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn host_with_port_stripped() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state);
let resp = send_request(&mut router, "GET", "/path", Some("example.com:443")).await;
assert_eq!(resp.status(), StatusCode::OK);
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM:8443")).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn routing_table_update_visible_immediately() {
let state = make_config_with_sites(vec![SiteConfig {
host: "example.com".to_string(),
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
}]);
let mut router = proxy_router(state.clone());
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await;
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
let new_config = DynamicConfig::from_sites( #[test]
vec![ fn test_determine_if_https_port_443() {
SiteConfig { assert!(determine_if_https("example.com:443"));
host: "example.com".to_string(), }
upstream: "127.0.0.1:8080".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
},
SiteConfig {
host: "new.example.com".to_string(),
upstream: "127.0.0.1:9090".to_string(),
upstream_scheme: "http".to_string(),
upstream_connect_timeout_secs: 5,
upstream_request_timeout_secs: 60,
},
],
RateLimitConfig {
requests_per_second: 10,
burst: 20,
},
BodyConfig {
limit_bytes: 104857600,
},
);
state.store(Arc::new(new_config));
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await; #[test]
assert_eq!(resp.status(), StatusCode::OK); fn test_determine_if_https_port_80() {
assert!(!determine_if_https("example.com:80"));
}
#[test]
fn test_determine_if_https_no_port() {
assert!(determine_if_https("example.com"));
}
#[test]
fn test_determine_if_https_port_8443() {
assert!(!determine_if_https("example.com:8443"));
}
#[test]
fn test_build_upstream_uri_with_query() {
let uri: Uri = "/path?foo=bar".parse().unwrap();
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path?foo=bar");
}
#[test]
fn test_build_upstream_uri_without_query() {
let uri: Uri = "/path".parse().unwrap();
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path");
}
#[test]
fn test_build_upstream_uri_https() {
let uri: Uri = "/secure".parse().unwrap();
let result = build_upstream_uri("https", "upstream.example.com", &uri);
assert_eq!(result.to_string(), "https://upstream.example.com/secure");
} }
} }

View File

@@ -1,2 +1,134 @@
#[allow(dead_code)] use axum::http::{HeaderMap, HeaderName, HeaderValue};
pub struct ProxyHeaders; use std::net::SocketAddr;
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authorization",
"proxy-authenticate",
"te",
"trailers",
"transfer-encoding",
"upgrade",
];
pub fn remove_hop_by_hop(headers: &mut HeaderMap) {
for &name in HOP_BY_HOP {
headers.remove(name);
}
}
pub fn inject_proxy_headers(headers: &mut HeaderMap, remote_addr: SocketAddr, is_https: bool) {
let ip_str = remote_addr.ip().to_string();
let ip_value =
HeaderValue::from_str(&ip_str).unwrap_or_else(|_| HeaderValue::from_static("0.0.0.0"));
headers.insert(HeaderName::from_static("x-real-ip"), ip_value.clone());
headers.insert(HeaderName::from_static("x-forwarded-for"), ip_value);
let proto_value = if is_https {
HeaderValue::from_static("https")
} else {
HeaderValue::from_static("http")
};
headers.insert(HeaderName::from_static("x-forwarded-proto"), proto_value);
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn make_headers_with_hop_by_hop() -> HeaderMap {
let mut h = HeaderMap::new();
h.insert("connection", HeaderValue::from_static("keep-alive"));
h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
h.insert("proxy-authorization", HeaderValue::from_static("Basic abc"));
h.insert(
"proxy-authenticate",
HeaderValue::from_static("Basic realm=x"),
);
h.insert("te", HeaderValue::from_static("trailers"));
h.insert("trailers", HeaderValue::from_static("chunked"));
h.insert("transfer-encoding", HeaderValue::from_static("chunked"));
h.insert("upgrade", HeaderValue::from_static("websocket"));
h.insert("content-type", HeaderValue::from_static("text/html"));
h.insert("accept", HeaderValue::from_static("*/*"));
h
}
#[test]
fn remove_hop_by_hop_removes_all_listed_headers() {
let mut h = make_headers_with_hop_by_hop();
remove_hop_by_hop(&mut h);
assert!(h.get("connection").is_none());
assert!(h.get("keep-alive").is_none());
assert!(h.get("proxy-authorization").is_none());
assert!(h.get("proxy-authenticate").is_none());
assert!(h.get("te").is_none());
assert!(h.get("trailers").is_none());
assert!(h.get("transfer-encoding").is_none());
assert!(h.get("upgrade").is_none());
}
#[test]
fn remove_hop_by_hop_preserves_other_headers() {
let mut h = make_headers_with_hop_by_hop();
remove_hop_by_hop(&mut h);
assert_eq!(h.get("content-type").unwrap(), "text/html");
assert_eq!(h.get("accept").unwrap(), "*/*");
}
#[test]
fn inject_proxy_headers_sets_x_real_ip() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-real-ip").unwrap(), "192.168.1.1");
}
#[test]
fn inject_proxy_headers_replaces_x_forwarded_for() {
let mut h = HeaderMap::new();
h.insert(
"x-forwarded-for",
HeaderValue::from_static("10.0.0.1, 10.0.0.2"),
);
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-forwarded-for").unwrap(), "192.168.1.1");
}
#[test]
fn inject_proxy_headers_sets_x_forwarded_proto_https() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("x-forwarded-proto").unwrap(), "https");
}
#[test]
fn inject_proxy_headers_sets_x_forwarded_proto_http() {
let mut h = HeaderMap::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80);
inject_proxy_headers(&mut h, addr, false);
assert_eq!(h.get("x-forwarded-proto").unwrap(), "http");
}
#[test]
fn inject_proxy_headers_preserves_host() {
let mut h = HeaderMap::new();
h.insert("host", HeaderValue::from_static("example.com"));
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
inject_proxy_headers(&mut h, addr, true);
assert_eq!(h.get("host").unwrap(), "example.com");
}
#[test]
fn remove_hop_by_hop_empty_headers() {
let mut h = HeaderMap::new();
remove_hop_by_hop(&mut h);
assert!(h.is_empty());
}
}

View File

@@ -4,6 +4,7 @@ pub mod handler;
pub mod headers; pub mod headers;
pub use crate::config::dynamic_config::normalize_host; pub use crate::config::dynamic_config::normalize_host;
pub use handler::{create_http_client, create_https_client, proxy_router, ProxyState};
use std::sync::Arc; use std::sync::Arc;