Implement client-side SSH auth handler with ClientAuthConfig and ClientHandler

This commit is contained in:
2026-06-02 10:03:56 +00:00
parent b4f4f2ed8c
commit eb032c87f1
6 changed files with 679 additions and 15 deletions

127
Cargo.lock generated
View File

@@ -130,7 +130,7 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048"
dependencies = [ dependencies = [
"asn1-rs-derive", "asn1-rs-derive 0.5.1",
"asn1-rs-impl", "asn1-rs-impl",
"displaydoc", "displaydoc",
"nom", "nom",
@@ -140,6 +140,22 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "asn1-rs"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f43a50ac4fdca5df8e885c21b835997f0a1cdee65494a6847694a98652d9d8"
dependencies = [
"asn1-rs-derive 0.6.0",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror 2.0.18",
"time",
]
[[package]] [[package]]
name = "asn1-rs-derive" name = "asn1-rs-derive"
version = "0.5.1" version = "0.5.1"
@@ -152,6 +168,18 @@ dependencies = [
"synstructure", "synstructure",
] ]
[[package]]
name = "asn1-rs-derive"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]] [[package]]
name = "asn1-rs-impl" name = "asn1-rs-impl"
version = "0.2.0" version = "0.2.0"
@@ -337,6 +365,15 @@ dependencies = [
"sha2", "sha2",
] ]
[[package]]
name = "bit-vec"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@@ -811,7 +848,21 @@ version = "9.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
dependencies = [ dependencies = [
"asn1-rs", "asn1-rs 0.6.2",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "der-parser"
version = "10.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6"
dependencies = [
"asn1-rs 0.7.2",
"displaydoc", "displaydoc",
"nom", "nom",
"num-bigint", "num-bigint",
@@ -1933,7 +1984,7 @@ dependencies = [
"pkarr", "pkarr",
"portmapper", "portmapper",
"rand 0.8.6", "rand 0.8.6",
"rcgen", "rcgen 0.13.2",
"reqwest", "reqwest",
"ring", "ring",
"rustls", "rustls",
@@ -1952,7 +2003,7 @@ dependencies = [
"url", "url",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"webpki-roots 0.26.11", "webpki-roots 0.26.11",
"x509-parser", "x509-parser 0.16.0",
"z32", "z32",
] ]
@@ -2640,7 +2691,16 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9"
dependencies = [ dependencies = [
"asn1-rs", "asn1-rs 0.6.2",
]
[[package]]
name = "oid-registry"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7"
dependencies = [
"asn1-rs 0.7.2",
] ]
[[package]] [[package]]
@@ -3341,7 +3401,21 @@ dependencies = [
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"time", "time",
"yasna", "yasna 0.5.2",
]
[[package]]
name = "rcgen"
version = "0.14.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57f6d249aad744e274e682777a50283a225a32705394ee6d5fcc01efa25e4055"
dependencies = [
"pem",
"ring",
"rustls-pki-types",
"time",
"x509-parser 0.18.1",
"yasna 0.6.0",
] ]
[[package]] [[package]]
@@ -3735,12 +3809,12 @@ dependencies = [
"http 1.4.1", "http 1.4.1",
"log", "log",
"pem", "pem",
"rcgen", "rcgen 0.13.2",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",
"webpki-roots 0.26.11", "webpki-roots 0.26.11",
"x509-parser", "x509-parser 0.16.0",
] ]
[[package]] [[package]]
@@ -5476,15 +5550,18 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"iroh", "iroh",
"rcgen 0.14.8",
"russh", "russh",
"rustls", "rustls",
"rustls-acme", "rustls-acme",
"rustls-pki-types",
"tempfile", "tempfile",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util", "tokio-util",
"tracing", "tracing",
"webpki-roots 0.26.11",
"wraith-core", "wraith-core",
] ]
@@ -5509,17 +5586,35 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69"
dependencies = [ dependencies = [
"asn1-rs", "asn1-rs 0.6.2",
"data-encoding", "data-encoding",
"der-parser", "der-parser 9.0.0",
"lazy_static", "lazy_static",
"nom", "nom",
"oid-registry", "oid-registry 0.7.1",
"rusticata-macros", "rusticata-macros",
"thiserror 1.0.69", "thiserror 1.0.69",
"time", "time",
] ]
[[package]]
name = "x509-parser"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202"
dependencies = [
"asn1-rs 0.7.2",
"data-encoding",
"der-parser 10.0.0",
"lazy_static",
"nom",
"oid-registry 0.8.1",
"ring",
"rusticata-macros",
"thiserror 2.0.18",
"time",
]
[[package]] [[package]]
name = "xml-rs" name = "xml-rs"
version = "0.8.28" version = "0.8.28"
@@ -5544,6 +5639,16 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "yasna"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5f6765e852b9b4dc8e2a76843e4d64d1cea8e79bcde0b6901aea8e7c7f08282"
dependencies = [
"bit-vec",
"time",
]
[[package]] [[package]]
name = "yoke" name = "yoke"
version = "0.8.2" version = "0.8.2"

