Merge feat/proxy/headers-and-forwarding into main
This commit is contained in:
25
Cargo.lock
generated
25
Cargo.lock
generated
@@ -902,7 +902,9 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
|
"log",
|
||||||
"rustls",
|
"rustls",
|
||||||
|
"rustls-native-certs",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
@@ -1227,7 +1229,7 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"openssl",
|
"openssl",
|
||||||
"openssl-probe",
|
"openssl-probe 0.2.1",
|
||||||
"openssl-sys",
|
"openssl-sys",
|
||||||
"schannel",
|
"schannel",
|
||||||
"security-framework",
|
"security-framework",
|
||||||
@@ -1344,6 +1346,12 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openssl-probe"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl-probe"
|
name = "openssl-probe"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@@ -1608,10 +1616,13 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
|
"hyper-rustls",
|
||||||
|
"hyper-util",
|
||||||
"rcgen",
|
"rcgen",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rustls",
|
"rustls",
|
||||||
"rustls-acme",
|
"rustls-acme",
|
||||||
|
"rustls-native-certs",
|
||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -1710,6 +1721,18 @@ dependencies = [
|
|||||||
"x509-parser",
|
"x509-parser",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-native-certs"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
|
||||||
|
dependencies = [
|
||||||
|
"openssl-probe 0.1.6",
|
||||||
|
"rustls-pki-types",
|
||||||
|
"schannel",
|
||||||
|
"security-framework",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pemfile"
|
name = "rustls-pemfile"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ path = "src/main.rs"
|
|||||||
axum = "=0.8.9"
|
axum = "=0.8.9"
|
||||||
tokio = { version = "=1.45.1", features = ["full"] }
|
tokio = { version = "=1.45.1", features = ["full"] }
|
||||||
hyper = "=1.6.0"
|
hyper = "=1.6.0"
|
||||||
|
hyper-util = { version = "=0.1.17", features = ["client-legacy", "http1", "http2", "tokio"] }
|
||||||
|
http-body-util = "=0.1.3"
|
||||||
|
hyper-rustls = { version = "=0.27.9", features = ["http1", "http2"] }
|
||||||
|
rustls-native-certs = "=0.8.1"
|
||||||
tower = "=0.5.2"
|
tower = "=0.5.2"
|
||||||
rustls = { version = "=0.23.28", features = ["aws_lc_rs"] }
|
rustls = { version = "=0.23.28", features = ["aws_lc_rs"] }
|
||||||
tokio-rustls = "=0.26.2"
|
tokio-rustls = "=0.26.2"
|
||||||
@@ -31,7 +35,6 @@ clap = { version = "=4.6.1", features = ["derive"] }
|
|||||||
signal-hook = "=0.3.18"
|
signal-hook = "=0.3.18"
|
||||||
anyhow = "=1.0.102"
|
anyhow = "=1.0.102"
|
||||||
thiserror = "=2.0.18"
|
thiserror = "=2.0.18"
|
||||||
http-body-util = "=0.1.3"
|
|
||||||
futures = "=0.3.31"
|
futures = "=0.3.31"
|
||||||
dashmap = "=6.1"
|
dashmap = "=6.1"
|
||||||
serde_json = "=1.0.140"
|
serde_json = "=1.0.140"
|
||||||
|
|||||||
@@ -19,6 +19,16 @@ pub enum ProxyError {
|
|||||||
NotFound,
|
NotFound,
|
||||||
#[error("Bad Request")]
|
#[error("Bad Request")]
|
||||||
BadRequest,
|
BadRequest,
|
||||||
|
#[error("upstream connection failed")]
|
||||||
|
UpstreamConnection(#[source] hyper_util::client::legacy::Error),
|
||||||
|
#[error("upstream timeout")]
|
||||||
|
UpstreamTimeout,
|
||||||
|
#[error("upstream tls certificate validation failed")]
|
||||||
|
UpstreamTls(#[source] std::io::Error),
|
||||||
|
#[error("no matching site for host")]
|
||||||
|
UnknownHost,
|
||||||
|
#[error("missing host header")]
|
||||||
|
MissingHost,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProxyError {
|
impl ProxyError {
|
||||||
@@ -28,8 +38,11 @@ impl ProxyError {
|
|||||||
Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
|
Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
|
||||||
Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
||||||
Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
||||||
Self::NotFound => StatusCode::NOT_FOUND,
|
Self::NotFound | Self::UnknownHost => StatusCode::NOT_FOUND,
|
||||||
Self::BadRequest => StatusCode::BAD_REQUEST,
|
Self::BadRequest | Self::MissingHost => StatusCode::BAD_REQUEST,
|
||||||
|
Self::UpstreamConnection(_) => StatusCode::BAD_GATEWAY,
|
||||||
|
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
|
||||||
|
Self::UpstreamTls(_) => StatusCode::BAD_GATEWAY,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,8 +52,11 @@ impl ProxyError {
|
|||||||
Self::GatewayTimeout { .. } => "Gateway Timeout",
|
Self::GatewayTimeout { .. } => "Gateway Timeout",
|
||||||
Self::PayloadTooLarge => "Payload Too Large",
|
Self::PayloadTooLarge => "Payload Too Large",
|
||||||
Self::TooManyRequests { .. } => "Too Many Requests",
|
Self::TooManyRequests { .. } => "Too Many Requests",
|
||||||
Self::NotFound => "Not Found",
|
Self::NotFound | Self::UnknownHost => "Not Found",
|
||||||
Self::BadRequest => "Bad Request",
|
Self::BadRequest | Self::MissingHost => "Bad Request",
|
||||||
|
Self::UpstreamConnection(_) => "Bad Gateway",
|
||||||
|
Self::UpstreamTimeout => "Gateway Timeout",
|
||||||
|
Self::UpstreamTls(_) => "Bad Gateway",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -76,6 +92,15 @@ impl IntoResponse for ProxyError {
|
|||||||
path
|
path
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
Self::UpstreamConnection(e) => {
|
||||||
|
tracing::warn!(error = %e, status = 502, "upstream connection failed");
|
||||||
|
}
|
||||||
|
Self::UpstreamTimeout => {
|
||||||
|
tracing::warn!(status = 504, "upstream timeout");
|
||||||
|
}
|
||||||
|
Self::UpstreamTls(e) => {
|
||||||
|
tracing::warn!(error = %e, status = 502, "upstream TLS error");
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,38 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use axum::extract::State;
|
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::response::IntoResponse;
|
|
||||||
use axum::routing::get;
|
|
||||||
use axum::Router;
|
|
||||||
|
|
||||||
use arc_swap::ArcSwap;
|
use arc_swap::ArcSwap;
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::extract::{ConnectInfo, State};
|
||||||
|
use axum::http::{Request, StatusCode, Uri};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::Router;
|
||||||
|
use hyper_util::client::legacy::connect::HttpConnector;
|
||||||
|
use hyper_util::client::legacy::Client;
|
||||||
|
use hyper_util::rt::TokioExecutor;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::config::dynamic_config::DynamicConfig;
|
use crate::config::dynamic_config::DynamicConfig;
|
||||||
use crate::proxy::error::ProxyError;
|
use crate::proxy::error::ProxyError;
|
||||||
|
use crate::proxy::headers::{inject_proxy_headers, remove_hop_by_hop};
|
||||||
|
|
||||||
|
pub struct ProxyState {
|
||||||
|
pub config: Arc<ArcSwap<DynamicConfig>>,
|
||||||
|
pub http_client: Client<HttpConnector, Body>,
|
||||||
|
pub https_client: Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
|
||||||
|
}
|
||||||
|
|
||||||
async fn health_handler() -> impl IntoResponse {
|
async fn health_handler() -> impl IntoResponse {
|
||||||
StatusCode::OK
|
StatusCode::OK
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn proxy_handler(
|
async fn proxy_handler(
|
||||||
State(state): State<Arc<ArcSwap<DynamicConfig>>>,
|
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
|
||||||
req: axum::http::Request<axum::body::Body>,
|
State(state): State<Arc<ProxyState>>,
|
||||||
) -> impl IntoResponse {
|
mut req: Request<Body>,
|
||||||
|
) -> Response {
|
||||||
if req.uri().path() == "/health" {
|
if req.uri().path() == "/health" {
|
||||||
return StatusCode::OK.into_response();
|
return StatusCode::OK.into_response();
|
||||||
}
|
}
|
||||||
@@ -30,17 +44,132 @@ async fn proxy_handler(
|
|||||||
|
|
||||||
let host = match host {
|
let host = match host {
|
||||||
Some(h) => h,
|
Some(h) => h,
|
||||||
None => return ProxyError::BadRequest.into_response(),
|
None => return ProxyError::MissingHost.into_response(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = state.load();
|
let config = state.config.load();
|
||||||
match config.lookup(host) {
|
let site = match config.lookup(host) {
|
||||||
Some(_site) => StatusCode::OK.into_response(),
|
Some(s) => s.clone(),
|
||||||
None => ProxyError::NotFound.into_response(),
|
None => return ProxyError::UnknownHost.into_response(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let is_https = determine_if_https(host);
|
||||||
|
inject_proxy_headers(req.headers_mut(), remote_addr, is_https);
|
||||||
|
remove_hop_by_hop(req.headers_mut());
|
||||||
|
|
||||||
|
let upstream_scheme = site.upstream_scheme.clone();
|
||||||
|
let upstream = site.upstream.clone();
|
||||||
|
let upstream_uri = build_upstream_uri(&upstream_scheme, &upstream, req.uri());
|
||||||
|
|
||||||
|
let upstream_req = match build_upstream_request(req, &upstream_uri) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "failed to build upstream request");
|
||||||
|
return StatusCode::BAD_GATEWAY.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let request_timeout = Duration::from_secs(site.upstream_request_timeout_secs);
|
||||||
|
|
||||||
|
let result = if upstream_scheme == "https" {
|
||||||
|
tokio::time::timeout(request_timeout, state.https_client.request(upstream_req)).await
|
||||||
|
} else {
|
||||||
|
tokio::time::timeout(request_timeout, state.http_client.request(upstream_req)).await
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(upstream_resp)) => {
|
||||||
|
let (mut parts, body) = upstream_resp.into_parts();
|
||||||
|
remove_hop_by_hop(&mut parts.headers);
|
||||||
|
parts.headers.remove("server");
|
||||||
|
let body = Body::new(body);
|
||||||
|
Response::from_parts(parts, body)
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
if e.is_connect() {
|
||||||
|
ProxyError::UpstreamConnection(e).into_response()
|
||||||
|
} else {
|
||||||
|
warn!(error = %e, "upstream request failed");
|
||||||
|
StatusCode::BAD_GATEWAY.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => ProxyError::UpstreamTimeout.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router {
|
fn determine_if_https(host: &str) -> bool {
|
||||||
|
let port_str = host.split(':').nth(1);
|
||||||
|
if let Some(port) = port_str {
|
||||||
|
if let Ok(p) = port.parse::<u16>() {
|
||||||
|
return p == 443;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_upstream_uri(scheme: &str, upstream: &str, original_uri: &Uri) -> Uri {
|
||||||
|
let path = original_uri.path();
|
||||||
|
let query = original_uri
|
||||||
|
.query()
|
||||||
|
.map(|q| format!("?{}", q))
|
||||||
|
.unwrap_or_default();
|
||||||
|
let uri_string = format!("{}://{}{}{}", scheme, upstream, path, query);
|
||||||
|
uri_string.parse::<Uri>().unwrap_or_else(|_| {
|
||||||
|
format!("{}://{}{}", scheme, upstream, path)
|
||||||
|
.parse::<Uri>()
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_upstream_request(req: Request<Body>, upstream_uri: &Uri) -> anyhow::Result<Request<Body>> {
|
||||||
|
let mut builder = Request::builder()
|
||||||
|
.method(req.method().clone())
|
||||||
|
.uri(upstream_uri.clone());
|
||||||
|
|
||||||
|
for (name, value) in req.headers().iter() {
|
||||||
|
builder = builder.header(name.as_str(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.body(req.into_body()).map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_http_client() -> Client<HttpConnector, Body> {
|
||||||
|
Client::builder(TokioExecutor::new())
|
||||||
|
.pool_idle_timeout(Duration::from_secs(90))
|
||||||
|
.build_http()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_https_client() -> Client<hyper_rustls::HttpsConnector<HttpConnector>, Body> {
|
||||||
|
let tls_config = rustls::ClientConfig::builder()
|
||||||
|
.with_root_certificates(root_certs())
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
|
||||||
|
.with_tls_config(tls_config)
|
||||||
|
.https_or_http()
|
||||||
|
.enable_http1()
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Client::builder(TokioExecutor::new())
|
||||||
|
.pool_idle_timeout(Duration::from_secs(90))
|
||||||
|
.build(https_connector)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn root_certs() -> rustls::RootCertStore {
|
||||||
|
let mut roots = rustls::RootCertStore::empty();
|
||||||
|
let result = rustls_native_certs::load_native_certs();
|
||||||
|
for cert in result.certs {
|
||||||
|
roots.add(cert).ok();
|
||||||
|
}
|
||||||
|
if !result.errors.is_empty() {
|
||||||
|
for err in &result.errors {
|
||||||
|
warn!(error = %err, "failed to load native certificate");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
roots
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn proxy_router(state: Arc<ProxyState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/health", get(health_handler))
|
.route("/health", get(health_handler))
|
||||||
.fallback(proxy_handler)
|
.fallback(proxy_handler)
|
||||||
@@ -50,204 +179,163 @@ pub fn proxy_router(state: Arc<ArcSwap<DynamicConfig>>) -> Router {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::config::dynamic_config::{BodyConfig, RateLimitConfig};
|
use crate::config::dynamic_config::{BodyConfig, DynamicConfig, RateLimitConfig};
|
||||||
use crate::config::SiteConfig;
|
use crate::config::SiteConfig;
|
||||||
use axum::body::Body;
|
use axum::body::Body;
|
||||||
use axum::http::{Request, Response};
|
use axum::http::Request;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
fn make_config_with_sites(sites: Vec<SiteConfig>) -> Arc<ArcSwap<DynamicConfig>> {
|
fn make_proxy_state(sites: Vec<SiteConfig>) -> Arc<ProxyState> {
|
||||||
Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
|
Arc::new(ProxyState {
|
||||||
sites,
|
config: Arc::new(ArcSwap::from_pointee(DynamicConfig::from_sites(
|
||||||
RateLimitConfig {
|
sites,
|
||||||
requests_per_second: 10,
|
RateLimitConfig {
|
||||||
burst: 20,
|
requests_per_second: 10,
|
||||||
},
|
burst: 20,
|
||||||
BodyConfig {
|
},
|
||||||
limit_bytes: 104857600,
|
BodyConfig {
|
||||||
},
|
limit_bytes: 104857600,
|
||||||
)))
|
},
|
||||||
|
))),
|
||||||
|
http_client: create_http_client(),
|
||||||
|
https_client: create_https_client(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_request(
|
fn make_request_with_connect_info(
|
||||||
router: &mut Router,
|
|
||||||
method: &str,
|
method: &str,
|
||||||
uri: &str,
|
uri: &str,
|
||||||
host: Option<&str>,
|
host: Option<&str>,
|
||||||
) -> Response<axum::body::Body> {
|
remote_addr: SocketAddr,
|
||||||
|
) -> Request<Body> {
|
||||||
let mut builder = Request::builder().method(method).uri(uri);
|
let mut builder = Request::builder().method(method).uri(uri);
|
||||||
if let Some(h) = host {
|
if let Some(h) = host {
|
||||||
builder = builder.header("Host", h);
|
builder = builder.header("Host", h);
|
||||||
}
|
}
|
||||||
let req = builder.body(Body::empty()).unwrap();
|
let mut req = builder.body(Body::empty()).unwrap();
|
||||||
router.oneshot(req).await.unwrap()
|
req.extensions_mut().insert(ConnectInfo(remote_addr));
|
||||||
|
req
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn health_path_returns_200_regardless_of_host() {
|
async fn health_path_returns_200_regardless_of_host() {
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
let state = make_proxy_state(vec![SiteConfig {
|
||||||
host: "example.com".to_string(),
|
host: "example.com".to_string(),
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
upstream: "127.0.0.1:8080".to_string(),
|
||||||
upstream_scheme: "http".to_string(),
|
upstream_scheme: "http".to_string(),
|
||||||
upstream_connect_timeout_secs: 5,
|
upstream_connect_timeout_secs: 5,
|
||||||
upstream_request_timeout_secs: 60,
|
upstream_request_timeout_secs: 60,
|
||||||
}]);
|
}]);
|
||||||
let mut router = proxy_router(state);
|
let router = proxy_router(state);
|
||||||
|
let req = make_request_with_connect_info(
|
||||||
let resp = send_request(&mut router, "GET", "/health", None).await;
|
"GET",
|
||||||
|
"/health",
|
||||||
|
None,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||||
|
);
|
||||||
|
let resp = router.oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn health_with_unknown_host_returns_200() {
|
async fn health_with_unknown_host_returns_200() {
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
let state = make_proxy_state(vec![SiteConfig {
|
||||||
host: "example.com".to_string(),
|
host: "example.com".to_string(),
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
upstream: "127.0.0.1:8080".to_string(),
|
||||||
upstream_scheme: "http".to_string(),
|
upstream_scheme: "http".to_string(),
|
||||||
upstream_connect_timeout_secs: 5,
|
upstream_connect_timeout_secs: 5,
|
||||||
upstream_request_timeout_secs: 60,
|
upstream_request_timeout_secs: 60,
|
||||||
}]);
|
}]);
|
||||||
let mut router = proxy_router(state);
|
let router = proxy_router(state);
|
||||||
|
let req = make_request_with_connect_info(
|
||||||
let resp = send_request(&mut router, "GET", "/health", Some("unknown.host")).await;
|
"GET",
|
||||||
|
"/health",
|
||||||
|
Some("unknown.host"),
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||||
|
);
|
||||||
|
let resp = router.oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn missing_host_returns_400() {
|
async fn missing_host_returns_400() {
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
let state = make_proxy_state(vec![SiteConfig {
|
||||||
host: "example.com".to_string(),
|
host: "example.com".to_string(),
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
upstream: "127.0.0.1:8080".to_string(),
|
||||||
upstream_scheme: "http".to_string(),
|
upstream_scheme: "http".to_string(),
|
||||||
upstream_connect_timeout_secs: 5,
|
upstream_connect_timeout_secs: 5,
|
||||||
upstream_request_timeout_secs: 60,
|
upstream_request_timeout_secs: 60,
|
||||||
}]);
|
}]);
|
||||||
let mut router = proxy_router(state);
|
let router = proxy_router(state);
|
||||||
|
let req = make_request_with_connect_info(
|
||||||
let resp = send_request(&mut router, "GET", "/some/path", None).await;
|
"GET",
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
"/some/path",
|
||||||
assert_eq!(
|
None,
|
||||||
resp.headers().get("content-type").unwrap(),
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||||
"text/plain; charset=utf-8"
|
|
||||||
);
|
);
|
||||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
let resp = router.oneshot(req).await.unwrap();
|
||||||
assert_eq!(&body[..], b"Bad Request");
|
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn unknown_host_returns_404() {
|
async fn unknown_host_returns_404() {
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
let state = make_proxy_state(vec![SiteConfig {
|
||||||
host: "example.com".to_string(),
|
host: "example.com".to_string(),
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
upstream: "127.0.0.1:8080".to_string(),
|
||||||
upstream_scheme: "http".to_string(),
|
upstream_scheme: "http".to_string(),
|
||||||
upstream_connect_timeout_secs: 5,
|
upstream_connect_timeout_secs: 5,
|
||||||
upstream_request_timeout_secs: 60,
|
upstream_request_timeout_secs: 60,
|
||||||
}]);
|
}]);
|
||||||
let mut router = proxy_router(state);
|
let router = proxy_router(state);
|
||||||
|
let req = make_request_with_connect_info(
|
||||||
let resp = send_request(&mut router, "GET", "/some/path", Some("unknown.host")).await;
|
"GET",
|
||||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
"/some/path",
|
||||||
assert_eq!(
|
Some("unknown.host"),
|
||||||
resp.headers().get("content-type").unwrap(),
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||||
"text/plain; charset=utf-8"
|
|
||||||
);
|
);
|
||||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
let resp = router.oneshot(req).await.unwrap();
|
||||||
assert_eq!(&body[..], b"Not Found");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn known_host_returns_200() {
|
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
|
||||||
host: "example.com".to_string(),
|
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
}]);
|
|
||||||
let mut router = proxy_router(state);
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/some/path", Some("example.com")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn host_matching_is_case_insensitive() {
|
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
|
||||||
host: "example.com".to_string(),
|
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
}]);
|
|
||||||
let mut router = proxy_router(state);
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("Example.Com")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn host_with_port_stripped() {
|
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
|
||||||
host: "example.com".to_string(),
|
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
}]);
|
|
||||||
let mut router = proxy_router(state);
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("example.com:443")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("EXAMPLE.COM:8443")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn routing_table_update_visible_immediately() {
|
|
||||||
let state = make_config_with_sites(vec![SiteConfig {
|
|
||||||
host: "example.com".to_string(),
|
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
}]);
|
|
||||||
let mut router = proxy_router(state.clone());
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
let new_config = DynamicConfig::from_sites(
|
#[test]
|
||||||
vec![
|
fn test_determine_if_https_port_443() {
|
||||||
SiteConfig {
|
assert!(determine_if_https("example.com:443"));
|
||||||
host: "example.com".to_string(),
|
}
|
||||||
upstream: "127.0.0.1:8080".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
},
|
|
||||||
SiteConfig {
|
|
||||||
host: "new.example.com".to_string(),
|
|
||||||
upstream: "127.0.0.1:9090".to_string(),
|
|
||||||
upstream_scheme: "http".to_string(),
|
|
||||||
upstream_connect_timeout_secs: 5,
|
|
||||||
upstream_request_timeout_secs: 60,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
RateLimitConfig {
|
|
||||||
requests_per_second: 10,
|
|
||||||
burst: 20,
|
|
||||||
},
|
|
||||||
BodyConfig {
|
|
||||||
limit_bytes: 104857600,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
state.store(Arc::new(new_config));
|
|
||||||
|
|
||||||
let resp = send_request(&mut router, "GET", "/path", Some("new.example.com")).await;
|
#[test]
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
fn test_determine_if_https_port_80() {
|
||||||
|
assert!(!determine_if_https("example.com:80"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_determine_if_https_no_port() {
|
||||||
|
assert!(determine_if_https("example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_determine_if_https_port_8443() {
|
||||||
|
assert!(!determine_if_https("example.com:8443"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_upstream_uri_with_query() {
|
||||||
|
let uri: Uri = "/path?foo=bar".parse().unwrap();
|
||||||
|
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
|
||||||
|
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path?foo=bar");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_upstream_uri_without_query() {
|
||||||
|
let uri: Uri = "/path".parse().unwrap();
|
||||||
|
let result = build_upstream_uri("http", "127.0.0.1:8080", &uri);
|
||||||
|
assert_eq!(result.to_string(), "http://127.0.0.1:8080/path");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_upstream_uri_https() {
|
||||||
|
let uri: Uri = "/secure".parse().unwrap();
|
||||||
|
let result = build_upstream_uri("https", "upstream.example.com", &uri);
|
||||||
|
assert_eq!(result.to_string(), "https://upstream.example.com/secure");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,2 +1,134 @@
|
|||||||
#[allow(dead_code)]
|
use axum::http::{HeaderMap, HeaderName, HeaderValue};
|
||||||
pub struct ProxyHeaders;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
const HOP_BY_HOP: &[&str] = &[
|
||||||
|
"connection",
|
||||||
|
"keep-alive",
|
||||||
|
"proxy-authorization",
|
||||||
|
"proxy-authenticate",
|
||||||
|
"te",
|
||||||
|
"trailers",
|
||||||
|
"transfer-encoding",
|
||||||
|
"upgrade",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn remove_hop_by_hop(headers: &mut HeaderMap) {
|
||||||
|
for &name in HOP_BY_HOP {
|
||||||
|
headers.remove(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inject_proxy_headers(headers: &mut HeaderMap, remote_addr: SocketAddr, is_https: bool) {
|
||||||
|
let ip_str = remote_addr.ip().to_string();
|
||||||
|
let ip_value =
|
||||||
|
HeaderValue::from_str(&ip_str).unwrap_or_else(|_| HeaderValue::from_static("0.0.0.0"));
|
||||||
|
|
||||||
|
headers.insert(HeaderName::from_static("x-real-ip"), ip_value.clone());
|
||||||
|
|
||||||
|
headers.insert(HeaderName::from_static("x-forwarded-for"), ip_value);
|
||||||
|
|
||||||
|
let proto_value = if is_https {
|
||||||
|
HeaderValue::from_static("https")
|
||||||
|
} else {
|
||||||
|
HeaderValue::from_static("http")
|
||||||
|
};
|
||||||
|
headers.insert(HeaderName::from_static("x-forwarded-proto"), proto_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
|
||||||
|
fn make_headers_with_hop_by_hop() -> HeaderMap {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
h.insert("connection", HeaderValue::from_static("keep-alive"));
|
||||||
|
h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
|
||||||
|
h.insert("proxy-authorization", HeaderValue::from_static("Basic abc"));
|
||||||
|
h.insert(
|
||||||
|
"proxy-authenticate",
|
||||||
|
HeaderValue::from_static("Basic realm=x"),
|
||||||
|
);
|
||||||
|
h.insert("te", HeaderValue::from_static("trailers"));
|
||||||
|
h.insert("trailers", HeaderValue::from_static("chunked"));
|
||||||
|
h.insert("transfer-encoding", HeaderValue::from_static("chunked"));
|
||||||
|
h.insert("upgrade", HeaderValue::from_static("websocket"));
|
||||||
|
h.insert("content-type", HeaderValue::from_static("text/html"));
|
||||||
|
h.insert("accept", HeaderValue::from_static("*/*"));
|
||||||
|
h
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remove_hop_by_hop_removes_all_listed_headers() {
|
||||||
|
let mut h = make_headers_with_hop_by_hop();
|
||||||
|
remove_hop_by_hop(&mut h);
|
||||||
|
assert!(h.get("connection").is_none());
|
||||||
|
assert!(h.get("keep-alive").is_none());
|
||||||
|
assert!(h.get("proxy-authorization").is_none());
|
||||||
|
assert!(h.get("proxy-authenticate").is_none());
|
||||||
|
assert!(h.get("te").is_none());
|
||||||
|
assert!(h.get("trailers").is_none());
|
||||||
|
assert!(h.get("transfer-encoding").is_none());
|
||||||
|
assert!(h.get("upgrade").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remove_hop_by_hop_preserves_other_headers() {
|
||||||
|
let mut h = make_headers_with_hop_by_hop();
|
||||||
|
remove_hop_by_hop(&mut h);
|
||||||
|
assert_eq!(h.get("content-type").unwrap(), "text/html");
|
||||||
|
assert_eq!(h.get("accept").unwrap(), "*/*");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inject_proxy_headers_sets_x_real_ip() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
|
||||||
|
inject_proxy_headers(&mut h, addr, true);
|
||||||
|
assert_eq!(h.get("x-real-ip").unwrap(), "192.168.1.1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inject_proxy_headers_replaces_x_forwarded_for() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
h.insert(
|
||||||
|
"x-forwarded-for",
|
||||||
|
HeaderValue::from_static("10.0.0.1, 10.0.0.2"),
|
||||||
|
);
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 12345);
|
||||||
|
inject_proxy_headers(&mut h, addr, true);
|
||||||
|
assert_eq!(h.get("x-forwarded-for").unwrap(), "192.168.1.1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inject_proxy_headers_sets_x_forwarded_proto_https() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
|
||||||
|
inject_proxy_headers(&mut h, addr, true);
|
||||||
|
assert_eq!(h.get("x-forwarded-proto").unwrap(), "https");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inject_proxy_headers_sets_x_forwarded_proto_http() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80);
|
||||||
|
inject_proxy_headers(&mut h, addr, false);
|
||||||
|
assert_eq!(h.get("x-forwarded-proto").unwrap(), "http");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inject_proxy_headers_preserves_host() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
h.insert("host", HeaderValue::from_static("example.com"));
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 443);
|
||||||
|
inject_proxy_headers(&mut h, addr, true);
|
||||||
|
assert_eq!(h.get("host").unwrap(), "example.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remove_hop_by_hop_empty_headers() {
|
||||||
|
let mut h = HeaderMap::new();
|
||||||
|
remove_hop_by_hop(&mut h);
|
||||||
|
assert!(h.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pub mod handler;
|
|||||||
pub mod headers;
|
pub mod headers;
|
||||||
|
|
||||||
pub use crate::config::dynamic_config::normalize_host;
|
pub use crate::config::dynamic_config::normalize_host;
|
||||||
|
pub use handler::{create_http_client, create_https_client, proxy_router, ProxyState};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user