diff --git a/Cargo.lock b/Cargo.lock index 47788c4..1d84785 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -572,9 +572,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "futures" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1157,6 +1157,7 @@ checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" dependencies = [ "aws-lc-rs", "pem", + "ring", "rustls-pki-types", "time", "yasna", @@ -1179,13 +1180,16 @@ dependencies = [ "arc-swap", "axum", "clap", + "futures", "hyper", + "rcgen", "rustls", "rustls-acme", "rustls-pemfile", "rustls-pki-types", "serde", "signal-hook", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-rustls", @@ -1498,6 +1502,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 8fc3a52..9854bc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,10 @@ version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" +[lib] +name = "reverse_proxy" +path = "src/lib.rs" + [[bin]] name = "reverse-proxy" path = "src/main.rs" @@ -26,4 +30,9 @@ rustls-pki-types = "=1.12.0" clap = { version = "=4.6.1", features = ["derive"] } signal-hook = "=0.3.18" anyhow = "=1.0.102" -thiserror = "=2.0.18" \ No newline at end of file +thiserror = "=2.0.18" +futures = "=0.3.31" + +[dev-dependencies] +rcgen = "=0.13" +tempfile = "=3" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..59977c2 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,8 @@ +pub mod admin; +pub mod config; +pub mod health; +pub mod logging; +pub mod proxy; +pub mod rate_limit; +pub mod shutdown; +pub mod tls; diff --git a/src/tls/acceptor.rs b/src/tls/acceptor.rs index 9240abf..4208a74 100644 --- a/src/tls/acceptor.rs +++ b/src/tls/acceptor.rs @@ -1,2 +1,232 @@ +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use rustls::version::{TLS12, TLS13}; +use rustls::ServerConfig; +use tracing::info; + +use super::acme::{spawn_acme_state, AcmeTlsConfig}; +use super::config::crypto_provider; +use crate::config::static_config::TlsConfig; + +const ACME_TLS_ALPN_01: &[u8] = b"acme-tls/1"; + #[allow(dead_code)] -pub struct TlsAcceptor; +fn build_acme_server_config( + resolver: Arc, +) -> Result> { + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .context("failed to set TLS protocol versions")? + .with_no_client_auth() + .with_cert_resolver(resolver); + let mut config = (*Arc::new(config)).clone(); + config.alpn_protocols = vec![ + b"h2".to_vec(), + b"http/1.1".to_vec(), + ACME_TLS_ALPN_01.to_vec(), + ]; + Ok(Arc::new(config)) +} + +#[allow(dead_code)] +fn build_acme_challenge_config( + resolver: Arc, +) -> Arc { + let provider = crypto_provider(); + let mut config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .expect("valid protocol versions") + .with_no_client_auth() + .with_cert_resolver(resolver); + config.alpn_protocols = vec![ACME_TLS_ALPN_01.to_vec()]; + Arc::new(config) +} + +#[allow(dead_code)] +#[derive(Debug)] +pub enum TlsMode { + Manual(Arc), + Acme { + default_config: Arc, + challenge_config: Arc, + resolver: Arc, + }, +} + +#[allow(dead_code)] +pub fn setup_tls(tls_config: &TlsConfig) -> Result { + match tls_config.mode.as_str() { + "manual" => { + if tls_config.cert_path.is_empty() { + bail!("manual TLS mode requires cert_path"); + } + if tls_config.key_path.is_empty() { + bail!("manual TLS mode requires key_path"); + } + let config = super::config::build_manual_server_config( + &tls_config.cert_path, + &tls_config.key_path, + )?; + Ok(TlsMode::Manual(Arc::new(config))) + } + "acme" => { + if tls_config.acme_domains.is_empty() { + bail!("ACME TLS mode requires at least one domain in acme_domains"); + } + if tls_config.acme_cache_dir.is_empty() { + bail!("ACME TLS mode requires acme_cache_dir"); + } + + let acme_tls_config = AcmeTlsConfig { + domains: tls_config.acme_domains.clone(), + cache_dir: tls_config.acme_cache_dir.clone().into(), + directory: tls_config.acme_directory.clone(), + contact: vec![], + }; + + let super::acme::AcmeTlsSetup { resolver, state } = acme_tls_config.setup()?; + + let default_config = build_acme_server_config(resolver.clone())?; + let challenge_config = build_acme_challenge_config(resolver.clone()); + + spawn_acme_state(state, tls_config.acme_domains.clone()); + + info!( + domains = ?tls_config.acme_domains, + "ACME TLS mode initialized" + ); + + Ok(TlsMode::Acme { + default_config, + challenge_config, + resolver, + }) + } + other => { + bail!("unknown TLS mode: '{}', expected 'manual' or 'acme'", other); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_acme_tls_alpn_value() { + assert_eq!(ACME_TLS_ALPN_01, b"acme-tls/1"); + } + + fn make_test_resolver() -> Arc { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let config = rustls_acme::AcmeConfig::new(["test.example.com"]) + .cache(rustls_acme::caches::DirCache::new( + temp_dir.path().to_path_buf(), + )) + .directory("https://acme-staging-v02.api.letsencrypt.org/directory"); + let state = config.state(); + state.resolver() + } + + #[test] + fn test_build_acme_server_config() { + let resolver = make_test_resolver(); + let config = build_acme_server_config(resolver); + assert!(config.is_ok()); + + let config = config.unwrap(); + assert!(config.alpn_protocols.contains(&b"h2".to_vec())); + assert!(config.alpn_protocols.contains(&b"http/1.1".to_vec())); + assert!(config.alpn_protocols.contains(&ACME_TLS_ALPN_01.to_vec())); + } + + #[test] + fn test_build_acme_challenge_config() { + let resolver = make_test_resolver(); + let config = build_acme_challenge_config(resolver); + assert_eq!(config.alpn_protocols.len(), 1); + assert_eq!(config.alpn_protocols[0], ACME_TLS_ALPN_01); + } + + #[test] + fn test_setup_tls_manual_missing_cert_path() { + let tls_config = TlsConfig { + mode: "manual".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: String::new(), + key_path: "/some/key.pem".to_string(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("cert_path")); + } + + #[test] + fn test_setup_tls_manual_missing_key_path() { + let tls_config = TlsConfig { + mode: "manual".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: "/some/cert.pem".to_string(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("key_path")); + } + + #[test] + fn test_setup_tls_acme_missing_domains() { + let tls_config = TlsConfig { + mode: "acme".to_string(), + acme_domains: vec![], + acme_cache_dir: "/tmp/cache".to_string(), + acme_directory: "staging".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("acme_domains")); + } + + #[test] + fn test_setup_tls_acme_missing_cache_dir() { + let tls_config = TlsConfig { + mode: "acme".to_string(), + acme_domains: vec!["example.com".to_string()], + acme_cache_dir: String::new(), + acme_directory: "staging".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("acme_cache_dir")); + } + + #[test] + fn test_setup_tls_unknown_mode() { + let tls_config = TlsConfig { + mode: "invalid".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("unknown TLS mode")); + } +} diff --git a/src/tls/acme.rs b/src/tls/acme.rs new file mode 100644 index 0000000..e6c20d5 --- /dev/null +++ b/src/tls/acme.rs @@ -0,0 +1,227 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use rustls_acme::caches::DirCache; +use rustls_acme::{AcmeConfig, AcmeState, EventError, EventOk, ResolvesServerCertAcme}; +use tracing::{error, info, warn}; + +#[allow(dead_code)] +const LETS_ENCRYPT_PRODUCTION_DIRECTORY: &str = "https://acme-v02.api.letsencrypt.org/directory"; +#[allow(dead_code)] +const LETS_ENCRYPT_STAGING_DIRECTORY: &str = + "https://acme-staging-v02.api.letsencrypt.org/directory"; + +#[allow(dead_code)] +pub struct AcmeTlsConfig { + pub domains: Vec, + pub cache_dir: PathBuf, + pub directory: String, + pub contact: Vec, +} + +#[allow(dead_code)] +pub struct AcmeTlsSetup { + pub resolver: Arc, + pub state: AcmeState, +} + +impl AcmeTlsConfig { + pub fn setup(self) -> Result { + let directory_url = match self.directory.as_str() { + "production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string(), + "staging" => LETS_ENCRYPT_STAGING_DIRECTORY.to_string(), + other => other.to_string(), + }; + + let acme_config = AcmeConfig::new(self.domains.clone()) + .cache(DirCache::new(self.cache_dir.clone())) + .directory(&directory_url) + .contact(self.contact.iter().map(|c| c.as_str())); + + let state = acme_config.state(); + let resolver = state.resolver(); + + info!( + domains = ?self.domains, + cache_dir = %self.cache_dir.display(), + directory = %directory_url, + "ACME state machine created" + ); + + Ok(AcmeTlsSetup { resolver, state }) + } + + #[allow(dead_code)] + pub fn directory_url(&self) -> &str { + match self.directory.as_str() { + "production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY, + "staging" => LETS_ENCRYPT_STAGING_DIRECTORY, + other => other, + } + } +} + +#[allow(dead_code)] +pub fn spawn_acme_state( + state: AcmeState, + domains: Vec, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + use futures::StreamExt; + let mut state = state; + loop { + match state.next().await { + Some(Ok(event)) => match event { + EventOk::DeployedCachedCert => { + info!( + domains = ?domains, + "ACME: deployed cached certificate" + ); + } + EventOk::DeployedNewCert => { + info!( + domains = ?domains, + "ACME: deployed new certificate" + ); + } + EventOk::CertCacheStore => { + info!( + domains = ?domains, + "ACME: certificate stored to cache" + ); + } + EventOk::AccountCacheStore => { + info!( + domains = ?domains, + "ACME: account stored to cache" + ); + } + }, + Some(Err(err)) => match &err { + EventError::CertCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: certificate cache load failed" + ); + } + EventError::AccountCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: account cache load failed" + ); + } + EventError::CertCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate cache store failed" + ); + } + EventError::AccountCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: account cache store failed" + ); + } + EventError::CachedCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: cached certificate parse failed" + ); + } + EventError::Order(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate order failed, will retry" + ); + } + EventError::NewCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: new certificate parse failed" + ); + } + }, + None => { + info!( + domains = ?domains, + "ACME: state machine ended" + ); + break; + } + } + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_acme_config_production_directory() { + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: "production".to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), LETS_ENCRYPT_PRODUCTION_DIRECTORY); + } + + #[test] + fn test_acme_config_staging_directory() { + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: "staging".to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), LETS_ENCRYPT_STAGING_DIRECTORY); + } + + #[test] + fn test_acme_config_custom_directory() { + let custom_url = "https://custom-acme.example.com/directory"; + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: custom_url.to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), custom_url); + } + + #[test] + fn test_acme_config_multiple_domains() { + let config = AcmeTlsConfig { + domains: vec!["git.alk.dev".to_string(), "alk.dev".to_string()], + cache_dir: PathBuf::from("/var/lib/reverse-proxy/acme-cache"), + directory: "production".to_string(), + contact: vec!["mailto:admin@alk.dev".to_string()], + }; + assert_eq!(config.domains.len(), 2); + assert_eq!(config.directory_url(), LETS_ENCRYPT_PRODUCTION_DIRECTORY); + } + + #[test] + fn test_acme_setup_creates_resolver() { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let config = AcmeTlsConfig { + domains: vec!["test.example.com".to_string()], + cache_dir: temp_dir.path().to_path_buf(), + directory: "staging".to_string(), + contact: vec!["mailto:admin@example.com".to_string()], + }; + let setup = config.setup().expect("setup should succeed"); + assert!(Arc::strong_count(&setup.resolver) >= 1); + } +} diff --git a/src/tls/config.rs b/src/tls/config.rs new file mode 100644 index 0000000..261a066 --- /dev/null +++ b/src/tls/config.rs @@ -0,0 +1,305 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use rustls::crypto::aws_lc_rs::cipher_suite; +use rustls::crypto::aws_lc_rs::{default_provider, kx_group}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use rustls::version::{TLS12, TLS13}; +use rustls::ServerConfig; +use rustls::SupportedCipherSuite; +use rustls_pemfile; + +#[allow(dead_code)] +static RESTRICTED_CIPHER_SUITES: &[SupportedCipherSuite] = &[ + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + cipher_suite::TLS13_AES_128_GCM_SHA256, + cipher_suite::TLS13_AES_256_GCM_SHA384, + cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, +]; + +pub(crate) fn crypto_provider() -> Arc { + let provider = default_provider(); + Arc::new(rustls::crypto::CryptoProvider { + cipher_suites: RESTRICTED_CIPHER_SUITES.to_vec(), + kx_groups: vec![kx_group::X25519, kx_group::SECP256R1, kx_group::SECP384R1], + ..provider + }) +} + +#[allow(dead_code)] +pub fn load_certs(path: &str) -> Result>> { + let file = + File::open(path).with_context(|| format!("failed to open certificate file: {path}"))?; + let mut reader = BufReader::new(file); + let certs: Vec> = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .with_context(|| format!("failed to parse certificate file: {path}"))?; + if certs.is_empty() { + bail!("no certificates found in {path}"); + } + Ok(certs) +} + +#[allow(dead_code)] +pub fn load_private_key(path: &str) -> Result> { + let file = + File::open(path).with_context(|| format!("failed to open private key file: {path}"))?; + let mut reader = BufReader::new(file); + let key = rustls_pemfile::private_key(&mut reader) + .with_context(|| format!("failed to parse private key file: {path}"))?; + key.context(format!("no private key found in {path}")) +} + +#[allow(dead_code)] +pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result { + let certs = load_certs(cert_path)?; + let key = load_private_key(key_path)?; + + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_single_cert(certs, key) + .with_context(|| "failed to configure certificate/key pair")?; + + Ok(config) +} + +#[allow(dead_code)] +pub fn build_multi_domain_server_config( + domain_certs: &HashMap>, PrivateKeyDer<'static>)>, +) -> Result { + let provider = crypto_provider(); + + let mut resolver = SniCertResolver::new(); + for (domain, (certs, key)) in domain_certs { + let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider) + .with_context(|| format!("failed to load cert/key for domain {domain}"))?; + resolver.add(domain, Arc::new(certified_key)); + } + + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + + Ok(config) +} + +#[derive(Debug)] +struct SniCertResolver { + entries: HashMap>, +} + +impl SniCertResolver { + fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + fn add(&mut self, domain: &str, certified_key: Arc) { + self.entries.insert(domain.to_lowercase(), certified_key); + } +} + +impl ResolvesServerCert for SniCertResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let server_name = client_hello.server_name()?; + self.entries.get(&server_name.to_lowercase()).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, IsCa, KeyPair}; + use rustls::pki_types::PrivatePkcs8KeyDer; + + fn generate_test_cert(domain: &str) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_der = cert.der().clone(); + let key_der = key_pair.serialize_der(); + let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der)); + (vec![cert_der], private_key) + } + + fn generate_test_cert_pem(domain: &str) -> (String, String) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_pem = cert.pem(); + let key_pem = key_pair.serialize_pem(); + (cert_pem, key_pem) + } + + #[test] + fn test_build_manual_server_config() { + let (certs, key) = generate_test_cert("test.example.com"); + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, _) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + std::fs::write(&cert_path, cert_pem).unwrap(); + let certs = load_certs(cert_path.to_str().unwrap()).unwrap(); + assert_eq!(certs.len(), 1); + } + + #[test] + fn test_load_private_key_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (_, key_pem) = generate_test_cert_pem("test.example.com"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&key_path, key_pem).unwrap(); + let key = load_private_key(key_path.to_str().unwrap()).unwrap(); + assert!(matches!(key, PrivateKeyDer::Pkcs8(_))); + } + + #[test] + fn test_build_manual_server_config_from_files() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, key_pem) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&cert_path, &cert_pem).unwrap(); + std::fs::write(&key_path, &key_pem).unwrap(); + let config = + build_manual_server_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap()) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_cipher_suite_restriction() { + let provider = crypto_provider(); + let cipher_suites: Vec = provider + .cipher_suites + .iter() + .map(|cs| format!("{:?}", cs)) + .collect(); + + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("AES_128_GCM_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("CHACHA20_POLY1305_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_RSA_WITH_AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_RSA_WITH_AES_128_GCM_SHA256"))); + + assert_eq!(provider.cipher_suites.len(), 7); + } + + #[test] + fn test_sni_resolver_known_domain() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("example.com"); + assert!(resolved.is_some()); + } + + #[test] + fn test_sni_resolver_unknown_domain_returns_none() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("unknown.com"); + assert!(resolved.is_none()); + } + + #[test] + fn test_sni_resolver_case_insensitive() { + let (certs, key) = generate_test_cert("Example.COM"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("Example.COM", Arc::new(certified_key)); + + assert!(resolver.entries.get("example.com").is_some()); + } + + #[test] + fn test_build_multi_domain_server_config() { + let (certs1, key1) = generate_test_cert("site1.example.com"); + let (certs2, key2) = generate_test_cert("site2.example.com"); + + let mut domain_certs = HashMap::new(); + domain_certs.insert("site1.example.com".to_string(), (certs1, key1)); + domain_certs.insert("site2.example.com".to_string(), (certs2, key2)); + + let config = build_multi_domain_server_config(&domain_certs).unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_empty_file() { + let dir = tempfile::tempdir().unwrap(); + let cert_path = dir.path().join("empty.pem"); + std::fs::write(&cert_path, "").unwrap(); + let result = load_certs(cert_path.to_str().unwrap()); + assert!(result.is_err()); + } + + #[test] + fn test_load_certs_nonexistent_file() { + let result = load_certs("/nonexistent/path/cert.pem"); + assert!(result.is_err()); + } + + #[test] + fn test_load_private_key_nonexistent_file() { + let result = load_private_key("/nonexistent/path/key.pem"); + assert!(result.is_err()); + } +} diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 478b715..33e305b 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,2 +1,4 @@ pub mod acceptor; +pub mod acme; +pub mod config; pub mod redirect;