View File

@@ -8,7 +8,7 @@ name = "wraith_core"
[features] [features]
default = [] default = []
tls = ["dep:tokio-rustls", "dep:rustls"] tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
iroh = ["dep:iroh"] iroh = ["dep:iroh"]
acme = ["dep:rustls-acme", "tls"] acme = ["dep:rustls-acme", "tls"]
testutil = [] testutil = []
@@ -22,11 +22,14 @@ anyhow = "1"
thiserror = "2" thiserror = "2"
tokio-util = { version = "0.7", features = ["compat"] } tokio-util = { version = "0.7", features = ["compat"] }
tokio-rustls = { version = "0.26", optional = true } tokio-rustls = { version = "0.26", optional = true }
rustls = { version = "0.23", optional = true } rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
rustls-pki-types = { version = "1", optional = true }
rustls-acme = { version = "0.12", optional = true } rustls-acme = { version = "0.12", optional = true }
webpki-roots = { version = "0.26", optional = true }
iroh = { version = "0.34", optional = true } iroh = { version = "0.34", optional = true }
async-trait = "0.1" async-trait = "0.1"
[dev-dependencies] [dev-dependencies]
wraith-core = { path = ".", features = ["testutil"] } wraith-core = { path = ".", features = ["testutil", "tls"] }
tempfile = "3" tempfile = "3"
rcgen = "0.14"

View File

