Merge remote-tracking branch 'origin/fix/fix/remove-dead-code-remnants'
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
|
#[cfg(test)]
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct KvVisitor {
|
struct KvVisitor {
|
||||||
pairs: Vec<(String, String)>,
|
pairs: Vec<(String, String)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
impl KvVisitor {
|
impl KvVisitor {
|
||||||
fn format(&self) -> String {
|
fn format(&self) -> String {
|
||||||
let parts: Vec<String> = self
|
let parts: Vec<String> = self
|
||||||
@@ -20,6 +22,7 @@ impl KvVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
impl tracing::field::Visit for KvVisitor {
|
impl tracing::field::Visit for KvVisitor {
|
||||||
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
|
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
|
||||||
self.pairs
|
self.pairs
|
||||||
@@ -47,12 +50,6 @@ impl tracing::field::Visit for KvVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_event_fields(event: &tracing::Event<'_>) -> String {
|
|
||||||
let mut visitor = KvVisitor::default();
|
|
||||||
event.record(&mut visitor);
|
|
||||||
visitor.format()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! log_request {
|
macro_rules! log_request {
|
||||||
($client_ip:expr, $host:expr, $method:expr, $path:expr, $status:expr, $upstream:expr, $duration_ms:expr) => {
|
($client_ip:expr, $host:expr, $method:expr, $path:expr, $status:expr, $upstream:expr, $duration_ms:expr) => {
|
||||||
@@ -69,19 +66,6 @@ macro_rules! log_request {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! log_rate_limit {
|
|
||||||
($client_ip:expr, $host:expr, $path:expr, $status:expr) => {
|
|
||||||
tracing::warn!(
|
|
||||||
prefix = "RATE_LIMIT",
|
|
||||||
client_ip = %$client_ip,
|
|
||||||
host = %$host,
|
|
||||||
path = %$path,
|
|
||||||
status = %$status,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! log_upstream_error {
|
macro_rules! log_upstream_error {
|
||||||
($host:expr, $upstream:expr, $error:expr) => {
|
($host:expr, $upstream:expr, $error:expr) => {
|
||||||
@@ -94,17 +78,6 @@ macro_rules! log_upstream_error {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! log_config_reload {
|
|
||||||
($status:expr, $sites:expr) => {
|
|
||||||
tracing::info!(
|
|
||||||
prefix = "CONFIG_RELOAD",
|
|
||||||
status = %$status,
|
|
||||||
sites = %$sites,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -180,8 +153,6 @@ mod tests {
|
|||||||
"127.0.0.1:3000",
|
"127.0.0.1:3000",
|
||||||
45u64
|
45u64
|
||||||
);
|
);
|
||||||
log_rate_limit!("10.0.0.1", "example.com", "/login", 429u16);
|
|
||||||
log_upstream_error!("git.alk.dev", "127.0.0.1:3000", "connection refused");
|
log_upstream_error!("git.alk.dev", "127.0.0.1:3000", "connection refused");
|
||||||
log_config_reload!("success", 1u32);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,24 +8,16 @@ pub enum ProxyError {
|
|||||||
BadGateway { host: String, upstream: String },
|
BadGateway { host: String, upstream: String },
|
||||||
#[error("Gateway Timeout")]
|
#[error("Gateway Timeout")]
|
||||||
GatewayTimeout { host: String, upstream: String },
|
GatewayTimeout { host: String, upstream: String },
|
||||||
#[error("Payload Too Large")]
|
|
||||||
PayloadTooLarge,
|
|
||||||
#[error("Too Many Requests")]
|
#[error("Too Many Requests")]
|
||||||
TooManyRequests {
|
TooManyRequests {
|
||||||
client_ip: String,
|
client_ip: String,
|
||||||
host: String,
|
host: String,
|
||||||
path: String,
|
path: String,
|
||||||
},
|
},
|
||||||
#[error("Not Found")]
|
|
||||||
NotFound,
|
|
||||||
#[error("Bad Request")]
|
|
||||||
BadRequest,
|
|
||||||
#[error("upstream connection failed")]
|
#[error("upstream connection failed")]
|
||||||
UpstreamConnection(#[source] hyper_util::client::legacy::Error),
|
UpstreamConnection(#[source] hyper_util::client::legacy::Error),
|
||||||
#[error("upstream timeout")]
|
#[error("upstream timeout")]
|
||||||
UpstreamTimeout,
|
UpstreamTimeout,
|
||||||
#[error("upstream tls certificate validation failed")]
|
|
||||||
UpstreamTls(#[source] std::io::Error),
|
|
||||||
#[error("no matching site for host")]
|
#[error("no matching site for host")]
|
||||||
UnknownHost,
|
UnknownHost,
|
||||||
#[error("missing host header")]
|
#[error("missing host header")]
|
||||||
@@ -37,13 +29,11 @@ impl ProxyError {
|
|||||||
match self {
|
match self {
|
||||||
Self::BadGateway { .. } => StatusCode::BAD_GATEWAY,
|
Self::BadGateway { .. } => StatusCode::BAD_GATEWAY,
|
||||||
Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
|
Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
|
||||||
Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
|
||||||
Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS,
|
||||||
Self::NotFound | Self::UnknownHost => StatusCode::NOT_FOUND,
|
|
||||||
Self::BadRequest | Self::MissingHost => StatusCode::BAD_REQUEST,
|
|
||||||
Self::UpstreamConnection(_) => StatusCode::BAD_GATEWAY,
|
Self::UpstreamConnection(_) => StatusCode::BAD_GATEWAY,
|
||||||
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
|
Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT,
|
||||||
Self::UpstreamTls(_) => StatusCode::BAD_GATEWAY,
|
Self::UnknownHost => StatusCode::NOT_FOUND,
|
||||||
|
Self::MissingHost => StatusCode::BAD_REQUEST,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,13 +41,11 @@ impl ProxyError {
|
|||||||
match self {
|
match self {
|
||||||
Self::BadGateway { .. } => "Bad Gateway",
|
Self::BadGateway { .. } => "Bad Gateway",
|
||||||
Self::GatewayTimeout { .. } => "Gateway Timeout",
|
Self::GatewayTimeout { .. } => "Gateway Timeout",
|
||||||
Self::PayloadTooLarge => "Payload Too Large",
|
|
||||||
Self::TooManyRequests { .. } => "Too Many Requests",
|
Self::TooManyRequests { .. } => "Too Many Requests",
|
||||||
Self::NotFound | Self::UnknownHost => "Not Found",
|
|
||||||
Self::BadRequest | Self::MissingHost => "Bad Request",
|
|
||||||
Self::UpstreamConnection(_) => "Bad Gateway",
|
Self::UpstreamConnection(_) => "Bad Gateway",
|
||||||
Self::UpstreamTimeout => "Gateway Timeout",
|
Self::UpstreamTimeout => "Gateway Timeout",
|
||||||
Self::UpstreamTls(_) => "Bad Gateway",
|
Self::UnknownHost => "Not Found",
|
||||||
|
Self::MissingHost => "Bad Request",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,9 +87,6 @@ impl IntoResponse for ProxyError {
|
|||||||
Self::UpstreamTimeout => {
|
Self::UpstreamTimeout => {
|
||||||
tracing::warn!(status = 504, "upstream timeout");
|
tracing::warn!(status = 504, "upstream timeout");
|
||||||
}
|
}
|
||||||
Self::UpstreamTls(e) => {
|
|
||||||
tracing::warn!(error = %e, status = 502, "upstream TLS error");
|
|
||||||
}
|
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,23 +161,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn payload_too_large_response() {
|
|
||||||
let resp = into_response(ProxyError::PayloadTooLarge);
|
|
||||||
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
|
||||||
assert_eq!(&body[..], b"Payload Too Large");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn payload_too_large_content_type() {
|
|
||||||
let resp = into_response(ProxyError::PayloadTooLarge);
|
|
||||||
assert_eq!(
|
|
||||||
resp.headers().get("content-type").unwrap(),
|
|
||||||
"text/plain; charset=utf-8"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn too_many_requests_response() {
|
async fn too_many_requests_response() {
|
||||||
let resp = into_response(ProxyError::TooManyRequests {
|
let resp = into_response(ProxyError::TooManyRequests {
|
||||||
@@ -218,40 +186,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn not_found_response() {
|
|
||||||
let resp = into_response(ProxyError::NotFound);
|
|
||||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
|
||||||
assert_eq!(&body[..], b"Not Found");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn not_found_content_type() {
|
|
||||||
let resp = into_response(ProxyError::NotFound);
|
|
||||||
assert_eq!(
|
|
||||||
resp.headers().get("content-type").unwrap(),
|
|
||||||
"text/plain; charset=utf-8"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn bad_request_response() {
|
|
||||||
let resp = into_response(ProxyError::BadRequest);
|
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
|
||||||
assert_eq!(&body[..], b"Bad Request");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn bad_request_content_type() {
|
|
||||||
let resp = into_response(ProxyError::BadRequest);
|
|
||||||
assert_eq!(
|
|
||||||
resp.headers().get("content-type").unwrap(),
|
|
||||||
"text/plain; charset=utf-8"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn error_display_matches_body() {
|
fn error_display_matches_body() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -270,7 +204,6 @@ mod tests {
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
"Gateway Timeout"
|
"Gateway Timeout"
|
||||||
);
|
);
|
||||||
assert_eq!(ProxyError::PayloadTooLarge.to_string(), "Payload Too Large");
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ProxyError::TooManyRequests {
|
ProxyError::TooManyRequests {
|
||||||
client_ip: String::new(),
|
client_ip: String::new(),
|
||||||
@@ -280,7 +213,5 @@ mod tests {
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
"Too Many Requests"
|
"Too Many Requests"
|
||||||
);
|
);
|
||||||
assert_eq!(ProxyError::NotFound.to_string(), "Not Found");
|
|
||||||
assert_eq!(ProxyError::BadRequest.to_string(), "Bad Request");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ impl AcmeTlsConfig {
|
|||||||
Ok(AcmeTlsSetup { resolver, state })
|
Ok(AcmeTlsSetup { resolver, state })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
pub fn directory_url(&self) -> &str {
|
pub fn directory_url(&self) -> &str {
|
||||||
match self.directory.as_str() {
|
match self.directory.as_str() {
|
||||||
"production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY,
|
"production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -7,8 +6,6 @@ use anyhow::{bail, Context, Result};
|
|||||||
use rustls::crypto::aws_lc_rs::cipher_suite;
|
use rustls::crypto::aws_lc_rs::cipher_suite;
|
||||||
use rustls::crypto::aws_lc_rs::{default_provider, kx_group};
|
use rustls::crypto::aws_lc_rs::{default_provider, kx_group};
|
||||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||||
use rustls::server::{ClientHello, ResolvesServerCert};
|
|
||||||
use rustls::sign::CertifiedKey;
|
|
||||||
use rustls::version::{TLS12, TLS13};
|
use rustls::version::{TLS12, TLS13};
|
||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use rustls::SupportedCipherSuite;
|
use rustls::SupportedCipherSuite;
|
||||||
@@ -75,56 +72,6 @@ pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result<Ser
|
|||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_multi_domain_server_config(
|
|
||||||
domain_certs: &HashMap<String, (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
|
|
||||||
) -> Result<ServerConfig> {
|
|
||||||
let provider = crypto_provider();
|
|
||||||
|
|
||||||
let mut resolver = SniCertResolver::new();
|
|
||||||
for (domain, (certs, key)) in domain_certs {
|
|
||||||
let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider)
|
|
||||||
.with_context(|| format!("failed to load cert/key for domain {domain}"))?;
|
|
||||||
resolver.add(domain, Arc::new(certified_key));
|
|
||||||
}
|
|
||||||
|
|
||||||
let config = ServerConfig::builder_with_provider(provider)
|
|
||||||
.with_protocol_versions(&[&TLS12, &TLS13])
|
|
||||||
.with_context(|| "failed to set protocol versions")?
|
|
||||||
.with_no_client_auth()
|
|
||||||
.with_cert_resolver(Arc::new(resolver));
|
|
||||||
|
|
||||||
let mut config = config;
|
|
||||||
// Advertise HTTP/2 and HTTP/1.1 via ALPN so clients can negotiate HTTP/2.
|
|
||||||
// Note: acme-tls/1 is NOT included here — it's only needed for ACME mode.
|
|
||||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
|
||||||
|
|
||||||
Ok(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct SniCertResolver {
|
|
||||||
entries: HashMap<String, Arc<CertifiedKey>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SniCertResolver {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
entries: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add(&mut self, domain: &str, certified_key: Arc<CertifiedKey>) {
|
|
||||||
self.entries.insert(domain.to_lowercase(), certified_key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResolvesServerCert for SniCertResolver {
|
|
||||||
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
|
||||||
let server_name = client_hello.server_name()?;
|
|
||||||
self.entries.get(&server_name.to_lowercase()).cloned()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -264,55 +211,6 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sni_resolver_known_domain() {
|
|
||||||
let (certs, key) = generate_test_cert("example.com");
|
|
||||||
let provider = crypto_provider();
|
|
||||||
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
|
|
||||||
let mut resolver = SniCertResolver::new();
|
|
||||||
resolver.add("example.com", Arc::new(certified_key));
|
|
||||||
|
|
||||||
let resolved = resolver.entries.get("example.com");
|
|
||||||
assert!(resolved.is_some());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sni_resolver_unknown_domain_returns_none() {
|
|
||||||
let (certs, key) = generate_test_cert("example.com");
|
|
||||||
let provider = crypto_provider();
|
|
||||||
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
|
|
||||||
let mut resolver = SniCertResolver::new();
|
|
||||||
resolver.add("example.com", Arc::new(certified_key));
|
|
||||||
|
|
||||||
let resolved = resolver.entries.get("unknown.com");
|
|
||||||
assert!(resolved.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sni_resolver_case_insensitive() {
|
|
||||||
let (certs, key) = generate_test_cert("Example.COM");
|
|
||||||
let provider = crypto_provider();
|
|
||||||
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
|
|
||||||
let mut resolver = SniCertResolver::new();
|
|
||||||
resolver.add("Example.COM", Arc::new(certified_key));
|
|
||||||
|
|
||||||
assert!(resolver.entries.contains_key("example.com"));
|
|
||||||
assert!(!resolver.entries.contains_key("Example.COM"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_multi_domain_server_config() {
|
|
||||||
let (certs1, key1) = generate_test_cert("site1.example.com");
|
|
||||||
let (certs2, key2) = generate_test_cert("site2.example.com");
|
|
||||||
|
|
||||||
let mut domain_certs = HashMap::new();
|
|
||||||
domain_certs.insert("site1.example.com".to_string(), (certs1, key1));
|
|
||||||
domain_certs.insert("site2.example.com".to_string(), (certs2, key2));
|
|
||||||
|
|
||||||
let config = build_multi_domain_server_config(&domain_certs).unwrap();
|
|
||||||
assert!(!config.ignore_client_order);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_certs_empty_file() {
|
fn test_load_certs_empty_file() {
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
|||||||
@@ -34,14 +34,4 @@ impl TestUpstream {
|
|||||||
pub async fn spawn_ok() -> Self {
|
pub async fn spawn_ok() -> Self {
|
||||||
Self::spawn(|| Router::new().route("/", get(|| async { "ok" }))).await
|
Self::spawn(|| Router::new().route("/", get(|| async { "ok" }))).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn url(&self) -> String {
|
|
||||||
format!("http://{}", self.addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn upstream_addr(&self) -> String {
|
|
||||||
format!("127.0.0.1:{}", self.addr.port())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user