Implement client-side SSH auth handler with ClientAuthConfig and ClientHandler
This commit is contained in:
127
Cargo.lock
generated
127
Cargo.lock
generated
@@ -130,7 +130,7 @@ version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048"
|
||||
dependencies = [
|
||||
"asn1-rs-derive",
|
||||
"asn1-rs-derive 0.5.1",
|
||||
"asn1-rs-impl",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
@@ -140,6 +140,22 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7f43a50ac4fdca5df8e885c21b835997f0a1cdee65494a6847694a98652d9d8"
|
||||
dependencies = [
|
||||
"asn1-rs-derive 0.6.0",
|
||||
"asn1-rs-impl",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-derive"
|
||||
version = "0.5.1"
|
||||
@@ -152,6 +168,18 @@ dependencies = [
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-derive"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-impl"
|
||||
version = "0.2.0"
|
||||
@@ -337,6 +365,15 @@ dependencies = [
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-vec"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@@ -811,7 +848,21 @@ version = "9.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"asn1-rs 0.6.2",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der-parser"
|
||||
version = "10.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6"
|
||||
dependencies = [
|
||||
"asn1-rs 0.7.2",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-bigint",
|
||||
@@ -1933,7 +1984,7 @@ dependencies = [
|
||||
"pkarr",
|
||||
"portmapper",
|
||||
"rand 0.8.6",
|
||||
"rcgen",
|
||||
"rcgen 0.13.2",
|
||||
"reqwest",
|
||||
"ring",
|
||||
"rustls",
|
||||
@@ -1952,7 +2003,7 @@ dependencies = [
|
||||
"url",
|
||||
"wasm-bindgen-futures",
|
||||
"webpki-roots 0.26.11",
|
||||
"x509-parser",
|
||||
"x509-parser 0.16.0",
|
||||
"z32",
|
||||
]
|
||||
|
||||
@@ -2640,7 +2691,16 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"asn1-rs 0.6.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oid-registry"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7"
|
||||
dependencies = [
|
||||
"asn1-rs 0.7.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3341,7 +3401,21 @@ dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"time",
|
||||
"yasna",
|
||||
"yasna 0.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rcgen"
|
||||
version = "0.14.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57f6d249aad744e274e682777a50283a225a32705394ee6d5fcc01efa25e4055"
|
||||
dependencies = [
|
||||
"pem",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"time",
|
||||
"x509-parser 0.18.1",
|
||||
"yasna 0.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3735,12 +3809,12 @@ dependencies = [
|
||||
"http 1.4.1",
|
||||
"log",
|
||||
"pem",
|
||||
"rcgen",
|
||||
"rcgen 0.13.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"webpki-roots 0.26.11",
|
||||
"x509-parser",
|
||||
"x509-parser 0.16.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5476,15 +5550,18 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"iroh",
|
||||
"rcgen 0.14.8",
|
||||
"russh",
|
||||
"rustls",
|
||||
"rustls-acme",
|
||||
"rustls-pki-types",
|
||||
"tempfile",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"webpki-roots 0.26.11",
|
||||
"wraith-core",
|
||||
]
|
||||
|
||||
@@ -5509,17 +5586,35 @@ version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"asn1-rs 0.6.2",
|
||||
"data-encoding",
|
||||
"der-parser",
|
||||
"der-parser 9.0.0",
|
||||
"lazy_static",
|
||||
"nom",
|
||||
"oid-registry",
|
||||
"oid-registry 0.7.1",
|
||||
"rusticata-macros",
|
||||
"thiserror 1.0.69",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x509-parser"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202"
|
||||
dependencies = [
|
||||
"asn1-rs 0.7.2",
|
||||
"data-encoding",
|
||||
"der-parser 10.0.0",
|
||||
"lazy_static",
|
||||
"nom",
|
||||
"oid-registry 0.8.1",
|
||||
"ring",
|
||||
"rusticata-macros",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xml-rs"
|
||||
version = "0.8.28"
|
||||
@@ -5544,6 +5639,16 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yasna"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5f6765e852b9b4dc8e2a76843e4d64d1cea8e79bcde0b6901aea8e7c7f08282"
|
||||
dependencies = [
|
||||
"bit-vec",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.2"
|
||||
|
||||
@@ -8,7 +8,7 @@ name = "wraith_core"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["dep:tokio-rustls", "dep:rustls"]
|
||||
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
||||
iroh = ["dep:iroh"]
|
||||
acme = ["dep:rustls-acme", "tls"]
|
||||
testutil = []
|
||||
@@ -22,11 +22,14 @@ anyhow = "1"
|
||||
thiserror = "2"
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tokio-rustls = { version = "0.26", optional = true }
|
||||
rustls = { version = "0.23", optional = true }
|
||||
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
||||
rustls-pki-types = { version = "1", optional = true }
|
||||
rustls-acme = { version = "0.12", optional = true }
|
||||
webpki-roots = { version = "0.26", optional = true }
|
||||
iroh = { version = "0.34", optional = true }
|
||||
async-trait = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
wraith-core = { path = ".", features = ["testutil"] }
|
||||
tempfile = "3"
|
||||
wraith-core = { path = ".", features = ["testutil", "tls"] }
|
||||
tempfile = "3"
|
||||
rcgen = "0.14"
|
||||
176
crates/wraith-core/src/auth/client_auth.rs
Normal file
176
crates/wraith-core/src/auth/client_auth.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use russh::client;
|
||||
use russh::keys::key::PrivateKeyWithHashAlg;
|
||||
use russh::keys::{PrivateKey, PublicKey};
|
||||
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// Client-side SSH authentication configuration.
|
||||
///
|
||||
/// Holds the private key used for SSH authentication and an optional
|
||||
/// public key override. When no public key is provided, it is derived
|
||||
/// from the private key.
|
||||
pub struct ClientAuthConfig {
|
||||
private_key: Arc<PrivateKey>,
|
||||
public_key: PublicKey,
|
||||
}
|
||||
|
||||
impl ClientAuthConfig {
|
||||
/// Load a `ClientAuthConfig` from a key source (file or in-memory).
|
||||
pub fn from_key_source(source: KeySource) -> Result<Self, ConfigError> {
|
||||
let private_key = crate::auth::keys::load_private_key(source)?;
|
||||
let public_key = private_key.public_key().clone();
|
||||
Ok(Self {
|
||||
private_key: Arc::new(private_key),
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the private key wrapped in `Arc` for use with russh authentication.
|
||||
pub fn private_key(&self) -> Arc<PrivateKey> {
|
||||
Arc::clone(&self.private_key)
|
||||
}
|
||||
|
||||
/// Returns the public key derived from (or overridden for) this config.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.public_key
|
||||
}
|
||||
|
||||
/// Authenticate with the given SSH session handle and username.
|
||||
pub async fn authenticate<H: client::Handler>(
|
||||
&self,
|
||||
handle: &mut client::Handle<H>,
|
||||
username: &str,
|
||||
) -> Result<bool, russh::Error> {
|
||||
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::clone(&self.private_key), None)?;
|
||||
handle.authenticate_publickey(username, key_with_alg).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Client handler implementing `russh::client::Handler`.
|
||||
///
|
||||
/// Provides the callbacks required by russh during the SSH handshake.
|
||||
/// Server key verification is delegated to a configurable callback;
|
||||
/// the default accepts all server keys (suitable for testing or when
|
||||
/// transport-layer verification — e.g. TLS — is already in place).
|
||||
pub struct ClientHandler {
|
||||
pub_key: PublicKey,
|
||||
check_server_key_fn: Box<dyn Fn(&PublicKey) -> bool + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create a new client handler from a `ClientAuthConfig`.
|
||||
pub fn from_config(config: &ClientAuthConfig) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(|_| true),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a client handler with a custom server key verification callback.
|
||||
pub fn with_server_key_check(
|
||||
config: &ClientAuthConfig,
|
||||
check_fn: impl Fn(&PublicKey) -> bool + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(check_fn),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the public key associated with this handler.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.pub_key
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl client::Handler for ClientHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn check_server_key(
|
||||
&mut self,
|
||||
server_public_key: &PublicKey,
|
||||
) -> Result<bool, Self::Error> {
|
||||
Ok((self.check_server_key_fn)(server_public_key))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use russh::client::Handler;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
#[test]
|
||||
fn from_key_source_memory() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
assert_eq!(
|
||||
config.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_from_config() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::from_config(&config);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_with_custom_server_key_check() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_key_source_invalid_key() {
|
||||
let source = KeySource::Memory(b"not a key".to_vec());
|
||||
let result = ClientAuthConfig::from_key_source(source);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_accepts_by_default() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::from_config(&config);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_rejects_with_custom_fn() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_key_arc_dedup() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let key1 = config.private_key();
|
||||
let key2 = config.private_key();
|
||||
assert!(Arc::ptr_eq(&key1, &key2));
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
pub mod client_auth;
|
||||
pub mod keys;
|
||||
|
||||
pub use client_auth::{ClientAuthConfig, ClientHandler};
|
||||
pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys};
|
||||
@@ -2,6 +2,12 @@ mod tcp;
|
||||
|
||||
pub use tcp::{TcpAcceptor, TcpTransport};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
372
crates/wraith-core/src/transport/tls.rs
Normal file
372
crates/wraith-core/src/transport/tls.rs
Normal file
@@ -0,0 +1,372 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
||||
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
pub struct TlsTransport {
|
||||
addr: SocketAddr,
|
||||
tls_server_name: Option<String>,
|
||||
insecure: bool,
|
||||
root_cert: Option<CertificateDer<'static>>,
|
||||
}
|
||||
|
||||
impl TlsTransport {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
root_cert: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self {
|
||||
self.root_cert = Some(cert);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_client_config(&self) -> Result<ClientConfig> {
|
||||
if self.insecure {
|
||||
let config = ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
return Ok(config);
|
||||
}
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
if let Some(ref cert) = self.root_cert {
|
||||
root_store.add(cert.clone())?;
|
||||
}
|
||||
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn resolve_server_name(&self) -> Result<ServerName<'static>> {
|
||||
let name = match &self.tls_server_name {
|
||||
Some(n) => n.clone(),
|
||||
None => self.addr.ip().to_string(),
|
||||
};
|
||||
ServerName::try_from(name.clone())
|
||||
.map_err(move |e| anyhow!("invalid server name '{}': {}", name, e))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for TlsTransport {
|
||||
type Stream = ClientTlsStream<TcpStream>;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let tcp_stream = TcpStream::connect(self.addr).await?;
|
||||
let config = self.build_client_config()?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
let server_name = self.resolve_server_name()?;
|
||||
let tls_stream = connector.connect(server_name, tcp_stream).await?;
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("tls://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AcmeConfig {
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
pub struct TlsAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
#[allow(dead_code)]
|
||||
server_config: Arc<ServerConfig>,
|
||||
tokio_acceptor: TokioTlsAcceptor,
|
||||
}
|
||||
|
||||
impl TlsAcceptor {
|
||||
pub async fn bind(
|
||||
addr: SocketAddr,
|
||||
tls_certs: Vec<CertificateDer<'static>>,
|
||||
tls_key: PrivateKeyDer<'static>,
|
||||
_acme_config: Option<AcmeConfig>,
|
||||
) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certs, tls_key)?;
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<TcpStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||
|
||||
let server_name = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.server_name()
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tls { server_name },
|
||||
};
|
||||
|
||||
Ok((tls_stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> std::result::Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
|
||||
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
let cert_der: CertificateDer<'static> = cert.into();
|
||||
let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into());
|
||||
(cert_der, key_der)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_format() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr).with_server_name("example.com");
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_with_ip() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr);
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_builder_methods() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("wraith.test")
|
||||
.with_insecure(true);
|
||||
assert_eq!(transport.tls_server_name, Some("wraith.test".to_string()));
|
||||
assert!(transport.insecure);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_connect_insecure_self_signed() {
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
|
||||
let (mut server, info) = accept_handle.await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(
|
||||
info.transport_kind,
|
||||
TransportKind::Tls { .. }
|
||||
));
|
||||
|
||||
client.write_all(b"hello tls").await.unwrap();
|
||||
let mut buf = [0u8; 9];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello tls");
|
||||
|
||||
server.write_all(b"reply").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"reply");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_returns_server_name() {
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let _client = transport.connect().await.unwrap();
|
||||
|
||||
let (_, info) = accept_handle.await.unwrap();
|
||||
if let TransportKind::Tls { server_name } = info.transport_kind {
|
||||
assert_eq!(server_name, Some("localhost".to_string()));
|
||||
} else {
|
||||
panic!("expected TransportKind::Tls");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_full_client_to_server_connection() {
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
let (mut server, _info) = accept_handle.await.unwrap();
|
||||
|
||||
let msg = b"wraith integration test";
|
||||
client.write_all(msg).await.unwrap();
|
||||
let mut buf = vec![0u8; msg.len()];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..], msg);
|
||||
|
||||
let reply = b"ok";
|
||||
server.write_all(reply).await.unwrap();
|
||||
let mut buf = [0u8; 2];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, reply);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_bind_port_zero_assigns_ephemeral() {
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_verifier_accepts_any_cert() {
|
||||
let verifier = NoVerifier;
|
||||
assert!(verifier.supported_verify_schemes().len() > 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user