@@ -0,0 +1,176 @@
use std::sync::Arc;
use async_trait::async_trait;
use russh::client;
use russh::keys::key::PrivateKeyWithHashAlg;
use russh::keys::{PrivateKey, PublicKey};
use crate::auth::keys::KeySource;
use crate::error::ConfigError;
/// Client-side SSH authentication configuration.
///
/// Holds the private key used for SSH authentication and an optional
/// public key override. When no public key is provided, it is derived
/// from the private key.
pub struct ClientAuthConfig {
private_key: Arc<PrivateKey>,
public_key: PublicKey,
}
impl ClientAuthConfig {
/// Load a `ClientAuthConfig` from a key source (file or in-memory).
pub fn from_key_source(source: KeySource) -> Result<Self, ConfigError> {
let private_key = crate::auth::keys::load_private_key(source)?;
let public_key = private_key.public_key().clone();
Ok(Self {
private_key: Arc::new(private_key),
public_key,
})
}
/// Returns the private key wrapped in `Arc` for use with russh authentication.
pub fn private_key(&self) -> Arc<PrivateKey> {
Arc::clone(&self.private_key)
}
/// Returns the public key derived from (or overridden for) this config.
pub fn public_key(&self) -> &PublicKey {
&self.public_key
}
/// Authenticate with the given SSH session handle and username.
pub async fn authenticate<H: client::Handler>(
&self,
handle: &mut client::Handle<H>,
username: &str,
) -> Result<bool, russh::Error> {
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::clone(&self.private_key), None)?;
handle.authenticate_publickey(username, key_with_alg).await
}
}
/// Client handler implementing `russh::client::Handler`.
///
/// Provides the callbacks required by russh during the SSH handshake.
/// Server key verification is delegated to a configurable callback;
/// the default accepts all server keys (suitable for testing or when
/// transport-layer verification — e.g. TLS — is already in place).
pub struct ClientHandler {
pub_key: PublicKey,
check_server_key_fn: Box<dyn Fn(&PublicKey) -> bool + Send + Sync>,
}
impl ClientHandler {
/// Create a new client handler from a `ClientAuthConfig`.
pub fn from_config(config: &ClientAuthConfig) -> Self {
Self {
pub_key: config.public_key().clone(),
check_server_key_fn: Box::new(|_| true),
}
}
/// Create a client handler with a custom server key verification callback.
pub fn with_server_key_check(
config: &ClientAuthConfig,
check_fn: impl Fn(&PublicKey) -> bool + Send + Sync + 'static,
) -> Self {
Self {
pub_key: config.public_key().clone(),
check_server_key_fn: Box::new(check_fn),
}
}
/// Returns the public key associated with this handler.
pub fn public_key(&self) -> &PublicKey {
&self.pub_key
}
}
#[async_trait]
impl client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &PublicKey,
) -> Result<bool, Self::Error> {
Ok((self.check_server_key_fn)(server_public_key))
}
}
#[cfg(test)]
mod tests {
use super::*;
use russh::client::Handler;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
#[test]
fn from_key_source_memory() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
assert_eq!(
config.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn handler_from_config() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let handler = ClientHandler::from_config(&config);
assert_eq!(
handler.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn handler_with_custom_server_key_check() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let handler = ClientHandler::with_server_key_check(&config, |_pk| false);
assert_eq!(
handler.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn from_key_source_invalid_key() {
let source = KeySource::Memory(b"not a key".to_vec());
let result = ClientAuthConfig::from_key_source(source);
assert!(result.is_err());
}
#[tokio::test]
async fn handler_check_server_key_accepts_by_default() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let mut handler = ClientHandler::from_config(&config);
let some_key = config.public_key().clone();
let result = handler.check_server_key(&some_key).await.unwrap();
assert!(result);
}
#[tokio::test]
async fn handler_check_server_key_rejects_with_custom_fn() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let mut handler = ClientHandler::with_server_key_check(&config, |_pk| false);
let some_key = config.public_key().clone();
let result = handler.check_server_key(&some_key).await.unwrap();
assert!(!result);
}
#[test]
fn private_key_arc_dedup() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let key1 = config.private_key();
let key2 = config.private_key();
assert!(Arc::ptr_eq(&key1, &key2));
}
}

View File

@@ -1,3 +1,5 @@
pub mod client_auth;
pub mod keys; pub mod keys;
pub use client_auth::{ClientAuthConfig, ClientHandler};
pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys}; pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys};

View File

@@ -2,6 +2,12 @@ mod tcp;
pub use tcp::{TcpAcceptor, TcpTransport}; pub use tcp::{TcpAcceptor, TcpTransport};
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
use std::net::SocketAddr; use std::net::SocketAddr;
use anyhow::Result; use anyhow::Result;

View File

