diff --git a/Cargo.lock b/Cargo.lock index a501ba5..263c2e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,7 +130,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.5.1", "asn1-rs-impl", "displaydoc", "nom", @@ -140,6 +140,22 @@ dependencies = [ "time", ] +[[package]] +name = "asn1-rs" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f43a50ac4fdca5df8e885c21b835997f0a1cdee65494a6847694a98652d9d8" +dependencies = [ + "asn1-rs-derive 0.6.0", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 2.0.18", + "time", +] + [[package]] name = "asn1-rs-derive" version = "0.5.1" @@ -152,6 +168,18 @@ dependencies = [ "synstructure", ] +[[package]] +name = "asn1-rs-derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "asn1-rs-impl" version = "0.2.0" @@ -337,6 +365,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "bit-vec" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -811,7 +848,21 @@ version = "9.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" dependencies = [ - "asn1-rs", + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + +[[package]] +name = "der-parser" +version = "10.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" +dependencies = [ + "asn1-rs 0.7.2", "displaydoc", "nom", "num-bigint", @@ -1933,7 +1984,7 @@ dependencies = [ "pkarr", "portmapper", "rand 0.8.6", - "rcgen", + "rcgen 0.13.2", "reqwest", "ring", "rustls", @@ -1952,7 +2003,7 @@ dependencies = [ "url", "wasm-bindgen-futures", "webpki-roots 0.26.11", - "x509-parser", + "x509-parser 0.16.0", "z32", ] @@ -2640,7 +2691,16 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" dependencies = [ - "asn1-rs", + "asn1-rs 0.6.2", +] + +[[package]] +name = "oid-registry" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" +dependencies = [ + "asn1-rs 0.7.2", ] [[package]] @@ -3341,7 +3401,21 @@ dependencies = [ "ring", "rustls-pki-types", "time", - "yasna", + "yasna 0.5.2", +] + +[[package]] +name = "rcgen" +version = "0.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57f6d249aad744e274e682777a50283a225a32705394ee6d5fcc01efa25e4055" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser 0.18.1", + "yasna 0.6.0", ] [[package]] @@ -3735,12 +3809,12 @@ dependencies = [ "http 1.4.1", "log", "pem", - "rcgen", + "rcgen 0.13.2", "serde", "serde_json", "thiserror 2.0.18", "webpki-roots 0.26.11", - "x509-parser", + "x509-parser 0.16.0", ] [[package]] @@ -5476,15 +5550,18 @@ dependencies = [ "anyhow", "async-trait", "iroh", + "rcgen 0.14.8", "russh", "rustls", "rustls-acme", + "rustls-pki-types", "tempfile", "thiserror 2.0.18", "tokio", "tokio-rustls", "tokio-util", "tracing", + "webpki-roots 0.26.11", "wraith-core", ] @@ -5509,17 +5586,35 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" dependencies = [ - "asn1-rs", + "asn1-rs 0.6.2", "data-encoding", - "der-parser", + "der-parser 9.0.0", "lazy_static", "nom", - "oid-registry", + "oid-registry 0.7.1", "rusticata-macros", "thiserror 1.0.69", "time", ] +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" +dependencies = [ + "asn1-rs 0.7.2", + "data-encoding", + "der-parser 10.0.0", + "lazy_static", + "nom", + "oid-registry 0.8.1", + "ring", + "rusticata-macros", + "thiserror 2.0.18", + "time", +] + [[package]] name = "xml-rs" version = "0.8.28" @@ -5544,6 +5639,16 @@ dependencies = [ "time", ] +[[package]] +name = "yasna" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5f6765e852b9b4dc8e2a76843e4d64d1cea8e79bcde0b6901aea8e7c7f08282" +dependencies = [ + "bit-vec", + "time", +] + [[package]] name = "yoke" version = "0.8.2" diff --git a/crates/wraith-core/Cargo.toml b/crates/wraith-core/Cargo.toml index f90dbc0..a507f8f 100644 --- a/crates/wraith-core/Cargo.toml +++ b/crates/wraith-core/Cargo.toml @@ -8,7 +8,7 @@ name = "wraith_core" [features] default = [] -tls = ["dep:tokio-rustls", "dep:rustls"] +tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"] iroh = ["dep:iroh"] acme = ["dep:rustls-acme", "tls"] testutil = [] @@ -22,11 +22,14 @@ anyhow = "1" thiserror = "2" tokio-util = { version = "0.7", features = ["compat"] } tokio-rustls = { version = "0.26", optional = true } -rustls = { version = "0.23", optional = true } +rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] } +rustls-pki-types = { version = "1", optional = true } rustls-acme = { version = "0.12", optional = true } +webpki-roots = { version = "0.26", optional = true } iroh = { version = "0.34", optional = true } async-trait = "0.1" [dev-dependencies] -wraith-core = { path = ".", features = ["testutil"] } -tempfile = "3" \ No newline at end of file +wraith-core = { path = ".", features = ["testutil", "tls"] } +tempfile = "3" +rcgen = "0.14" \ No newline at end of file diff --git a/crates/wraith-core/src/transport/mod.rs b/crates/wraith-core/src/transport/mod.rs index 4ac17cb..8d4ffdf 100644 --- a/crates/wraith-core/src/transport/mod.rs +++ b/crates/wraith-core/src/transport/mod.rs @@ -2,6 +2,12 @@ mod tcp; pub use tcp::{TcpAcceptor, TcpTransport}; +#[cfg(feature = "tls")] +mod tls; + +#[cfg(feature = "tls")] +pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport}; + use std::net::SocketAddr; use anyhow::Result; diff --git a/crates/wraith-core/src/transport/tls.rs b/crates/wraith-core/src/transport/tls.rs new file mode 100644 index 0000000..d247bb4 --- /dev/null +++ b/crates/wraith-core/src/transport/tls.rs @@ -0,0 +1,386 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector}; + +use super::{Transport, TransportAcceptor, TransportInfo, TransportKind}; + +/// A TLS-based client transport that connects to a remote address over TLS. +/// +/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`. +/// Supports insecure mode (accepts any certificate, for development) and +/// custom root CA certificates for verification. The `tls_server_name` field +/// overrides the SNI hostname sent during the TLS handshake (ADR-010). +pub struct TlsTransport { + addr: SocketAddr, + tls_server_name: Option, + insecure: bool, + root_cert: Option>, +} + +impl TlsTransport { + pub fn new(addr: SocketAddr) -> Self { + Self { + addr, + tls_server_name: None, + insecure: false, + root_cert: None, + } + } + + pub fn with_server_name(mut self, name: impl Into) -> Self { + self.tls_server_name = Some(name.into()); + self + } + + pub fn with_insecure(mut self, insecure: bool) -> Self { + self.insecure = insecure; + self + } + + pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self { + self.root_cert = Some(cert); + self + } + + fn build_client_config(&self) -> Result { + if self.insecure { + let config = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoVerifier)) + .with_no_client_auth(); + return Ok(config); + } + + let mut root_store = RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + if let Some(ref cert) = self.root_cert { + root_store.add(cert.clone())?; + } + + let config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + Ok(config) + } + + fn resolve_server_name(&self) -> Result> { + let name = match &self.tls_server_name { + Some(n) => n.clone(), + None => self.addr.ip().to_string(), + }; + ServerName::try_from(name.clone()) + .map_err(move |e| anyhow!("invalid server name '{}': {}", name, e)) + } +} + +#[async_trait] +impl Transport for TlsTransport { + type Stream = ClientTlsStream; + + async fn connect(&self) -> Result { + let tcp_stream = TcpStream::connect(self.addr).await?; + let config = self.build_client_config()?; + let connector = TlsConnector::from(Arc::new(config)); + let server_name = self.resolve_server_name()?; + let tls_stream = connector.connect(server_name, tcp_stream).await?; + Ok(tls_stream) + } + + fn describe(&self) -> String { + format!("tls://{}", self.addr) + } +} + +/// Stub configuration for ACME certificate provisioning (ADR-008). +/// Feature-gated behind the `acme` feature. When implemented, this will +/// hold the ACME domain and challenge responder configuration. +#[derive(Debug)] +pub struct AcmeConfig { + pub domain: String, +} + +/// A TLS-based server transport acceptor that accepts TCP connections +/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`. +/// +/// Requires certificate and private key configuration. Supports manual +/// cert/key paths and an ACME config stub (ADR-008). +pub struct TlsAcceptor { + listener: TcpListener, + listen_addr: SocketAddr, + #[allow(dead_code)] + server_config: Arc, + tokio_acceptor: TokioTlsAcceptor, +} + +impl TlsAcceptor { + pub async fn bind( + addr: SocketAddr, + tls_certs: Vec>, + tls_key: PrivateKeyDer<'static>, + _acme_config: Option, + ) -> Result { + let listener = TcpListener::bind(addr).await?; + let listen_addr = listener.local_addr()?; + + let server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(tls_certs, tls_key)?; + + let server_config = Arc::new(server_config); + let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone()); + + Ok(Self { + listener, + listen_addr, + server_config, + tokio_acceptor, + }) + } + + pub fn listen_addr(&self) -> SocketAddr { + self.listen_addr + } +} + +#[async_trait] +impl TransportAcceptor for TlsAcceptor { + type Stream = tokio_rustls::server::TlsStream; + + async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> { + let (tcp_stream, remote_addr) = self.listener.accept().await?; + let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?; + + let server_name = tls_stream + .get_ref() + .1 + .server_name() + .map(|s| s.to_string()); + + let info = TransportInfo { + remote_addr: Some(remote_addr), + transport_kind: TransportKind::Tls { server_name }, + }; + + Ok((tls_stream, info)) + } +} + +#[derive(Debug)] +struct NoVerifier; + +impl ServerCertVerifier for NoVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> std::result::Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _doc: &DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _doc: &DigitallySignedStruct, + ) -> std::result::Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + ] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, KeyPair}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) { + let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); + let key_pair = KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + let cert_der: CertificateDer<'static> = cert.into(); + let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into()); + (cert_der, key_der) + } + + #[test] + fn tls_transport_describe_format() { + let addr: SocketAddr = "1.2.3.4:443".parse().unwrap(); + let transport = TlsTransport::new(addr).with_server_name("example.com"); + assert_eq!(transport.describe(), "tls://1.2.3.4:443"); + } + + #[test] + fn tls_transport_describe_with_ip() { + let addr: SocketAddr = "1.2.3.4:443".parse().unwrap(); + let transport = TlsTransport::new(addr); + assert_eq!(transport.describe(), "tls://1.2.3.4:443"); + } + + #[test] + fn tls_transport_builder_methods() { + let addr: SocketAddr = "1.2.3.4:443".parse().unwrap(); + let transport = TlsTransport::new(addr) + .with_server_name("wraith.test") + .with_insecure(true); + assert_eq!(transport.tls_server_name, Some("wraith.test".to_string())); + assert!(transport.insecure); + } + + #[tokio::test] + async fn tls_connect_insecure_self_signed() { + let (cert_der, key_der) = generate_self_signed_cert(); + + let acceptor = TlsAcceptor::bind( + "127.0.0.1:0".parse().unwrap(), + vec![cert_der], + key_der, + None, + ) + .await + .unwrap(); + let addr = acceptor.listen_addr(); + + let transport = TlsTransport::new(addr) + .with_server_name("localhost") + .with_insecure(true); + + let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() }); + + let mut client = transport.connect().await.unwrap(); + + let (mut server, info) = accept_handle.await.unwrap(); + assert!(info.remote_addr.is_some()); + assert!(matches!( + info.transport_kind, + TransportKind::Tls { .. } + )); + + client.write_all(b"hello tls").await.unwrap(); + let mut buf = [0u8; 9]; + server.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello tls"); + + server.write_all(b"reply").await.unwrap(); + let mut buf = [0u8; 5]; + client.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"reply"); + } + + #[tokio::test] + async fn tls_acceptor_returns_server_name() { + let (cert_der, key_der) = generate_self_signed_cert(); + + let acceptor = TlsAcceptor::bind( + "127.0.0.1:0".parse().unwrap(), + vec![cert_der], + key_der, + None, + ) + .await + .unwrap(); + let addr = acceptor.listen_addr(); + + let transport = TlsTransport::new(addr) + .with_server_name("localhost") + .with_insecure(true); + + let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() }); + + let _client = transport.connect().await.unwrap(); + + let (_, info) = accept_handle.await.unwrap(); + if let TransportKind::Tls { server_name } = info.transport_kind { + assert_eq!(server_name, Some("localhost".to_string())); + } else { + panic!("expected TransportKind::Tls"); + } + } + + #[tokio::test] + async fn tls_full_client_to_server_connection() { + let (cert_der, key_der) = generate_self_signed_cert(); + + let acceptor = TlsAcceptor::bind( + "127.0.0.1:0".parse().unwrap(), + vec![cert_der], + key_der, + None, + ) + .await + .unwrap(); + let addr = acceptor.listen_addr(); + + let transport = TlsTransport::new(addr) + .with_server_name("localhost") + .with_insecure(true); + + let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() }); + + let mut client = transport.connect().await.unwrap(); + let (mut server, _info) = accept_handle.await.unwrap(); + + let msg = b"wraith integration test"; + client.write_all(msg).await.unwrap(); + let mut buf = vec![0u8; msg.len()]; + server.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf[..], msg); + + let reply = b"ok"; + server.write_all(reply).await.unwrap(); + let mut buf = [0u8; 2]; + client.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, reply); + } + + #[tokio::test] + async fn tls_acceptor_bind_port_zero_assigns_ephemeral() { + let (cert_der, key_der) = generate_self_signed_cert(); + + let acceptor = TlsAcceptor::bind( + "127.0.0.1:0".parse().unwrap(), + vec![cert_der], + key_der, + None, + ) + .await + .unwrap(); + assert_ne!(acceptor.listen_addr().port(), 0); + } + + #[test] + fn no_verifier_accepts_any_cert() { + let verifier = NoVerifier; + assert!(verifier.supported_verify_schemes().len() > 0); + } +} \ No newline at end of file