diff --git a/src/logging/format.rs b/src/logging/format.rs index 56745f6..638886f 100644 --- a/src/logging/format.rs +++ b/src/logging/format.rs @@ -1,8 +1,10 @@ +#[cfg(test)] #[derive(Default)] struct KvVisitor { pairs: Vec<(String, String)>, } +#[cfg(test)] impl KvVisitor { fn format(&self) -> String { let parts: Vec = self @@ -20,6 +22,7 @@ impl KvVisitor { } } +#[cfg(test)] impl tracing::field::Visit for KvVisitor { fn record_str(&mut self, field: &tracing::field::Field, value: &str) { self.pairs @@ -47,12 +50,6 @@ impl tracing::field::Visit for KvVisitor { } } -pub fn format_event_fields(event: &tracing::Event<'_>) -> String { - let mut visitor = KvVisitor::default(); - event.record(&mut visitor); - visitor.format() -} - #[macro_export] macro_rules! log_request { ($client_ip:expr, $host:expr, $method:expr, $path:expr, $status:expr, $upstream:expr, $duration_ms:expr) => { @@ -69,19 +66,6 @@ macro_rules! log_request { }; } -#[macro_export] -macro_rules! log_rate_limit { - ($client_ip:expr, $host:expr, $path:expr, $status:expr) => { - tracing::warn!( - prefix = "RATE_LIMIT", - client_ip = %$client_ip, - host = %$host, - path = %$path, - status = %$status, - ) - }; -} - #[macro_export] macro_rules! log_upstream_error { ($host:expr, $upstream:expr, $error:expr) => { @@ -94,17 +78,6 @@ macro_rules! log_upstream_error { }; } -#[macro_export] -macro_rules! log_config_reload { - ($status:expr, $sites:expr) => { - tracing::info!( - prefix = "CONFIG_RELOAD", - status = %$status, - sites = %$sites, - ) - }; -} - #[cfg(test)] mod tests { use super::*; @@ -180,8 +153,6 @@ mod tests { "127.0.0.1:3000", 45u64 ); - log_rate_limit!("10.0.0.1", "example.com", "/login", 429u16); log_upstream_error!("git.alk.dev", "127.0.0.1:3000", "connection refused"); - log_config_reload!("success", 1u32); } } diff --git a/src/proxy/error.rs b/src/proxy/error.rs index 59f17a2..5480755 100644 --- a/src/proxy/error.rs +++ b/src/proxy/error.rs @@ -8,24 +8,16 @@ pub enum ProxyError { BadGateway { host: String, upstream: String }, #[error("Gateway Timeout")] GatewayTimeout { host: String, upstream: String }, - #[error("Payload Too Large")] - PayloadTooLarge, #[error("Too Many Requests")] TooManyRequests { client_ip: String, host: String, path: String, }, - #[error("Not Found")] - NotFound, - #[error("Bad Request")] - BadRequest, #[error("upstream connection failed")] UpstreamConnection(#[source] hyper_util::client::legacy::Error), #[error("upstream timeout")] UpstreamTimeout, - #[error("upstream tls certificate validation failed")] - UpstreamTls(#[source] std::io::Error), #[error("no matching site for host")] UnknownHost, #[error("missing host header")] @@ -37,13 +29,11 @@ impl ProxyError { match self { Self::BadGateway { .. } => StatusCode::BAD_GATEWAY, Self::GatewayTimeout { .. } => StatusCode::GATEWAY_TIMEOUT, - Self::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE, Self::TooManyRequests { .. } => StatusCode::TOO_MANY_REQUESTS, - Self::NotFound | Self::UnknownHost => StatusCode::NOT_FOUND, - Self::BadRequest | Self::MissingHost => StatusCode::BAD_REQUEST, Self::UpstreamConnection(_) => StatusCode::BAD_GATEWAY, Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT, - Self::UpstreamTls(_) => StatusCode::BAD_GATEWAY, + Self::UnknownHost => StatusCode::NOT_FOUND, + Self::MissingHost => StatusCode::BAD_REQUEST, } } @@ -51,13 +41,11 @@ impl ProxyError { match self { Self::BadGateway { .. } => "Bad Gateway", Self::GatewayTimeout { .. } => "Gateway Timeout", - Self::PayloadTooLarge => "Payload Too Large", Self::TooManyRequests { .. } => "Too Many Requests", - Self::NotFound | Self::UnknownHost => "Not Found", - Self::BadRequest | Self::MissingHost => "Bad Request", Self::UpstreamConnection(_) => "Bad Gateway", Self::UpstreamTimeout => "Gateway Timeout", - Self::UpstreamTls(_) => "Bad Gateway", + Self::UnknownHost => "Not Found", + Self::MissingHost => "Bad Request", } } } @@ -99,9 +87,6 @@ impl IntoResponse for ProxyError { Self::UpstreamTimeout => { tracing::warn!(status = 504, "upstream timeout"); } - Self::UpstreamTls(e) => { - tracing::warn!(error = %e, status = 502, "upstream TLS error"); - } _ => {} } @@ -176,23 +161,6 @@ mod tests { ); } - #[tokio::test] - async fn payload_too_large_response() { - let resp = into_response(ProxyError::PayloadTooLarge); - assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); - let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); - assert_eq!(&body[..], b"Payload Too Large"); - } - - #[tokio::test] - async fn payload_too_large_content_type() { - let resp = into_response(ProxyError::PayloadTooLarge); - assert_eq!( - resp.headers().get("content-type").unwrap(), - "text/plain; charset=utf-8" - ); - } - #[tokio::test] async fn too_many_requests_response() { let resp = into_response(ProxyError::TooManyRequests { @@ -218,40 +186,6 @@ mod tests { ); } - #[tokio::test] - async fn not_found_response() { - let resp = into_response(ProxyError::NotFound); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); - let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); - assert_eq!(&body[..], b"Not Found"); - } - - #[tokio::test] - async fn not_found_content_type() { - let resp = into_response(ProxyError::NotFound); - assert_eq!( - resp.headers().get("content-type").unwrap(), - "text/plain; charset=utf-8" - ); - } - - #[tokio::test] - async fn bad_request_response() { - let resp = into_response(ProxyError::BadRequest); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); - assert_eq!(&body[..], b"Bad Request"); - } - - #[tokio::test] - async fn bad_request_content_type() { - let resp = into_response(ProxyError::BadRequest); - assert_eq!( - resp.headers().get("content-type").unwrap(), - "text/plain; charset=utf-8" - ); - } - #[test] fn error_display_matches_body() { assert_eq!( @@ -270,7 +204,6 @@ mod tests { .to_string(), "Gateway Timeout" ); - assert_eq!(ProxyError::PayloadTooLarge.to_string(), "Payload Too Large"); assert_eq!( ProxyError::TooManyRequests { client_ip: String::new(), @@ -280,7 +213,5 @@ mod tests { .to_string(), "Too Many Requests" ); - assert_eq!(ProxyError::NotFound.to_string(), "Not Found"); - assert_eq!(ProxyError::BadRequest.to_string(), "Bad Request"); } } diff --git a/src/tls/acme.rs b/src/tls/acme.rs index 278acd1..f08b789 100644 --- a/src/tls/acme.rs +++ b/src/tls/acme.rs @@ -50,6 +50,7 @@ impl AcmeTlsConfig { Ok(AcmeTlsSetup { resolver, state }) } + #[cfg(test)] pub fn directory_url(&self) -> &str { match self.directory.as_str() { "production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY, diff --git a/src/tls/config.rs b/src/tls/config.rs index 3d75442..7f449ff 100644 --- a/src/tls/config.rs +++ b/src/tls/config.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::fs::File; use std::io::BufReader; use std::sync::Arc; @@ -7,8 +6,6 @@ use anyhow::{bail, Context, Result}; use rustls::crypto::aws_lc_rs::cipher_suite; use rustls::crypto::aws_lc_rs::{default_provider, kx_group}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use rustls::server::{ClientHello, ResolvesServerCert}; -use rustls::sign::CertifiedKey; use rustls::version::{TLS12, TLS13}; use rustls::ServerConfig; use rustls::SupportedCipherSuite; @@ -75,56 +72,6 @@ pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result>, PrivateKeyDer<'static>)>, -) -> Result { - let provider = crypto_provider(); - - let mut resolver = SniCertResolver::new(); - for (domain, (certs, key)) in domain_certs { - let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider) - .with_context(|| format!("failed to load cert/key for domain {domain}"))?; - resolver.add(domain, Arc::new(certified_key)); - } - - let config = ServerConfig::builder_with_provider(provider) - .with_protocol_versions(&[&TLS12, &TLS13]) - .with_context(|| "failed to set protocol versions")? - .with_no_client_auth() - .with_cert_resolver(Arc::new(resolver)); - - let mut config = config; - // Advertise HTTP/2 and HTTP/1.1 via ALPN so clients can negotiate HTTP/2. - // Note: acme-tls/1 is NOT included here — it's only needed for ACME mode. - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - Ok(config) -} - -#[derive(Debug)] -struct SniCertResolver { - entries: HashMap>, -} - -impl SniCertResolver { - fn new() -> Self { - Self { - entries: HashMap::new(), - } - } - - fn add(&mut self, domain: &str, certified_key: Arc) { - self.entries.insert(domain.to_lowercase(), certified_key); - } -} - -impl ResolvesServerCert for SniCertResolver { - fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { - let server_name = client_hello.server_name()?; - self.entries.get(&server_name.to_lowercase()).cloned() - } -} - #[cfg(test)] mod tests { use super::*; @@ -264,55 +211,6 @@ mod tests { .unwrap(); } - #[test] - fn test_sni_resolver_known_domain() { - let (certs, key) = generate_test_cert("example.com"); - let provider = crypto_provider(); - let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); - let mut resolver = SniCertResolver::new(); - resolver.add("example.com", Arc::new(certified_key)); - - let resolved = resolver.entries.get("example.com"); - assert!(resolved.is_some()); - } - - #[test] - fn test_sni_resolver_unknown_domain_returns_none() { - let (certs, key) = generate_test_cert("example.com"); - let provider = crypto_provider(); - let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); - let mut resolver = SniCertResolver::new(); - resolver.add("example.com", Arc::new(certified_key)); - - let resolved = resolver.entries.get("unknown.com"); - assert!(resolved.is_none()); - } - - #[test] - fn test_sni_resolver_case_insensitive() { - let (certs, key) = generate_test_cert("Example.COM"); - let provider = crypto_provider(); - let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); - let mut resolver = SniCertResolver::new(); - resolver.add("Example.COM", Arc::new(certified_key)); - - assert!(resolver.entries.contains_key("example.com")); - assert!(!resolver.entries.contains_key("Example.COM")); - } - - #[test] - fn test_build_multi_domain_server_config() { - let (certs1, key1) = generate_test_cert("site1.example.com"); - let (certs2, key2) = generate_test_cert("site2.example.com"); - - let mut domain_certs = HashMap::new(); - domain_certs.insert("site1.example.com".to_string(), (certs1, key1)); - domain_certs.insert("site2.example.com".to_string(), (certs2, key2)); - - let config = build_multi_domain_server_config(&domain_certs).unwrap(); - assert!(!config.ignore_client_order); - } - #[test] fn test_load_certs_empty_file() { let dir = tempfile::tempdir().unwrap(); diff --git a/tests/helpers/http_test_helper.rs b/tests/helpers/http_test_helper.rs index 511a9c0..89a07ee 100644 --- a/tests/helpers/http_test_helper.rs +++ b/tests/helpers/http_test_helper.rs @@ -34,14 +34,4 @@ impl TestUpstream { pub async fn spawn_ok() -> Self { Self::spawn(|| Router::new().route("/", get(|| async { "ok" }))).await } - - #[allow(dead_code)] - pub fn url(&self) -> String { - format!("http://{}", self.addr) - } - - #[allow(dead_code)] - pub fn upstream_addr(&self) -> String { - format!("127.0.0.1:{}", self.addr.port()) - } }