diff --git a/Cargo.lock b/Cargo.lock index b7909ec..c763f26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1531,6 +1531,7 @@ dependencies = [ "rustls-pki-types", "serde", "signal-hook", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-rustls", @@ -1907,15 +1908,15 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.27.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b886953..0059aaf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,5 @@ thiserror = "=2.0.18" [dev-dependencies] rcgen = "=0.13" -reqwest = { version = "=0.12", features = ["json"] } \ No newline at end of file +reqwest = { version = "=0.12", features = ["json"] } +tempfile = "=3.20" \ No newline at end of file diff --git a/src/tls/config.rs b/src/tls/config.rs new file mode 100644 index 0000000..f08a647 --- /dev/null +++ b/src/tls/config.rs @@ -0,0 +1,332 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use rustls::crypto::aws_lc_rs; +use rustls::crypto::aws_lc_rs::cipher_suite::{ + TLS13_AES_128_GCM_SHA256, TLS13_AES_256_GCM_SHA384, TLS13_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, +}; +use rustls::crypto::CryptoProvider; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use rustls::{ServerConfig, SupportedCipherSuite}; +use rustls_pemfile; + +static RESTRICTED_CIPHER_SUITES: &[SupportedCipherSuite] = &[ + TLS13_AES_256_GCM_SHA384, + TLS13_AES_128_GCM_SHA256, + TLS13_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +]; + +fn crypto_provider() -> Arc { + let mut provider = aws_lc_rs::default_provider(); + provider.cipher_suites = RESTRICTED_CIPHER_SUITES.to_vec(); + Arc::new(provider) +} + +pub fn load_certs(path: &str) -> Result>> { + let file = + File::open(path).with_context(|| format!("failed to open certificate file: {path}"))?; + let mut reader = BufReader::new(file); + let certs: Vec> = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .with_context(|| format!("failed to parse certificate file: {path}"))?; + if certs.is_empty() { + bail!("no certificates found in {path}"); + } + Ok(certs) +} + +pub fn load_private_key(path: &str) -> Result> { + let file = + File::open(path).with_context(|| format!("failed to open private key file: {path}"))?; + let mut reader = BufReader::new(file); + let key = rustls_pemfile::private_key(&mut reader) + .with_context(|| format!("failed to parse private key file: {path}"))?; + key.context(format!("no private key found in {path}")) +} + +pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result { + let certs = load_certs(cert_path)?; + let key = load_private_key(key_path)?; + + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_single_cert(certs, key) + .with_context(|| "failed to configure certificate/key pair")?; + + Ok(config) +} + +pub fn build_multi_domain_server_config( + domain_certs: &HashMap>, PrivateKeyDer<'static>)>, +) -> Result { + let provider = crypto_provider(); + + let mut resolver = SniCertResolver::new(); + for (domain, (certs, key)) in domain_certs { + let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider) + .with_context(|| format!("failed to load cert/key for domain {domain}"))?; + resolver.add(domain, Arc::new(certified_key)); + } + + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + + Ok(config) +} + +#[derive(Debug)] +struct SniCertResolver { + entries: HashMap>, +} + +impl SniCertResolver { + fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + fn add(&mut self, domain: &str, certified_key: Arc) { + self.entries.insert(domain.to_lowercase(), certified_key); + } +} + +impl ResolvesServerCert for SniCertResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let server_name = client_hello.server_name()?; + self.entries.get(&server_name.to_lowercase()).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, IsCa, KeyPair}; + use rustls::pki_types::PrivatePkcs8KeyDer; + + fn generate_test_cert(domain: &str) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_der = cert.der().clone(); + let key_der = PrivateKeyDer::from(PrivatePkcs8KeyDer::from(key_pair.serialize_der())); + (vec![cert_der], key_der) + } + + fn generate_test_cert_pem(domain: &str) -> (String, String) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_pem = cert.pem(); + let key_pem = key_pair.serialize_pem(); + (cert_pem, key_pem) + } + + #[test] + fn test_build_manual_server_config() { + let (certs, key) = generate_test_cert("test.example.com"); + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, _) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + std::fs::write(&cert_path, cert_pem).unwrap(); + let certs = load_certs(cert_path.to_str().unwrap()).unwrap(); + assert_eq!(certs.len(), 1); + } + + #[test] + fn test_load_private_key_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (_, key_pem) = generate_test_cert_pem("test.example.com"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&key_path, key_pem).unwrap(); + let key = load_private_key(key_path.to_str().unwrap()).unwrap(); + assert!(matches!(key, PrivateKeyDer::Pkcs8(_))); + } + + #[test] + fn test_build_manual_server_config_from_files() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, key_pem) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&cert_path, &cert_pem).unwrap(); + std::fs::write(&key_path, &key_pem).unwrap(); + let config = + build_manual_server_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap()) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_cipher_suite_restriction() { + let provider = crypto_provider(); + assert_eq!(provider.cipher_suites.len(), 7); + + let has_tls13_aes_256 = provider.cipher_suites.iter().any(|cs| { + format!("{cs:?}").contains("AES_256_GCM_SHA384") && format!("{cs:?}").contains("TLS13") + }); + let has_tls13_aes_128 = provider.cipher_suites.iter().any(|cs| { + format!("{cs:?}").contains("AES_128_GCM_SHA256") && format!("{cs:?}").contains("TLS13") + }); + let has_tls13_chacha = provider + .cipher_suites + .iter() + .any(|cs| format!("{cs:?}").contains("CHACHA20_POLY1305_SHA256")); + let has_ecdsa_aes256 = provider + .cipher_suites + .iter() + .any(|cs| format!("{cs:?}").contains("ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")); + let has_ecdsa_aes128 = provider + .cipher_suites + .iter() + .any(|cs| format!("{cs:?}").contains("ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")); + let has_rsa_aes256 = provider + .cipher_suites + .iter() + .any(|cs| format!("{cs:?}").contains("ECDHE_RSA_WITH_AES_256_GCM_SHA384")); + let has_rsa_aes128 = provider + .cipher_suites + .iter() + .any(|cs| format!("{cs:?}").contains("ECDHE_RSA_WITH_AES_128_GCM_SHA256")); + + assert!(has_tls13_aes_256); + assert!(has_tls13_aes_128); + assert!(has_tls13_chacha); + assert!(has_ecdsa_aes256); + assert!(has_ecdsa_aes128); + assert!(has_rsa_aes256); + assert!(has_rsa_aes128); + } + + #[test] + fn test_no_chacha20_for_tls12() { + let provider = crypto_provider(); + let tls12_chacha = provider.cipher_suites.iter().any(|cs| { + let dbg = format!("{cs:?}"); + dbg.contains("ECDHE") && dbg.contains("CHACHA20") + }); + assert!( + !tls12_chacha, + "TLS 1.2 ChaCha20 suites should not be present" + ); + } + + #[test] + fn test_protocol_versions_configured() { + let (certs, key) = generate_test_cert("test.example.com"); + let provider = crypto_provider(); + let _config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + } + + #[test] + fn test_sni_resolver_known_domain() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("example.com"); + assert!(resolved.is_some()); + } + + #[test] + fn test_sni_resolver_unknown_domain_returns_none() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("unknown.com"); + assert!(resolved.is_none()); + } + + #[test] + fn test_sni_resolver_normalizes_domain_to_lowercase() { + let (certs, key) = generate_test_cert("Example.COM"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("Example.COM", Arc::new(certified_key)); + + assert!(resolver.entries.contains_key("example.com")); + assert!(!resolver.entries.contains_key("Example.COM")); + } + + #[test] + fn test_build_multi_domain_server_config() { + let (certs1, key1) = generate_test_cert("site1.example.com"); + let (certs2, key2) = generate_test_cert("site2.example.com"); + + let mut domain_certs = HashMap::new(); + domain_certs.insert("site1.example.com".to_string(), (certs1, key1)); + domain_certs.insert("site2.example.com".to_string(), (certs2, key2)); + + let config = build_multi_domain_server_config(&domain_certs).unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_empty_file() { + let dir = tempfile::tempdir().unwrap(); + let cert_path = dir.path().join("empty.pem"); + std::fs::write(&cert_path, "").unwrap(); + let result = load_certs(cert_path.to_str().unwrap()); + assert!(result.is_err()); + } + + #[test] + fn test_load_certs_nonexistent_file() { + let result = load_certs("/nonexistent/path/cert.pem"); + assert!(result.is_err()); + } + + #[test] + fn test_load_private_key_nonexistent_file() { + let result = load_private_key("/nonexistent/path/key.pem"); + assert!(result.is_err()); + } +} diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 478b715..736b6ba 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,2 +1,3 @@ pub mod acceptor; +pub mod config; pub mod redirect;