Files
reverse-proxy/src/tls/config.rs
glm-5.1 fbae1c464e Implement DynamicConfig with ArcSwap hot-reload and ConfigReloadHandle
Add ConfigReloadHandle with Arc<ArcSwap<DynamicConfig>> for lock-free reads
on the request hot path and tokio::sync::Mutex-serialized reload. Add static
config change detection via diff_static_config(). Add DynamicConfig validation
(rate_limit, body_limit, site checks). Add PartialEq derives to config types.
Include unit tests for ArcSwap swap visibility, invalid config rejection, and
concurrent reload serialization.
2026-06-11 12:42:16 +00:00

328 lines
12 KiB
Rust

use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use rustls::crypto::aws_lc_rs::cipher_suite;
use rustls::crypto::aws_lc_rs::{default_provider, kx_group};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::version::{TLS12, TLS13};
use rustls::ServerConfig;
use rustls::SupportedCipherSuite;
use rustls_pemfile;
#[allow(dead_code)]
static RESTRICTED_CIPHER_SUITES: &[SupportedCipherSuite] = &[
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
cipher_suite::TLS13_AES_128_GCM_SHA256,
cipher_suite::TLS13_AES_256_GCM_SHA384,
cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
];
pub(crate) fn crypto_provider() -> Arc<rustls::crypto::CryptoProvider> {
let provider = default_provider();
Arc::new(rustls::crypto::CryptoProvider {
cipher_suites: RESTRICTED_CIPHER_SUITES.to_vec(),
kx_groups: vec![kx_group::X25519, kx_group::SECP256R1, kx_group::SECP384R1],
..provider
})
}
pub fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
let file =
File::open(path).with_context(|| format!("failed to open certificate file: {path}"))?;
let mut reader = BufReader::new(file);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.with_context(|| format!("failed to parse certificate file: {path}"))?;
if certs.is_empty() {
bail!("no certificates found in {path}");
}
Ok(certs)
}
pub fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
let file =
File::open(path).with_context(|| format!("failed to open private key file: {path}"))?;
let mut reader = BufReader::new(file);
let key = rustls_pemfile::private_key(&mut reader)
.with_context(|| format!("failed to parse private key file: {path}"))?;
key.context(format!("no private key found in {path}"))
}
pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result<ServerConfig> {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
let provider = crypto_provider();
let config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&TLS12, &TLS13])
.with_context(|| "failed to set protocol versions")?
.with_no_client_auth()
.with_single_cert(certs, key)
.with_context(|| "failed to configure certificate/key pair")?;
Ok(config)
}
pub fn build_multi_domain_server_config(
domain_certs: &HashMap<String, (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
) -> Result<ServerConfig> {
let provider = crypto_provider();
let mut resolver = SniCertResolver::new();
for (domain, (certs, key)) in domain_certs {
let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider)
.with_context(|| format!("failed to load cert/key for domain {domain}"))?;
resolver.add(domain, Arc::new(certified_key));
}
let config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&TLS12, &TLS13])
.with_context(|| "failed to set protocol versions")?
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
Ok(config)
}
#[derive(Debug)]
struct SniCertResolver {
entries: HashMap<String, Arc<CertifiedKey>>,
}
impl SniCertResolver {
fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
fn add(&mut self, domain: &str, certified_key: Arc<CertifiedKey>) {
self.entries.insert(domain.to_lowercase(), certified_key);
}
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name()?;
self.entries.get(&server_name.to_lowercase()).cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rcgen::{CertificateParams, IsCa, KeyPair};
use rustls::pki_types::PrivatePkcs8KeyDer;
fn generate_test_cert(domain: &str) -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let mut params =
CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params");
params.is_ca = IsCa::NoCa;
let key_pair = KeyPair::generate().expect("failed to generate key pair");
let cert = params
.self_signed(&key_pair)
.expect("failed to self-sign cert");
let cert_der = cert.der().clone();
let key_der = key_pair.serialize_der();
let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der));
(vec![cert_der], private_key)
}
fn generate_test_cert_pem(domain: &str) -> (String, String) {
let mut params =
CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params");
params.is_ca = IsCa::NoCa;
let key_pair = KeyPair::generate().expect("failed to generate key pair");
let cert = params
.self_signed(&key_pair)
.expect("failed to self-sign cert");
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
(cert_pem, key_pem)
}
#[test]
fn test_build_manual_server_config() {
let (certs, key) = generate_test_cert("test.example.com");
let provider = crypto_provider();
let config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&TLS12, &TLS13])
.unwrap()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
assert!(!config.ignore_client_order);
}
#[test]
fn test_load_certs_from_pem() {
let dir = tempfile::tempdir().unwrap();
let (cert_pem, _) = generate_test_cert_pem("test.example.com");
let cert_path = dir.path().join("cert.pem");
std::fs::write(&cert_path, cert_pem).unwrap();
let certs = load_certs(cert_path.to_str().unwrap()).unwrap();
assert_eq!(certs.len(), 1);
}
#[test]
fn test_load_private_key_from_pem() {
let dir = tempfile::tempdir().unwrap();
let (_, key_pem) = generate_test_cert_pem("test.example.com");
let key_path = dir.path().join("key.pem");
std::fs::write(&key_path, key_pem).unwrap();
let key = load_private_key(key_path.to_str().unwrap()).unwrap();
assert!(matches!(key, PrivateKeyDer::Pkcs8(_)));
}
#[test]
fn test_build_manual_server_config_from_files() {
let dir = tempfile::tempdir().unwrap();
let (cert_pem, key_pem) = generate_test_cert_pem("test.example.com");
let cert_path = dir.path().join("cert.pem");
let key_path = dir.path().join("key.pem");
std::fs::write(&cert_path, &cert_pem).unwrap();
std::fs::write(&key_path, &key_pem).unwrap();
let config =
build_manual_server_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap())
.unwrap();
assert!(!config.ignore_client_order);
}
#[test]
fn test_cipher_suite_restriction() {
let provider = crypto_provider();
assert_eq!(provider.cipher_suites.len(), 7);
let cipher_suites: Vec<String> = provider
.cipher_suites
.iter()
.map(|cs| format!("{cs:?}"))
.collect();
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("AES_256_GCM_SHA384")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("AES_128_GCM_SHA256")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("CHACHA20_POLY1305_SHA256")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("ECDHE_RSA_WITH_AES_256_GCM_SHA384")));
assert!(cipher_suites
.iter()
.any(|cs| cs.contains("ECDHE_RSA_WITH_AES_128_GCM_SHA256")));
}
#[test]
fn test_no_chacha20_for_tls12() {
let provider = crypto_provider();
let tls12_chacha = provider.cipher_suites.iter().any(|cs| {
let dbg = format!("{cs:?}");
dbg.contains("ECDHE") && dbg.contains("CHACHA20")
});
assert!(
!tls12_chacha,
"TLS 1.2 ChaCha20 suites should not be present"
);
}
#[test]
fn test_protocol_versions_configured() {
let (certs, key) = generate_test_cert("test.example.com");
let provider = crypto_provider();
let _config = ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&TLS12, &TLS13])
.unwrap()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
}
#[test]
fn test_sni_resolver_known_domain() {
let (certs, key) = generate_test_cert("example.com");
let provider = crypto_provider();
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
let mut resolver = SniCertResolver::new();
resolver.add("example.com", Arc::new(certified_key));
let resolved = resolver.entries.get("example.com");
assert!(resolved.is_some());
}
#[test]
fn test_sni_resolver_unknown_domain_returns_none() {
let (certs, key) = generate_test_cert("example.com");
let provider = crypto_provider();
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
let mut resolver = SniCertResolver::new();
resolver.add("example.com", Arc::new(certified_key));
let resolved = resolver.entries.get("unknown.com");
assert!(resolved.is_none());
}
#[test]
fn test_sni_resolver_case_insensitive() {
let (certs, key) = generate_test_cert("Example.COM");
let provider = crypto_provider();
let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap();
let mut resolver = SniCertResolver::new();
resolver.add("Example.COM", Arc::new(certified_key));
assert!(resolver.entries.contains_key("example.com"));
assert!(!resolver.entries.contains_key("Example.COM"));
}
#[test]
fn test_build_multi_domain_server_config() {
let (certs1, key1) = generate_test_cert("site1.example.com");
let (certs2, key2) = generate_test_cert("site2.example.com");
let mut domain_certs = HashMap::new();
domain_certs.insert("site1.example.com".to_string(), (certs1, key1));
domain_certs.insert("site2.example.com".to_string(), (certs2, key2));
let config = build_multi_domain_server_config(&domain_certs).unwrap();
assert!(!config.ignore_client_order);
}
#[test]
fn test_load_certs_empty_file() {
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("empty.pem");
std::fs::write(&cert_path, "").unwrap();
let result = load_certs(cert_path.to_str().unwrap());
assert!(result.is_err());
}
#[test]
fn test_load_certs_nonexistent_file() {
let result = load_certs("/nonexistent/path/cert.pem");
assert!(result.is_err());
}
#[test]
fn test_load_private_key_nonexistent_file() {
let result = load_private_key("/nonexistent/path/key.pem");
assert!(result.is_err());
}
}