@@ -0,0 +1,372 @@
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
pub struct TlsTransport {
addr: SocketAddr,
tls_server_name: Option<String>,
insecure: bool,
root_cert: Option<CertificateDer<'static>>,
}
impl TlsTransport {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
tls_server_name: None,
insecure: false,
root_cert: None,
}
}
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
self.tls_server_name = Some(name.into());
self
}
pub fn with_insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self {
self.root_cert = Some(cert);
self
}
fn build_client_config(&self) -> Result<ClientConfig> {
if self.insecure {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
return Ok(config);
}
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(ref cert) = self.root_cert {
root_store.add(cert.clone())?;
}
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
fn resolve_server_name(&self) -> Result<ServerName<'static>> {
let name = match &self.tls_server_name {
Some(n) => n.clone(),
None => self.addr.ip().to_string(),
};
ServerName::try_from(name.clone())
.map_err(move |e| anyhow!("invalid server name '{}': {}", name, e))
}
}
#[async_trait]
impl Transport for TlsTransport {
type Stream = ClientTlsStream<TcpStream>;
async fn connect(&self) -> Result<Self::Stream> {
let tcp_stream = TcpStream::connect(self.addr).await?;
let config = self.build_client_config()?;
let connector = TlsConnector::from(Arc::new(config));
let server_name = self.resolve_server_name()?;
let tls_stream = connector.connect(server_name, tcp_stream).await?;
Ok(tls_stream)
}
fn describe(&self) -> String {
format!("tls://{}", self.addr)
}
}
#[derive(Debug)]
pub struct AcmeConfig {
pub domain: String,
}
pub struct TlsAcceptor {
listener: TcpListener,
listen_addr: SocketAddr,
#[allow(dead_code)]
server_config: Arc<ServerConfig>,
tokio_acceptor: TokioTlsAcceptor,
}
impl TlsAcceptor {
pub async fn bind(
addr: SocketAddr,
tls_certs: Vec<CertificateDer<'static>>,
tls_key: PrivateKeyDer<'static>,
_acme_config: Option<AcmeConfig>,
) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
let listen_addr = listener.local_addr()?;
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(tls_certs, tls_key)?;
let server_config = Arc::new(server_config);
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
Ok(Self {
listener,
listen_addr,
server_config,
tokio_acceptor,
})
}
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
}
#[async_trait]
impl TransportAcceptor for TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<TcpStream>;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (tcp_stream, remote_addr) = self.listener.accept().await?;
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
let server_name = tls_stream
.get_ref()
.1
.server_name()
.map(|s| s.to_string());
let info = TransportInfo {
remote_addr: Some(remote_addr),
transport_kind: TransportKind::Tls { server_name },
};
Ok((tls_stream, info))
}
}
#[derive(Debug)]
struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_doc: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_doc: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rcgen::{CertificateParams, KeyPair};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der: CertificateDer<'static> = cert.into();
let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into());
(cert_der, key_der)
}
#[test]
fn tls_transport_describe_format() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr).with_server_name("example.com");
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
}
#[test]
fn tls_transport_describe_with_ip() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr);
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
}
#[test]
fn tls_transport_builder_methods() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr)
.with_server_name("wraith.test")
.with_insecure(true);
assert_eq!(transport.tls_server_name, Some("wraith.test".to_string()));
assert!(transport.insecure);
}
#[tokio::test]
async fn tls_connect_insecure_self_signed() {
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let mut client = transport.connect().await.unwrap();
let (mut server, info) = accept_handle.await.unwrap();
assert!(info.remote_addr.is_some());
assert!(matches!(
info.transport_kind,
TransportKind::Tls { .. }
));
client.write_all(b"hello tls").await.unwrap();
let mut buf = [0u8; 9];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello tls");
server.write_all(b"reply").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"reply");
}
#[tokio::test]
async fn tls_acceptor_returns_server_name() {
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let _client = transport.connect().await.unwrap();
let (_, info) = accept_handle.await.unwrap();
if let TransportKind::Tls { server_name } = info.transport_kind {
assert_eq!(server_name, Some("localhost".to_string()));
} else {
panic!("expected TransportKind::Tls");
}
}
#[tokio::test]
async fn tls_full_client_to_server_connection() {
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let mut client = transport.connect().await.unwrap();
let (mut server, _info) = accept_handle.await.unwrap();
let msg = b"wraith integration test";
client.write_all(msg).await.unwrap();
let mut buf = vec![0u8; msg.len()];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], msg);
let reply = b"ok";
server.write_all(reply).await.unwrap();
let mut buf = [0u8; 2];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, reply);
}
#[tokio::test]
async fn tls_acceptor_bind_port_zero_assigns_ephemeral() {
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
assert_ne!(acceptor.listen_addr().port(), 0);
}
#[test]
fn no_verifier_accepts_any_cert() {
let verifier = NoVerifier;
assert!(verifier.supported_verify_schemes().len() > 0);
}
}