diff --git a/Cargo.lock b/Cargo.lock index b7909ec..b7f826b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -590,6 +590,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -622,9 +628,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "futures" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -751,10 +757,23 @@ checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + [[package]] name = "gimli" version = "0.32.3" @@ -780,6 +799,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.17.1" @@ -913,7 +941,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.4", "system-configuration", "tokio", "tower-service", @@ -1027,6 +1055,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -1055,7 +1089,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.17.1", + "serde", + "serde_core", ] [[package]] @@ -1103,6 +1139,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.186" @@ -1427,6 +1469,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1451,6 +1503,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rcgen" version = "0.13.2" @@ -1522,6 +1580,7 @@ dependencies = [ "arc-swap", "axum", "clap", + "futures", "hyper", "rcgen", "reqwest", @@ -1531,6 +1590,7 @@ dependencies = [ "rustls-pki-types", "serde", "signal-hook", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-rustls", @@ -1703,6 +1763,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" version = "1.0.228" @@ -1835,6 +1901,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -1912,7 +1988,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys 0.61.2", @@ -2021,7 +2097,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.10", "tokio-macros", "windows-sys 0.52.0", ] @@ -2227,6 +2303,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -2290,7 +2372,16 @@ version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", ] [[package]] @@ -2348,6 +2439,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.100" @@ -2583,12 +2708,100 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + [[package]] name = "wit-bindgen" version = "0.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "writeable" version = "0.6.3" diff --git a/Cargo.toml b/Cargo.toml index b886953..21eb462 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,9 @@ clap = { version = "=4.6.1", features = ["derive"] } signal-hook = "=0.3.18" anyhow = "=1.0.102" thiserror = "=2.0.18" +futures = "=0.3.31" [dev-dependencies] rcgen = "=0.13" -reqwest = { version = "=0.12", features = ["json"] } \ No newline at end of file +reqwest = { version = "=0.12", features = ["json"] } +tempfile = "=3" diff --git a/src/tls/acceptor.rs b/src/tls/acceptor.rs index 9240abf..4208a74 100644 --- a/src/tls/acceptor.rs +++ b/src/tls/acceptor.rs @@ -1,2 +1,232 @@ +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use rustls::version::{TLS12, TLS13}; +use rustls::ServerConfig; +use tracing::info; + +use super::acme::{spawn_acme_state, AcmeTlsConfig}; +use super::config::crypto_provider; +use crate::config::static_config::TlsConfig; + +const ACME_TLS_ALPN_01: &[u8] = b"acme-tls/1"; + #[allow(dead_code)] -pub struct TlsAcceptor; +fn build_acme_server_config( + resolver: Arc, +) -> Result> { + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .context("failed to set TLS protocol versions")? + .with_no_client_auth() + .with_cert_resolver(resolver); + let mut config = (*Arc::new(config)).clone(); + config.alpn_protocols = vec![ + b"h2".to_vec(), + b"http/1.1".to_vec(), + ACME_TLS_ALPN_01.to_vec(), + ]; + Ok(Arc::new(config)) +} + +#[allow(dead_code)] +fn build_acme_challenge_config( + resolver: Arc, +) -> Arc { + let provider = crypto_provider(); + let mut config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .expect("valid protocol versions") + .with_no_client_auth() + .with_cert_resolver(resolver); + config.alpn_protocols = vec![ACME_TLS_ALPN_01.to_vec()]; + Arc::new(config) +} + +#[allow(dead_code)] +#[derive(Debug)] +pub enum TlsMode { + Manual(Arc), + Acme { + default_config: Arc, + challenge_config: Arc, + resolver: Arc, + }, +} + +#[allow(dead_code)] +pub fn setup_tls(tls_config: &TlsConfig) -> Result { + match tls_config.mode.as_str() { + "manual" => { + if tls_config.cert_path.is_empty() { + bail!("manual TLS mode requires cert_path"); + } + if tls_config.key_path.is_empty() { + bail!("manual TLS mode requires key_path"); + } + let config = super::config::build_manual_server_config( + &tls_config.cert_path, + &tls_config.key_path, + )?; + Ok(TlsMode::Manual(Arc::new(config))) + } + "acme" => { + if tls_config.acme_domains.is_empty() { + bail!("ACME TLS mode requires at least one domain in acme_domains"); + } + if tls_config.acme_cache_dir.is_empty() { + bail!("ACME TLS mode requires acme_cache_dir"); + } + + let acme_tls_config = AcmeTlsConfig { + domains: tls_config.acme_domains.clone(), + cache_dir: tls_config.acme_cache_dir.clone().into(), + directory: tls_config.acme_directory.clone(), + contact: vec![], + }; + + let super::acme::AcmeTlsSetup { resolver, state } = acme_tls_config.setup()?; + + let default_config = build_acme_server_config(resolver.clone())?; + let challenge_config = build_acme_challenge_config(resolver.clone()); + + spawn_acme_state(state, tls_config.acme_domains.clone()); + + info!( + domains = ?tls_config.acme_domains, + "ACME TLS mode initialized" + ); + + Ok(TlsMode::Acme { + default_config, + challenge_config, + resolver, + }) + } + other => { + bail!("unknown TLS mode: '{}', expected 'manual' or 'acme'", other); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_acme_tls_alpn_value() { + assert_eq!(ACME_TLS_ALPN_01, b"acme-tls/1"); + } + + fn make_test_resolver() -> Arc { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let config = rustls_acme::AcmeConfig::new(["test.example.com"]) + .cache(rustls_acme::caches::DirCache::new( + temp_dir.path().to_path_buf(), + )) + .directory("https://acme-staging-v02.api.letsencrypt.org/directory"); + let state = config.state(); + state.resolver() + } + + #[test] + fn test_build_acme_server_config() { + let resolver = make_test_resolver(); + let config = build_acme_server_config(resolver); + assert!(config.is_ok()); + + let config = config.unwrap(); + assert!(config.alpn_protocols.contains(&b"h2".to_vec())); + assert!(config.alpn_protocols.contains(&b"http/1.1".to_vec())); + assert!(config.alpn_protocols.contains(&ACME_TLS_ALPN_01.to_vec())); + } + + #[test] + fn test_build_acme_challenge_config() { + let resolver = make_test_resolver(); + let config = build_acme_challenge_config(resolver); + assert_eq!(config.alpn_protocols.len(), 1); + assert_eq!(config.alpn_protocols[0], ACME_TLS_ALPN_01); + } + + #[test] + fn test_setup_tls_manual_missing_cert_path() { + let tls_config = TlsConfig { + mode: "manual".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: String::new(), + key_path: "/some/key.pem".to_string(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("cert_path")); + } + + #[test] + fn test_setup_tls_manual_missing_key_path() { + let tls_config = TlsConfig { + mode: "manual".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: "/some/cert.pem".to_string(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("key_path")); + } + + #[test] + fn test_setup_tls_acme_missing_domains() { + let tls_config = TlsConfig { + mode: "acme".to_string(), + acme_domains: vec![], + acme_cache_dir: "/tmp/cache".to_string(), + acme_directory: "staging".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("acme_domains")); + } + + #[test] + fn test_setup_tls_acme_missing_cache_dir() { + let tls_config = TlsConfig { + mode: "acme".to_string(), + acme_domains: vec!["example.com".to_string()], + acme_cache_dir: String::new(), + acme_directory: "staging".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("acme_cache_dir")); + } + + #[test] + fn test_setup_tls_unknown_mode() { + let tls_config = TlsConfig { + mode: "invalid".to_string(), + acme_domains: vec![], + acme_cache_dir: String::new(), + acme_directory: "production".to_string(), + cert_path: String::new(), + key_path: String::new(), + }; + let result = setup_tls(&tls_config); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("unknown TLS mode")); + } +} diff --git a/src/tls/acme.rs b/src/tls/acme.rs new file mode 100644 index 0000000..e6c20d5 --- /dev/null +++ b/src/tls/acme.rs @@ -0,0 +1,227 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use rustls_acme::caches::DirCache; +use rustls_acme::{AcmeConfig, AcmeState, EventError, EventOk, ResolvesServerCertAcme}; +use tracing::{error, info, warn}; + +#[allow(dead_code)] +const LETS_ENCRYPT_PRODUCTION_DIRECTORY: &str = "https://acme-v02.api.letsencrypt.org/directory"; +#[allow(dead_code)] +const LETS_ENCRYPT_STAGING_DIRECTORY: &str = + "https://acme-staging-v02.api.letsencrypt.org/directory"; + +#[allow(dead_code)] +pub struct AcmeTlsConfig { + pub domains: Vec, + pub cache_dir: PathBuf, + pub directory: String, + pub contact: Vec, +} + +#[allow(dead_code)] +pub struct AcmeTlsSetup { + pub resolver: Arc, + pub state: AcmeState, +} + +impl AcmeTlsConfig { + pub fn setup(self) -> Result { + let directory_url = match self.directory.as_str() { + "production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string(), + "staging" => LETS_ENCRYPT_STAGING_DIRECTORY.to_string(), + other => other.to_string(), + }; + + let acme_config = AcmeConfig::new(self.domains.clone()) + .cache(DirCache::new(self.cache_dir.clone())) + .directory(&directory_url) + .contact(self.contact.iter().map(|c| c.as_str())); + + let state = acme_config.state(); + let resolver = state.resolver(); + + info!( + domains = ?self.domains, + cache_dir = %self.cache_dir.display(), + directory = %directory_url, + "ACME state machine created" + ); + + Ok(AcmeTlsSetup { resolver, state }) + } + + #[allow(dead_code)] + pub fn directory_url(&self) -> &str { + match self.directory.as_str() { + "production" => LETS_ENCRYPT_PRODUCTION_DIRECTORY, + "staging" => LETS_ENCRYPT_STAGING_DIRECTORY, + other => other, + } + } +} + +#[allow(dead_code)] +pub fn spawn_acme_state( + state: AcmeState, + domains: Vec, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + use futures::StreamExt; + let mut state = state; + loop { + match state.next().await { + Some(Ok(event)) => match event { + EventOk::DeployedCachedCert => { + info!( + domains = ?domains, + "ACME: deployed cached certificate" + ); + } + EventOk::DeployedNewCert => { + info!( + domains = ?domains, + "ACME: deployed new certificate" + ); + } + EventOk::CertCacheStore => { + info!( + domains = ?domains, + "ACME: certificate stored to cache" + ); + } + EventOk::AccountCacheStore => { + info!( + domains = ?domains, + "ACME: account stored to cache" + ); + } + }, + Some(Err(err)) => match &err { + EventError::CertCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: certificate cache load failed" + ); + } + EventError::AccountCacheLoad(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: account cache load failed" + ); + } + EventError::CertCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate cache store failed" + ); + } + EventError::AccountCacheStore(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: account cache store failed" + ); + } + EventError::CachedCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: cached certificate parse failed" + ); + } + EventError::Order(e) => { + warn!( + domains = ?domains, + error = ?e, + "ACME: certificate order failed, will retry" + ); + } + EventError::NewCertParse(e) => { + error!( + domains = ?domains, + error = ?e, + "ACME: new certificate parse failed" + ); + } + }, + None => { + info!( + domains = ?domains, + "ACME: state machine ended" + ); + break; + } + } + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_acme_config_production_directory() { + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: "production".to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), LETS_ENCRYPT_PRODUCTION_DIRECTORY); + } + + #[test] + fn test_acme_config_staging_directory() { + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: "staging".to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), LETS_ENCRYPT_STAGING_DIRECTORY); + } + + #[test] + fn test_acme_config_custom_directory() { + let custom_url = "https://custom-acme.example.com/directory"; + let config = AcmeTlsConfig { + domains: vec!["example.com".to_string()], + cache_dir: PathBuf::from("/tmp/test-cache"), + directory: custom_url.to_string(), + contact: vec![], + }; + assert_eq!(config.directory_url(), custom_url); + } + + #[test] + fn test_acme_config_multiple_domains() { + let config = AcmeTlsConfig { + domains: vec!["git.alk.dev".to_string(), "alk.dev".to_string()], + cache_dir: PathBuf::from("/var/lib/reverse-proxy/acme-cache"), + directory: "production".to_string(), + contact: vec!["mailto:admin@alk.dev".to_string()], + }; + assert_eq!(config.domains.len(), 2); + assert_eq!(config.directory_url(), LETS_ENCRYPT_PRODUCTION_DIRECTORY); + } + + #[test] + fn test_acme_setup_creates_resolver() { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let config = AcmeTlsConfig { + domains: vec!["test.example.com".to_string()], + cache_dir: temp_dir.path().to_path_buf(), + directory: "staging".to_string(), + contact: vec!["mailto:admin@example.com".to_string()], + }; + let setup = config.setup().expect("setup should succeed"); + assert!(Arc::strong_count(&setup.resolver) >= 1); + } +} diff --git a/src/tls/config.rs b/src/tls/config.rs new file mode 100644 index 0000000..261a066 --- /dev/null +++ b/src/tls/config.rs @@ -0,0 +1,305 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use rustls::crypto::aws_lc_rs::cipher_suite; +use rustls::crypto::aws_lc_rs::{default_provider, kx_group}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use rustls::version::{TLS12, TLS13}; +use rustls::ServerConfig; +use rustls::SupportedCipherSuite; +use rustls_pemfile; + +#[allow(dead_code)] +static RESTRICTED_CIPHER_SUITES: &[SupportedCipherSuite] = &[ + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + cipher_suite::TLS13_AES_128_GCM_SHA256, + cipher_suite::TLS13_AES_256_GCM_SHA384, + cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, +]; + +pub(crate) fn crypto_provider() -> Arc { + let provider = default_provider(); + Arc::new(rustls::crypto::CryptoProvider { + cipher_suites: RESTRICTED_CIPHER_SUITES.to_vec(), + kx_groups: vec![kx_group::X25519, kx_group::SECP256R1, kx_group::SECP384R1], + ..provider + }) +} + +#[allow(dead_code)] +pub fn load_certs(path: &str) -> Result>> { + let file = + File::open(path).with_context(|| format!("failed to open certificate file: {path}"))?; + let mut reader = BufReader::new(file); + let certs: Vec> = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .with_context(|| format!("failed to parse certificate file: {path}"))?; + if certs.is_empty() { + bail!("no certificates found in {path}"); + } + Ok(certs) +} + +#[allow(dead_code)] +pub fn load_private_key(path: &str) -> Result> { + let file = + File::open(path).with_context(|| format!("failed to open private key file: {path}"))?; + let mut reader = BufReader::new(file); + let key = rustls_pemfile::private_key(&mut reader) + .with_context(|| format!("failed to parse private key file: {path}"))?; + key.context(format!("no private key found in {path}")) +} + +#[allow(dead_code)] +pub fn build_manual_server_config(cert_path: &str, key_path: &str) -> Result { + let certs = load_certs(cert_path)?; + let key = load_private_key(key_path)?; + + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_single_cert(certs, key) + .with_context(|| "failed to configure certificate/key pair")?; + + Ok(config) +} + +#[allow(dead_code)] +pub fn build_multi_domain_server_config( + domain_certs: &HashMap>, PrivateKeyDer<'static>)>, +) -> Result { + let provider = crypto_provider(); + + let mut resolver = SniCertResolver::new(); + for (domain, (certs, key)) in domain_certs { + let certified_key = CertifiedKey::from_der(certs.clone(), key.clone_key(), &provider) + .with_context(|| format!("failed to load cert/key for domain {domain}"))?; + resolver.add(domain, Arc::new(certified_key)); + } + + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .with_context(|| "failed to set protocol versions")? + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + + Ok(config) +} + +#[derive(Debug)] +struct SniCertResolver { + entries: HashMap>, +} + +impl SniCertResolver { + fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + fn add(&mut self, domain: &str, certified_key: Arc) { + self.entries.insert(domain.to_lowercase(), certified_key); + } +} + +impl ResolvesServerCert for SniCertResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let server_name = client_hello.server_name()?; + self.entries.get(&server_name.to_lowercase()).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, IsCa, KeyPair}; + use rustls::pki_types::PrivatePkcs8KeyDer; + + fn generate_test_cert(domain: &str) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_der = cert.der().clone(); + let key_der = key_pair.serialize_der(); + let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der)); + (vec![cert_der], private_key) + } + + fn generate_test_cert_pem(domain: &str) -> (String, String) { + let mut params = + CertificateParams::new(vec![domain.to_string()]).expect("failed to create cert params"); + params.is_ca = IsCa::NoCa; + let key_pair = KeyPair::generate().expect("failed to generate key pair"); + let cert = params + .self_signed(&key_pair) + .expect("failed to self-sign cert"); + let cert_pem = cert.pem(); + let key_pem = key_pair.serialize_pem(); + (cert_pem, key_pem) + } + + #[test] + fn test_build_manual_server_config() { + let (certs, key) = generate_test_cert("test.example.com"); + let provider = crypto_provider(); + let config = ServerConfig::builder_with_provider(provider) + .with_protocol_versions(&[&TLS12, &TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, _) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + std::fs::write(&cert_path, cert_pem).unwrap(); + let certs = load_certs(cert_path.to_str().unwrap()).unwrap(); + assert_eq!(certs.len(), 1); + } + + #[test] + fn test_load_private_key_from_pem() { + let dir = tempfile::tempdir().unwrap(); + let (_, key_pem) = generate_test_cert_pem("test.example.com"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&key_path, key_pem).unwrap(); + let key = load_private_key(key_path.to_str().unwrap()).unwrap(); + assert!(matches!(key, PrivateKeyDer::Pkcs8(_))); + } + + #[test] + fn test_build_manual_server_config_from_files() { + let dir = tempfile::tempdir().unwrap(); + let (cert_pem, key_pem) = generate_test_cert_pem("test.example.com"); + let cert_path = dir.path().join("cert.pem"); + let key_path = dir.path().join("key.pem"); + std::fs::write(&cert_path, &cert_pem).unwrap(); + std::fs::write(&key_path, &key_pem).unwrap(); + let config = + build_manual_server_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap()) + .unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_cipher_suite_restriction() { + let provider = crypto_provider(); + let cipher_suites: Vec = provider + .cipher_suites + .iter() + .map(|cs| format!("{:?}", cs)) + .collect(); + + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("AES_128_GCM_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("CHACHA20_POLY1305_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_RSA_WITH_AES_256_GCM_SHA384"))); + assert!(cipher_suites + .iter() + .any(|cs| cs.contains("ECDHE_RSA_WITH_AES_128_GCM_SHA256"))); + + assert_eq!(provider.cipher_suites.len(), 7); + } + + #[test] + fn test_sni_resolver_known_domain() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("example.com"); + assert!(resolved.is_some()); + } + + #[test] + fn test_sni_resolver_unknown_domain_returns_none() { + let (certs, key) = generate_test_cert("example.com"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("example.com", Arc::new(certified_key)); + + let resolved = resolver.entries.get("unknown.com"); + assert!(resolved.is_none()); + } + + #[test] + fn test_sni_resolver_case_insensitive() { + let (certs, key) = generate_test_cert("Example.COM"); + let provider = crypto_provider(); + let certified_key = CertifiedKey::from_der(certs, key, &provider).unwrap(); + let mut resolver = SniCertResolver::new(); + resolver.add("Example.COM", Arc::new(certified_key)); + + assert!(resolver.entries.get("example.com").is_some()); + } + + #[test] + fn test_build_multi_domain_server_config() { + let (certs1, key1) = generate_test_cert("site1.example.com"); + let (certs2, key2) = generate_test_cert("site2.example.com"); + + let mut domain_certs = HashMap::new(); + domain_certs.insert("site1.example.com".to_string(), (certs1, key1)); + domain_certs.insert("site2.example.com".to_string(), (certs2, key2)); + + let config = build_multi_domain_server_config(&domain_certs).unwrap(); + assert!(!config.ignore_client_order); + } + + #[test] + fn test_load_certs_empty_file() { + let dir = tempfile::tempdir().unwrap(); + let cert_path = dir.path().join("empty.pem"); + std::fs::write(&cert_path, "").unwrap(); + let result = load_certs(cert_path.to_str().unwrap()); + assert!(result.is_err()); + } + + #[test] + fn test_load_certs_nonexistent_file() { + let result = load_certs("/nonexistent/path/cert.pem"); + assert!(result.is_err()); + } + + #[test] + fn test_load_private_key_nonexistent_file() { + let result = load_private_key("/nonexistent/path/key.pem"); + assert!(result.is_err()); + } +} diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 478b715..33e305b 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,2 +1,4 @@ pub mod acceptor; +pub mod acme; +pub mod config; pub mod redirect;