refactor!: rebrand wraith to alknet

Rename all crates, CLI commands, constants, type names, doc comments,
and documentation from wraith to alknet. Includes wire-protocol changes:
ALPN wraith-ssh -> alknet-ssh, reserved destination prefix wraith- ->
alknet-, SSH auth username wraith -> alknet.
This commit is contained in:
2026-06-05 10:04:32 +00:00
parent af7f4d0006
commit 596c89ce24
101 changed files with 552 additions and 552 deletions

View File

@@ -0,0 +1,44 @@
[package]
name = "alknet-core"
version.workspace = true
edition.workspace = true
license.workspace = true
description = "Core library for Alknet: pluggable SSH tunnel transport, SOCKS5 proxy, port forwarding, and authentication"
repository.workspace = true
[lib]
name = "alknet_core"
[features]
default = []
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
iroh = ["dep:iroh", "dep:url"]
acme = ["dep:rustls-acme", "dep:futures", "tls"]
testutil = []
transport-traits = []
[dependencies]
russh = "0.49"
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
anyhow = "1"
thiserror = "2"
tokio-util = { version = "0.7", features = ["compat"] }
tokio-rustls = { version = "0.26", optional = true }
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
rustls-pki-types = { version = "1", optional = true }
rustls-acme = { version = "0.12", optional = true }
futures = { version = "0.3", optional = true }
webpki-roots = { version = "0.26", optional = true }
iroh = { version = "0.34", optional = true }
url = { version = "2", optional = true }
async-trait = "0.1"
ipnetwork = "0.21.1"
[dev-dependencies]
alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] }
tempfile = "3"
rcgen = "0.14"
rand_core = "0.6"
ssh-key = { version = "0.6", features = ["ed25519", "alloc"] }
rand = "0.10.1"

View File

@@ -0,0 +1,176 @@
use std::sync::Arc;
use async_trait::async_trait;
use russh::client;
use russh::keys::key::PrivateKeyWithHashAlg;
use russh::keys::{PrivateKey, PublicKey};
use crate::auth::keys::KeySource;
use crate::error::ConfigError;
/// Client-side SSH authentication configuration.
///
/// Holds the private key used for SSH authentication and an optional
/// public key override. When no public key is provided, it is derived
/// from the private key.
pub struct ClientAuthConfig {
private_key: Arc<PrivateKey>,
public_key: PublicKey,
}
impl ClientAuthConfig {
/// Load a `ClientAuthConfig` from a key source (file or in-memory).
pub fn from_key_source(source: KeySource) -> Result<Self, ConfigError> {
let private_key = crate::auth::keys::load_private_key(source)?;
let public_key = private_key.public_key().clone();
Ok(Self {
private_key: Arc::new(private_key),
public_key,
})
}
/// Returns the private key wrapped in `Arc` for use with russh authentication.
pub fn private_key(&self) -> Arc<PrivateKey> {
Arc::clone(&self.private_key)
}
/// Returns the public key derived from (or overridden for) this config.
pub fn public_key(&self) -> &PublicKey {
&self.public_key
}
/// Authenticate with the given SSH session handle and username.
pub async fn authenticate<H: client::Handler>(
&self,
handle: &mut client::Handle<H>,
username: &str,
) -> Result<bool, russh::Error> {
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::clone(&self.private_key), None)?;
handle.authenticate_publickey(username, key_with_alg).await
}
}
/// Client handler implementing `russh::client::Handler`.
///
/// Provides the callbacks required by russh during the SSH handshake.
/// Server key verification is delegated to a configurable callback;
/// the default accepts all server keys (suitable for testing or when
/// transport-layer verification — e.g. TLS — is already in place).
pub struct ClientHandler {
pub_key: PublicKey,
check_server_key_fn: Box<dyn Fn(&PublicKey) -> bool + Send + Sync>,
}
impl ClientHandler {
/// Create a new client handler from a `ClientAuthConfig`.
pub fn from_config(config: &ClientAuthConfig) -> Self {
Self {
pub_key: config.public_key().clone(),
check_server_key_fn: Box::new(|_| true),
}
}
/// Create a client handler with a custom server key verification callback.
pub fn with_server_key_check(
config: &ClientAuthConfig,
check_fn: impl Fn(&PublicKey) -> bool + Send + Sync + 'static,
) -> Self {
Self {
pub_key: config.public_key().clone(),
check_server_key_fn: Box::new(check_fn),
}
}
/// Returns the public key associated with this handler.
pub fn public_key(&self) -> &PublicKey {
&self.pub_key
}
}
#[async_trait]
impl client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &PublicKey,
) -> Result<bool, Self::Error> {
Ok((self.check_server_key_fn)(server_public_key))
}
}
#[cfg(test)]
mod tests {
use super::*;
use russh::client::Handler;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
#[test]
fn from_key_source_memory() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
assert_eq!(
config.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn handler_from_config() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let handler = ClientHandler::from_config(&config);
assert_eq!(
handler.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn handler_with_custom_server_key_check() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let handler = ClientHandler::with_server_key_check(&config, |_pk| false);
assert_eq!(
handler.public_key().algorithm(),
russh::keys::Algorithm::Ed25519
);
}
#[test]
fn from_key_source_invalid_key() {
let source = KeySource::Memory(b"not a key".to_vec());
let result = ClientAuthConfig::from_key_source(source);
assert!(result.is_err());
}
#[tokio::test]
async fn handler_check_server_key_accepts_by_default() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let mut handler = ClientHandler::from_config(&config);
let some_key = config.public_key().clone();
let result = handler.check_server_key(&some_key).await.unwrap();
assert!(result);
}
#[tokio::test]
async fn handler_check_server_key_rejects_with_custom_fn() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let mut handler = ClientHandler::with_server_key_check(&config, |_pk| false);
let some_key = config.public_key().clone();
let result = handler.check_server_key(&some_key).await.unwrap();
assert!(!result);
}
#[test]
fn private_key_arc_dedup() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source).unwrap();
let key1 = config.private_key();
let key2 = config.private_key();
assert!(Arc::ptr_eq(&key1, &key2));
}
}

View File

@@ -0,0 +1,263 @@
//! Key loading and parsing for SSH authentication.
//!
//! Supports `KeySource` (file path or in-memory) for private keys, public keys,
//! and certificate authority entries. All keys must be in OpenSSH format.
//! PEM-encoded keys (PKCS#1, PKCS#8) are rejected with a clear error message.
use std::path::PathBuf;
use russh::keys::{PrivateKey, PublicKey, decode_secret_key, parse_public_key_base64};
use crate::error::ConfigError;
/// Source for key material — either a filesystem path or in-memory bytes.
///
/// Used throughout the API to accept keys without committing to a specific
/// loading mechanism. In-memory keys are primarily for the NAPI wrapper.
#[derive(Debug, Clone)]
pub enum KeySource {
File(PathBuf),
Memory(Vec<u8>),
}
/// A certificate authority entry parsed from an `authorized_keys` file.
///
/// Contains the CA public key and its associated options (e.g., `cert-authority`,
/// `permit-port-forwarding`). Used by `ServerAuthConfig` for certificate validation.
#[derive(Debug, Clone)]
pub struct CertAuthorityEntry {
pub public_key: PublicKey,
pub options: Vec<String>,
}
fn resolve_bytes(source: &KeySource) -> Result<Vec<u8>, ConfigError> {
match source {
KeySource::File(path) => {
if !path.exists() {
return Err(ConfigError::KeyFileNotFound {
path: path.display().to_string(),
});
}
std::fs::read(path).map_err(|_| ConfigError::KeyFileNotFound {
path: path.display().to_string(),
})
}
KeySource::Memory(data) => Ok(data.clone()),
}
}
fn check_openssh_private_key(data: &[u8]) -> Result<(), ConfigError> {
let s = String::from_utf8_lossy(data);
if s.contains("-----BEGIN OPENSSH PRIVATE KEY-----") {
return Ok(());
}
if s.contains("-----BEGIN RSA PRIVATE KEY-----")
|| s.contains("-----BEGIN PRIVATE KEY-----")
|| s.contains("-----BEGIN ENCRYPTED PRIVATE KEY-----")
|| s.contains("-----BEGIN EC PRIVATE KEY-----")
{
return Err(ConfigError::InvalidFlag {
name: "PEM-encoded key is not supported; use OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
});
}
Err(ConfigError::InvalidFlag {
name: "unrecognized private key format; expected OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
})
}
pub fn load_private_key(source: KeySource) -> Result<PrivateKey, ConfigError> {
let data = resolve_bytes(&source)?;
check_openssh_private_key(&data)?;
let s = String::from_utf8_lossy(&data);
decode_secret_key(&s, None).map_err(|e| ConfigError::InvalidFlag {
name: format!("failed to decode private key: {e}"),
})
}
fn parse_authorized_keys_line(line: &str) -> Option<Result<(PublicKey, Vec<String>), ConfigError>> {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
return None;
}
let parts: Vec<&str> = line.splitn(4, ' ').collect();
if parts.len() < 2 {
return None;
}
let mut options = Vec::new();
let key_type_idx;
if parts[0].starts_with("cert-authority")
|| parts[0].starts_with("no-")
|| parts[0].starts_with("permit-")
|| parts[0].starts_with("from=")
|| parts[0].starts_with("command=")
|| parts[0].starts_with("environment=")
|| parts[0].starts_with("tunnel=")
|| parts[0].starts_with("principals=")
{
let opts_str = parts[0];
options = opts_str
.split(',')
.map(|s| s.to_string())
.collect();
key_type_idx = 1;
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
key_type_idx = 0;
} else {
return None;
}
if parts.len() <= key_type_idx {
return None;
}
let key_base64 = parts[key_type_idx + 1];
match parse_public_key_base64(key_base64) {
Ok(pk) => Some(Ok((pk, options))),
Err(_) => None,
}
}
pub fn load_public_keys(source: KeySource) -> Result<Vec<PublicKey>, ConfigError> {
let data = resolve_bytes(&source)?;
let s = String::from_utf8_lossy(&data);
let mut keys = Vec::new();
for line in s.lines() {
if let Some(Ok((pk, _))) = parse_authorized_keys_line(line) {
keys.push(pk);
}
}
Ok(keys)
}
pub fn load_cert_authority_entries(
source: KeySource,
) -> Result<Vec<CertAuthorityEntry>, ConfigError> {
let data = resolve_bytes(&source)?;
let s = String::from_utf8_lossy(&data);
let mut entries = Vec::new();
for line in s.lines() {
if let Some(result) = parse_authorized_keys_line(line) {
match result {
Ok((pk, options)) if !options.is_empty() => {
entries.push(CertAuthorityEntry {
public_key: pk,
options,
});
}
_ => {}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
const PEM_PRIVATE_KEY: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC\n-----END PRIVATE KEY-----\n";
fn make_authorized_keys(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
write!(f, "{content}").unwrap();
f.flush().unwrap();
f
}
fn make_private_key_file(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
#[test]
fn load_ed25519_key_from_file() {
let f = make_private_key_file(ED25519_PRIVATE_KEY);
let source = KeySource::File(f.path().to_path_buf());
let key = load_private_key(source).unwrap();
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
}
#[test]
fn load_ed25519_key_from_memory() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let key = load_private_key(source).unwrap();
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
}
#[test]
fn load_key_file_not_found() {
let source = KeySource::File(PathBuf::from("/nonexistent/key"));
let result = load_private_key(source);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ConfigError::KeyFileNotFound { .. }));
assert!(err.to_string().contains("/nonexistent/key"));
}
#[test]
fn reject_pem_format() {
let source = KeySource::Memory(PEM_PRIVATE_KEY.to_vec());
let result = load_private_key(source);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ConfigError::InvalidFlag { .. }));
assert!(err.to_string().contains("PEM"));
}
const ED25519_PUBLIC_KEY_2: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
#[test]
fn parse_authorized_keys_multiple_entries() {
let content = format!(
"{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n"
);
let f = make_authorized_keys(&content);
let source = KeySource::File(f.path().to_path_buf());
let keys = load_public_keys(source).unwrap();
assert_eq!(keys.len(), 2);
}
#[test]
fn parse_authorized_keys_from_memory() {
let content = format!("{ED25519_PUBLIC_KEY}\n");
let source = KeySource::Memory(content.into_bytes());
let keys = load_public_keys(source).unwrap();
assert_eq!(keys.len(), 1);
}
#[test]
fn parse_cert_authority_entry() {
let content =
"cert-authority,permit-port-forwarding ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV CA name\n";
let f = make_authorized_keys(content);
let source = KeySource::File(f.path().to_path_buf());
let entries = load_cert_authority_entries(source).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].options.len(), 2);
assert_eq!(entries[0].options[0], "cert-authority");
assert_eq!(entries[0].options[1], "permit-port-forwarding");
}
#[test]
fn parse_mixed_authorized_keys() {
let content = format!(
"{ED25519_PUBLIC_KEY}\ncert-authority ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE CA name\n"
);
let source = KeySource::Memory(content.into_bytes());
let keys = load_public_keys(source.clone()).unwrap();
assert_eq!(keys.len(), 2);
let entries = load_cert_authority_entries(source).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].options, vec!["cert-authority"]);
}
}

View File

@@ -0,0 +1,12 @@
//! SSH authentication (Ed25519 public key and OpenSSH certificate authority).
//!
//! Supports file-path and in-memory key sources. No password authentication.
//! See ADR-012 for the design rationale.
pub mod client_auth;
pub mod keys;
pub mod server_auth;
pub use client_auth::{ClientAuthConfig, ClientHandler};
pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys};
pub use server_auth::ServerAuthConfig;

View File

@@ -0,0 +1,409 @@
//! Server-side authentication configuration and validation.
//!
//! `ServerAuthConfig` holds the set of authorized public keys and optional certificate
//! authority entries. Authentication is key-based only (Ed25519 + optional OpenSSH CA).
//! No password authentication. See ADR-012.
use std::collections::HashSet;
use std::net::IpAddr;
use std::str::FromStr;
use std::time::SystemTime;
use ipnetwork::IpNetwork;
use russh::keys::helpers::EncodedExt;
use russh::keys::{Certificate, PublicKey};
use super::keys::{CertAuthorityEntry, KeySource, load_cert_authority_entries, load_public_keys};
use crate::error::AuthError;
/// Server-side authentication configuration.
///
/// Holds authorized public keys (constant-time comparison) and optional certificate
/// authority entries for validating OpenSSH certificates.
#[derive(Debug, Clone)]
pub struct ServerAuthConfig {
pub authorized_keys: HashSet<PublicKey>,
pub cert_authorities: Vec<CertAuthorityEntry>,
encoded_keys: HashSet<Vec<u8>>,
}
fn encode_key_data(key: &PublicKey) -> Vec<u8> {
key.key_data().encoded().unwrap_or_default()
}
impl ServerAuthConfig {
pub fn from_keys_and_ca(
authorized_keys_source: Option<KeySource>,
cert_authority_source: Option<KeySource>,
) -> Result<Self, crate::error::ConfigError> {
let authorized_keys: HashSet<PublicKey> = match authorized_keys_source {
Some(src) => load_public_keys(src)?.into_iter().collect(),
None => HashSet::new(),
};
let encoded_keys: HashSet<Vec<u8>> = authorized_keys
.iter()
.map(encode_key_data)
.collect();
let cert_authorities = match cert_authority_source {
Some(src) => load_cert_authority_entries(src)?,
None => Vec::new(),
};
Ok(ServerAuthConfig {
authorized_keys,
cert_authorities,
encoded_keys,
})
}
pub fn authenticate_publickey(&self, key: &PublicKey) -> Result<(), AuthError> {
let encoded = encode_key_data(key);
if self.encoded_keys.contains(&encoded) {
return Ok(());
}
Err(AuthError::KeyRejected)
}
pub fn authenticate_certificate(
&self,
cert: &Certificate,
user: &str,
client_ip: Option<IpAddr>,
) -> Result<(), AuthError> {
let matching_ca = self
.cert_authorities
.iter()
.find(|ca| cert.signature_key() == ca.public_key.key_data());
let ca_entry = match matching_ca {
Some(entry) => entry,
None => return Err(AuthError::CertInvalid),
};
if cert.verify_signature().is_err() {
return Err(AuthError::CertInvalid);
}
let now = SystemTime::now();
let now_secs = now
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
return Err(AuthError::CertExpired);
}
let principals = cert.valid_principals();
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
return Err(AuthError::CertPrincipalMismatch);
}
check_critical_options(cert, ca_entry, client_ip)?;
check_extensions(cert, ca_entry)?;
Ok(())
}
}
fn check_critical_options(
cert: &Certificate,
ca_entry: &CertAuthorityEntry,
client_ip: Option<IpAddr>,
) -> Result<(), AuthError> {
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
for (name, data) in cert.critical_options().iter() {
match name.as_str() {
"source-address" => {
if !check_source_address(data, client_ip) {
return Err(AuthError::CertInvalid);
}
}
"force-command" => {}
"no-pty" => {}
_ => {
let _ = ca_has_no_pty;
return Err(AuthError::CertInvalid);
}
}
}
Ok(())
}
fn check_extensions(
cert: &Certificate,
ca_entry: &CertAuthorityEntry,
) -> Result<(), AuthError> {
let ca_permit_port_forwarding = ca_entry
.options
.iter()
.any(|o| o == "permit-port-forwarding");
if ca_permit_port_forwarding {
let cert_allows = cert
.extensions()
.iter()
.any(|(n, _)| n == "permit-port-forwarding");
if !cert_allows {
return Err(AuthError::CertInvalid);
}
}
Ok(())
}
fn check_source_address(allowed: &str, client_ip: Option<IpAddr>) -> bool {
let Some(ip) = client_ip else {
return false;
};
for pattern in allowed.split(',') {
let pattern = pattern.trim();
if pattern.is_empty() {
continue;
}
if let Ok(cidr) = IpNetwork::from_str(pattern) {
if cidr.contains(ip) {
return true;
}
}
if let Ok(net_ip) = IpAddr::from_str(pattern) {
if net_ip == ip {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::OsRng;
use russh::keys::{Certificate, PrivateKey, decode_secret_key};
use russh::keys::ssh_key::certificate::{Builder, CertType};
use std::io::Write;
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const USER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5NgAAAJgM/+f3DP/n\n9wAAAAtzc2gtZWQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5Ng\nAAAEADN/ZEFvX/mflX8aEGwS/tMzys564rYEaMzd4vmYKZkShOvxfseqW24oF0F0HZWNv4\nrtuLe9U9y5YBghve6vk2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const OTHER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdgAAAJgXj2UzF49l\nMwAAAAtzc2gtZWQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdg\nAAAEBVadyi5nAUfkjpp4zyQ08b8h1o4RTEgwtLejTjX5Tycb/tXYstPhZGbUx+N7x5I9aW\nE36Q1fPavIqioKRKsbN2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
fn load_ca_key() -> PrivateKey {
decode_secret_key(CA_PRIVATE_KEY, None).unwrap()
}
fn load_user_key() -> PrivateKey {
decode_secret_key(USER_PRIVATE_KEY, None).unwrap()
}
fn load_other_key() -> PrivateKey {
decode_secret_key(OTHER_PRIVATE_KEY, None).unwrap()
}
fn make_cert(
ca_key: &PrivateKey,
user_pub: &PublicKey,
valid_after: u64,
valid_before: u64,
principals: Vec<&str>,
) -> Certificate {
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
let mut builder = Builder::new_with_random_nonce(
&mut OsRng,
key_data,
valid_after,
valid_before,
)
.unwrap();
builder.cert_type(CertType::User).unwrap();
for p in principals {
builder.valid_principal(p).unwrap();
}
builder.sign(ca_key).unwrap()
}
fn make_authorized_keys_file(keys: &[&PublicKey]) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
for key in keys {
let line = format!("{}\n", key.to_openssh().unwrap());
f.write_all(line.as_bytes()).unwrap();
}
f.flush().unwrap();
f
}
fn make_ca_file(ca_pub: &PublicKey, options: &[&str]) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
let opts = if options.is_empty() {
"cert-authority".to_string()
} else {
format!("cert-authority,{}", options.join(","))
};
let line = format!(
"{} {} CA\n",
opts,
ca_pub.to_openssh().unwrap()
);
f.write_all(line.as_bytes()).unwrap();
f.flush().unwrap();
f
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
}
#[test]
fn valid_key_accepted() {
let user_key = load_user_key();
let user_pub = user_key.public_key().clone();
let f = make_authorized_keys_file(&[&user_pub]);
let config =
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
.unwrap();
assert!(config.authenticate_publickey(&user_pub).is_ok());
}
#[test]
fn invalid_key_rejected() {
let user_key = load_user_key();
let other_key = load_other_key();
let user_pub = user_key.public_key().clone();
let other_pub = other_key.public_key().clone();
let f = make_authorized_keys_file(&[&user_pub]);
let config =
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
.unwrap();
assert_eq!(
config.authenticate_publickey(&other_pub),
Err(AuthError::KeyRejected)
);
}
#[test]
fn cert_authority_signed_cert_accepted() {
let ca_key = load_ca_key();
let user_key = load_user_key();
let ca_pub = ca_key.public_key().clone();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
let f = make_ca_file(&ca_pub, &[]);
let config =
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
.unwrap();
assert!(config
.authenticate_certificate(&cert, "testuser", None)
.is_ok());
}
#[test]
fn expired_cert_rejected() {
let ca_key = load_ca_key();
let user_key = load_user_key();
let ca_pub = ca_key.public_key().clone();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let cert = make_cert(&ca_key, &user_pub, now - 7200, now - 3600, vec!["testuser"]);
let f = make_ca_file(&ca_pub, &[]);
let config =
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
.unwrap();
assert_eq!(
config.authenticate_certificate(&cert, "testuser", None),
Err(AuthError::CertExpired)
);
}
#[test]
fn wrong_principal_rejected() {
let ca_key = load_ca_key();
let user_key = load_user_key();
let ca_pub = ca_key.public_key().clone();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["alice"]);
let f = make_ca_file(&ca_pub, &[]);
let config =
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
.unwrap();
assert_eq!(
config.authenticate_certificate(&cert, "bob", None),
Err(AuthError::CertPrincipalMismatch)
);
}
#[test]
fn cert_wildcard_principals_accepts_any_user() {
let ca_key = load_ca_key();
let user_key = load_user_key();
let ca_pub = ca_key.public_key().clone();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
let mut builder = Builder::new_with_random_nonce(
&mut OsRng,
key_data,
now - 60,
now + 3600,
)
.unwrap();
builder.cert_type(CertType::User).unwrap();
builder.all_principals_valid().unwrap();
let cert = builder.sign(&ca_key).unwrap();
let f = make_ca_file(&ca_pub, &[]);
let config =
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
.unwrap();
assert!(config
.authenticate_certificate(&cert, "anyuser", None)
.is_ok());
}
#[test]
fn cert_wrong_ca_rejected() {
let user_key = load_user_key();
let other_ca_key = load_other_key();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let cert = make_cert(&other_ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
let ca_key = load_ca_key();
let ca_pub = ca_key.public_key().clone();
let f = make_ca_file(&ca_pub, &[]);
let config =
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
.unwrap();
assert_eq!(
config.authenticate_certificate(&cert, "testuser", None),
Err(AuthError::CertInvalid)
);
}
#[test]
fn no_config_accepts_nothing() {
let config =
ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
let other_pub = load_other_key().public_key().clone();
assert_eq!(
config.authenticate_publickey(&other_pub),
Err(AuthError::KeyRejected)
);
}
}

View File

@@ -0,0 +1,479 @@
//! Channel manager with automatic reconnection.
//!
//! Owns the SSH session handle and provides `open_direct_tcpip()`,
//! `request_tcpip_forward()`, and `cancel_tcpip_forward()`. Monitors
//! the session for disconnect and attempts reconnection with exponential
//! backoff (1s, 2s, 4s, ..., 30s cap). Re-registers remote forwards
//! after successful reconnection.
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use russh::client;
use tokio::sync::RwLock;
use tokio::time;
use tracing::{debug, error, info, warn};
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
use crate::error::ChannelError;
use crate::transport::Transport;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct ForwardRequest {
pub addr: String,
pub port: u32,
}
struct ChannelManagerInner<T: Transport> {
transport: Arc<T>,
auth_config: Arc<ClientAuthConfig>,
handle: Arc<RwLock<client::Handle<ClientHandler>>>,
username: String,
forwards: RwLock<HashSet<ForwardRequest>>,
reconnect_attempts: RwLock<u32>,
}
pub struct ChannelManager<T: Transport> {
inner: Arc<ChannelManagerInner<T>>,
reconnect_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
}
impl<T: Transport> ChannelManager<T> {
pub async fn new(
transport: Arc<T>,
auth_config: Arc<ClientAuthConfig>,
username: String,
) -> Result<Self, ChannelError> {
let handler = ClientHandler::from_config(&auth_config);
let handle = Self::establish_session(&*transport, handler, &auth_config, &username)
.await
.map_err(|_| ChannelError::TargetUnreachable)?;
let inner = Arc::new(ChannelManagerInner {
transport,
auth_config,
handle: Arc::new(RwLock::new(handle)),
username,
forwards: RwLock::new(HashSet::new()),
reconnect_attempts: RwLock::new(0),
});
let reconnect_handle = Arc::new(RwLock::new(None));
let manager = Self {
inner,
reconnect_handle,
};
manager.start_reconnect_monitor();
Ok(manager)
}
async fn establish_session(
transport: &T,
handler: ClientHandler,
auth_config: &ClientAuthConfig,
username: &str,
) -> Result<client::Handle<ClientHandler>, russh::Error> {
let stream = transport.connect().await.map_err(|e| {
error!("transport connect failed: {e}");
russh::Error::SendError
})?;
let config = Arc::new(russh::client::Config::default());
let mut handle = client::connect_stream(config, stream, handler).await?;
let auth_ok = auth_config.authenticate(&mut handle, username).await?;
if !auth_ok {
return Err(russh::Error::SendError);
}
Ok(handle)
}
pub async fn open_direct_tcpip(
&self,
host: &str,
port: u32,
) -> Result<russh::Channel<russh::client::Msg>, ChannelError> {
let handle = self.inner.handle.read().await;
handle
.channel_open_direct_tcpip(host, port, "127.0.0.1", 0)
.await
.map_err(|e| {
debug!("channel open failed: {e}");
ChannelError::ChannelClosed
})
}
pub async fn request_tcpip_forward(&self, addr: &str, port: u32) -> Result<u32, ChannelError> {
let mut handle = self.inner.handle.write().await;
let result = handle
.tcpip_forward(addr, port)
.await
.map_err(|_| ChannelError::ChannelClosed)?;
self.inner
.forwards
.write()
.await
.insert(ForwardRequest {
addr: addr.to_string(),
port,
});
Ok(result)
}
pub async fn cancel_tcpip_forward(&self, addr: &str, port: u32) -> Result<(), ChannelError> {
let handle = self.inner.handle.read().await;
handle
.cancel_tcpip_forward(addr, port)
.await
.map_err(|_| ChannelError::ChannelClosed)?;
self.inner
.forwards
.write()
.await
.remove(&ForwardRequest {
addr: addr.to_string(),
port,
});
Ok(())
}
pub async fn is_connected(&self) -> bool {
let handle = self.inner.handle.read().await;
!handle.is_closed()
}
fn start_reconnect_monitor(&self) {
let inner = Arc::clone(&self.inner);
let handle_arc = Arc::clone(&self.inner.handle);
let join_handle = tokio::spawn(async move {
loop {
time::sleep(Duration::from_secs(1)).await;
let handle = handle_arc.read().await;
if handle.is_closed() {
drop(handle);
info!("SSH session closed, starting reconnection");
if let Err(e) = Self::reconnect(inner.clone()).await {
error!("reconnection failed: {e}");
}
}
}
});
let reconnect_handle = Arc::clone(&self.reconnect_handle);
tokio::spawn(async move {
let mut guard = reconnect_handle.write().await;
*guard = Some(join_handle);
});
}
async fn reconnect(inner: Arc<ChannelManagerInner<T>>) -> Result<(), ChannelError> {
let mut attempts = inner.reconnect_attempts.write().await;
let attempt_num = *attempts;
let backoff = backoff_duration(attempt_num);
*attempts += 1;
drop(attempts);
warn!(
"reconnect attempt #{}, waiting {:?}",
attempt_num + 1,
backoff
);
time::sleep(backoff).await;
let handler = ClientHandler::from_config(&inner.auth_config);
match Self::establish_session(
&*inner.transport,
handler,
&inner.auth_config,
&inner.username,
)
.await
{
Ok(new_handle) => {
info!("reconnection successful");
{
let mut handle_guard = inner.handle.write().await;
*handle_guard = new_handle;
}
{
let mut attempts = inner.reconnect_attempts.write().await;
*attempts = 0;
}
Self::re_register_forwards(&inner).await;
Ok(())
}
Err(e) => {
warn!("reconnection attempt failed: {e}");
Err(ChannelError::ChannelClosed)
}
}
}
async fn re_register_forwards(inner: &ChannelManagerInner<T>) {
let forwards = inner.forwards.read().await;
if forwards.is_empty() {
return;
}
let mut handle = inner.handle.write().await;
for fwd in forwards.iter() {
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
Ok(_) => {
debug!(
"re-registered tcpip_forward: {}:{}",
fwd.addr, fwd.port
);
}
Err(e) => {
warn!(
"failed to re-register tcpip_forward {}:{}: {e}",
fwd.addr, fwd.port
);
}
}
}
}
}
/// Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely.
fn backoff_duration(attempt: u32) -> Duration {
let secs: u64 = match attempt {
0 => 1,
1 => 2,
2 => 4,
3 => 8,
4 => 16,
_ => 30,
};
Duration::from_secs(secs)
}
impl<T: Transport> Drop for ChannelManager<T> {
fn drop(&mut self) {
if let Ok(mut guard) = self.reconnect_handle.try_write() {
if let Some(handle) = guard.take() {
handle.abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::duplex;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
fn make_auth_config() -> Arc<ClientAuthConfig> {
let source = crate::auth::keys::KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
Arc::new(ClientAuthConfig::from_key_source(source).unwrap())
}
struct AlwaysFailTransport;
#[async_trait::async_trait]
impl Transport for AlwaysFailTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
Err(anyhow::anyhow!("always fails"))
}
fn describe(&self) -> String {
"always-fail".to_string()
}
}
struct TrackConnectTransport {
connect_count: Arc<AtomicUsize>,
}
impl TrackConnectTransport {
fn new() -> Self {
Self {
connect_count: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait::async_trait]
impl Transport for TrackConnectTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
self.connect_count.fetch_add(1, Ordering::SeqCst);
let (client, _) = duplex(4096);
Ok(client)
}
fn describe(&self) -> String {
"track-connect".to_string()
}
}
struct CountingFailTransport {
fail_count: Arc<AtomicUsize>,
succeed_after: usize,
}
impl CountingFailTransport {
fn new(succeed_after: usize) -> Self {
Self {
fail_count: Arc::new(AtomicUsize::new(0)),
succeed_after,
}
}
}
#[async_trait::async_trait]
impl Transport for CountingFailTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
if count < self.succeed_after {
return Err(anyhow::anyhow!("connection failed (attempt {})", count));
}
let (client, _) = duplex(4096);
Ok(client)
}
fn describe(&self) -> String {
"counting-fail".to_string()
}
}
#[test]
fn test_backoff_durations() {
assert_eq!(backoff_duration(0), Duration::from_secs(1));
assert_eq!(backoff_duration(1), Duration::from_secs(2));
assert_eq!(backoff_duration(2), Duration::from_secs(4));
assert_eq!(backoff_duration(3), Duration::from_secs(8));
assert_eq!(backoff_duration(4), Duration::from_secs(16));
assert_eq!(backoff_duration(5), Duration::from_secs(30));
assert_eq!(backoff_duration(6), Duration::from_secs(30));
assert_eq!(backoff_duration(100), Duration::from_secs(30));
}
#[test]
fn test_backoff_sequence_matches_spec() {
let sequence: Vec<Duration> = (0..6).map(backoff_duration).collect();
assert_eq!(
sequence,
vec![
Duration::from_secs(1),
Duration::from_secs(2),
Duration::from_secs(4),
Duration::from_secs(8),
Duration::from_secs(16),
Duration::from_secs(30),
]
);
}
#[test]
fn test_forward_request_hash_eq() {
let fwd1 = ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 8080,
};
let fwd2 = ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 8080,
};
let fwd3 = ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 9090,
};
assert_eq!(fwd1, fwd2);
assert_ne!(fwd1, fwd3);
let mut set = HashSet::new();
set.insert(fwd1.clone());
assert!(set.contains(&fwd2));
assert!(!set.contains(&fwd3));
}
#[tokio::test]
async fn test_channel_manager_new_transport_fails() {
let auth = make_auth_config();
let transport = Arc::new(AlwaysFailTransport);
let result = ChannelManager::new(transport, auth, "testuser".to_string()).await;
assert!(result.is_err());
match result {
Err(ChannelError::TargetUnreachable) => {}
other => panic!("expected TargetUnreachable, got {:?}", other.as_ref().err()),
}
}
#[tokio::test]
async fn test_transport_connect_called_on_new() {
let transport = Arc::new(TrackConnectTransport::new());
let connect_before = transport.connect_count.load(Ordering::SeqCst);
assert_eq!(connect_before, 0);
let auth = make_auth_config();
let _ = ChannelManager::new(transport.clone(), auth, "testuser".to_string()).await;
let connect_after = transport.connect_count.load(Ordering::SeqCst);
assert!(connect_after > 0);
}
#[tokio::test]
async fn test_reconnect_monitor_detects_closed_handle() {
let auth = make_auth_config();
let transport = Arc::new(TrackConnectTransport::new());
let handler = ClientHandler::from_config(&auth);
let config = Arc::new(russh::client::Config::default());
let stream = transport.connect().await.unwrap();
let handle = client::connect_stream(config, stream, handler).await;
match handle {
Ok(h) => {
assert!(!h.is_closed());
drop(h);
}
Err(_) => {
// connect_stream fails without a real SSH server,
// but the concept is verified: dropped handle => is_closed
}
}
}
#[tokio::test]
async fn test_forward_set_tracks_requests() {
let mut set: HashSet<ForwardRequest> = HashSet::new();
set.insert(ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 8080,
});
set.insert(ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 9090,
});
assert_eq!(set.len(), 2);
set.remove(&ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 8080,
});
assert_eq!(set.len(), 1);
assert!(set.contains(&ForwardRequest {
addr: "0.0.0.0".to_string(),
port: 9090,
}));
}
#[test]
fn test_backoff_indefinitely_beyond_cap() {
for attempt in 0..50 {
let duration = backoff_duration(attempt);
assert!(duration <= Duration::from_secs(30));
assert!(duration >= Duration::from_secs(1));
}
}
}

View File

@@ -0,0 +1,854 @@
//! Client session management and connection logic.
//!
//! `ClientSession` establishes an SSH connection over a transport, authenticates,
//! starts a SOCKS5 proxy, sets up port forwards, and monitors for reconnection.
//! `ConnectOptions` provides a builder-pattern API for programmatic configuration.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use russh::client;
use russh::keys::PrivateKey;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
use crate::auth::keys::KeySource;
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
use crate::error::ConfigError;
use crate::socks5::{HandleChannelOpener, Socks5Server};
use crate::transport::Transport;
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
/// Transport mode for the client connection.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransportMode {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportMode::Tcp => write!(f, "tcp"),
TransportMode::Tls => write!(f, "tls"),
TransportMode::Iroh => write!(f, "iroh"),
}
}
}
/// Programmatic configuration for an alknet client session.
///
/// Construct with `ConnectOptions::new(key_source)` and chain builder methods.
/// Call `validate()` before passing to `ClientSession::new()`.
///
/// ```
/// use alknet_core::client::{ConnectOptions, TransportMode};
/// use alknet_core::auth::keys::KeySource;
///
/// let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
/// .server("example.com:22")
/// .transport_mode(TransportMode::Tcp)
/// .socks5_addr("127.0.0.1:1080")
/// .forward("5432:db.internal:5432");
/// opts.validate().unwrap();
/// ```
#[derive(Clone)]
pub struct ConnectOptions {
pub server: Option<String>,
pub peer: Option<String>,
pub transport_mode: TransportMode,
pub identity: KeySource,
pub socks5_addr: String,
pub forwards: Vec<String>,
pub remote_forwards: Vec<String>,
pub proxy: Option<String>,
pub iroh_relay: Option<String>,
pub tls_server_name: Option<String>,
pub insecure: bool,
}
impl ConnectOptions {
pub fn new(identity: KeySource) -> Self {
Self {
server: None,
peer: None,
transport_mode: TransportMode::Tcp,
identity,
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
forwards: Vec::new(),
remote_forwards: Vec::new(),
proxy: None,
iroh_relay: None,
tls_server_name: None,
insecure: false,
}
}
pub fn server(mut self, addr: impl Into<String>) -> Self {
self.server = Some(addr.into());
self
}
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
self.peer = Some(endpoint_id.into());
self
}
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
self.transport_mode = mode;
self
}
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
self.socks5_addr = addr.into();
self
}
pub fn forward(mut self, spec: impl Into<String>) -> Self {
self.forwards.push(spec.into());
self
}
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
self.remote_forwards.push(spec.into());
self
}
pub fn proxy(mut self, url: impl Into<String>) -> Self {
self.proxy = Some(url.into());
self
}
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
self.iroh_relay = Some(url.into());
self
}
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
self.tls_server_name = Some(name.into());
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
match self.transport_mode {
TransportMode::Tcp | TransportMode::Tls => {
if self.server.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--server is required for tcp/tls transport".to_string(),
});
}
}
TransportMode::Iroh => {
if self.peer.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--peer is required for iroh transport".to_string(),
});
}
}
}
Ok(())
}
}
impl std::fmt::Debug for ConnectOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectOptions")
.field("server", &self.server)
.field("peer", &self.peer)
.field("transport_mode", &self.transport_mode)
.field("identity", &"<KeySource>")
.field("socks5_addr", &self.socks5_addr)
.field("forwards", &self.forwards)
.field("remote_forwards", &self.remote_forwards)
.field("proxy", &self.proxy)
.field("iroh_relay", &self.iroh_relay)
.field("tls_server_name", &self.tls_server_name)
.field("insecure", &self.insecure)
.finish()
}
}
/// An active SSH client session over a transport.
///
/// Establishes the connection, authenticates, and runs a SOCKS5 proxy plus
/// port forwards until shutdown or transport failure. On transport failure,
/// attempts reconnection with exponential backoff (1s, 2s, 4s, ..., 30s cap).
pub struct ClientSession<T: Transport> {
opts: ConnectOptions,
transport: Arc<T>,
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
auth_config: Arc<ClientAuthConfig>,
#[allow(dead_code)]
private_key: Arc<PrivateKey>,
#[allow(dead_code)]
username: String,
shutdown_tx: tokio::sync::watch::Sender<bool>,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
}
impl<T: Transport> ClientSession<T> {
pub async fn new(
opts: ConnectOptions,
transport: Arc<T>,
) -> Result<Self, ConnectError> {
opts.validate().map_err(ConnectError::Config)?;
let auth_config = Arc::new(
ClientAuthConfig::from_key_source(opts.identity.clone())
.map_err(ConnectError::Config)?,
);
let private_key = auth_config.private_key();
let username = derive_username();
let handler = ClientHandler::from_config(&auth_config);
let stream = transport.connect().await.map_err(|e| {
error!("transport connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let config = Arc::new(client::Config::default());
let mut handle = client::connect_stream(config, stream, handler)
.await
.map_err(|e| {
error!("SSH connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let auth_ok = auth_config
.authenticate(&mut handle, &username)
.await
.map_err(|_| ConnectError::AuthFailed)?;
if !auth_ok {
return Err(ConnectError::AuthFailed);
}
let handle = Arc::new(Mutex::new(handle));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
Ok(Self {
opts,
transport,
handle,
auth_config,
private_key,
username,
shutdown_tx,
shutdown_rx,
})
}
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
Arc::clone(&self.handle)
}
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
&self.auth_config
}
pub fn transport(&self) -> &Arc<T> {
&self.transport
}
pub fn options(&self) -> &ConnectOptions {
&self.opts
}
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
self.shutdown_tx.clone()
}
pub async fn run(self) -> Result<(), ConnectError> {
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
})
})?;
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
let socks5_listen = socks5_server.listen_addr();
let local_forwarders = build_local_forwarders(&self.opts)?;
let remote_specs = build_remote_specs(&self.opts)?;
for spec in &remote_specs {
let remote_forwarder = RemoteForwarder::new(spec.clone())
.map_err(|_| ConnectError::ForwardFailed)?;
let mut h = self.handle.lock().await;
remote_forwarder
.register(&mut h)
.await
.map_err(|_| {
warn!("failed to register remote forward {}", spec);
ConnectError::ForwardFailed
})?;
info!("registered remote forward: {}", spec);
}
let socks5_task = tokio::spawn(async move {
debug!("SOCKS5 server starting on {}", socks5_listen);
if let Err(e) = socks5_server.run().await {
error!("SOCKS5 server error: {e}");
}
});
let fwd_handle = Arc::clone(&self.handle);
let fwd_shutdown = self.shutdown_rx.clone();
let forward_task = tokio::spawn(async move {
crate::client::forward::run_local_forwarders(
local_forwarders, fwd_handle, fwd_shutdown,
)
.await;
});
info!("alknet client running: SOCKS5 on {}", socks5_listen);
#[cfg(unix)]
let signal_done = {
let sig_tx = self.shutdown_tx.clone();
tokio::spawn(async move {
let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = sigterm_stream.recv() => {
info!("received SIGTERM");
}
_ = tokio::signal::ctrl_c() => {
info!("received SIGINT (Ctrl+C)");
}
}
let _ = sig_tx.send(true);
})
};
let mut wait_shutdown = self.shutdown_rx.clone();
let reconnect_handle = Arc::clone(&self.handle);
let reconnect_transport = Arc::clone(&self.transport);
let reconnect_auth = Arc::clone(&self.auth_config);
let reconnect_username = self.username.clone();
let reconnect_shutdown = self.shutdown_rx.clone();
let reconnect_remote_specs = remote_specs.clone();
let reconnect_monitor = tokio::spawn(async move {
let mut attempts: u32 = 0;
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
if *reconnect_shutdown.borrow() {
break;
}
let h = reconnect_handle.lock().await;
if h.is_closed() {
drop(h);
info!("SSH session closed, starting reconnection");
let backoff = backoff_duration(attempts);
warn!("reconnect attempt #{}, waiting {:?}", attempts + 1, backoff);
tokio::time::sleep(backoff).await;
let handler = ClientHandler::from_config(&reconnect_auth);
let username = reconnect_username.clone();
match establish_session(&*reconnect_transport, handler, &reconnect_auth, &username).await {
Ok(new_handle) => {
info!("reconnection successful");
{
let mut guard = reconnect_handle.lock().await;
*guard = new_handle;
}
for spec in &reconnect_remote_specs {
match RemoteForwarder::new(spec.clone()) {
Ok(rf) => {
let mut h = reconnect_handle.lock().await;
match rf.register(&mut h).await {
Ok(_) => debug!("re-registered remote forward: {}", spec),
Err(e) => warn!("failed to re-register remote forward {}: {e}", spec),
}
}
Err(e) => warn!("failed to create remote forwarder: {e}"),
}
}
attempts = 0;
}
Err(e) => {
warn!("reconnection attempt failed: {e}");
attempts += 1;
}
}
}
}
});
tokio::select! {
_ = wait_shutdown.changed() => {
if *wait_shutdown.borrow() {
info!("shutdown signal received");
}
}
_ = socks5_task => {
warn!("SOCKS5 server exited unexpectedly");
}
}
reconnect_monitor.abort();
#[cfg(unix)]
signal_done.abort();
self.shutdown().await?;
forward_task.abort();
let _ = forward_task.await;
Ok(())
}
pub async fn shutdown(&self) -> Result<(), ConnectError> {
info!("initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
{
let handle = self.handle.lock().await;
if !handle.is_closed() {
if let Err(e) = handle
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
.await
{
warn!("failed to send SSH disconnect: {e}");
}
}
}
tokio::time::sleep(DRAIN_TIMEOUT).await;
info!("graceful shutdown complete");
Ok(())
}
}
fn derive_username() -> String {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "alknet".to_string())
}
async fn establish_session<T: Transport>(
transport: &T,
handler: ClientHandler,
auth_config: &ClientAuthConfig,
username: &str,
) -> Result<client::Handle<ClientHandler>, ConnectError> {
let stream = transport.connect().await.map_err(|e| {
error!("transport connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let config = Arc::new(client::Config::default());
let mut handle = client::connect_stream(config, stream, handler)
.await
.map_err(|e| {
error!("SSH connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let auth_ok = auth_config
.authenticate(&mut handle, username)
.await
.map_err(|_| ConnectError::AuthFailed)?;
if !auth_ok {
return Err(ConnectError::AuthFailed);
}
Ok(handle)
}
fn backoff_duration(attempt: u32) -> Duration {
let secs: u64 = match attempt {
0 => 1,
1 => 2,
2 => 4,
3 => 8,
4 => 16,
_ => 30,
};
Duration::from_secs(secs)
}
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
let mut forwarders = Vec::new();
for spec_str in &opts.forwards {
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
warn!("invalid local forward spec '{}': {}", spec_str, e);
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid forward spec: {}", spec_str),
})
})?;
forwarders.push(
LocalForwarder::new(spec).map_err(|e| {
warn!("failed to create local forwarder: {}", e);
ConnectError::ForwardFailed
})?,
);
}
Ok(forwarders)
}
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
let mut specs = Vec::new();
for spec_str in &opts.remote_forwards {
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
warn!("invalid remote forward spec '{}': {}", spec_str, e);
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid remote forward spec: {}", spec_str),
})
})?;
specs.push(spec);
}
Ok(specs)
}
/// Errors that can occur during client connection setup and operation.
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
#[error("connection failed")]
ConnectionFailed,
#[error("authentication failed")]
AuthFailed,
#[error("forward setup failed")]
ForwardFailed,
#[error("config error: {0}")]
Config(#[from] ConfigError),
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::duplex;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
fn make_identity() -> KeySource {
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
}
#[test]
fn connect_options_default_fields() {
let opts = ConnectOptions::new(make_identity());
assert!(opts.server.is_none());
assert!(opts.peer.is_none());
assert_eq!(opts.transport_mode, TransportMode::Tcp);
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
assert!(opts.forwards.is_empty());
assert!(opts.remote_forwards.is_empty());
assert!(opts.proxy.is_none());
assert!(opts.iroh_relay.is_none());
assert!(opts.tls_server_name.is_none());
assert!(!opts.insecure);
}
#[test]
fn connect_options_builder_pattern() {
let opts = ConnectOptions::new(make_identity())
.server("example.com:22")
.transport_mode(TransportMode::Tls)
.socks5_addr("127.0.0.1:9050")
.forward("127.0.0.1:5432:db:5432")
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
.proxy("socks5://127.0.0.1:1080")
.iroh_relay("https://relay.example.com")
.tls_server_name("alknet.test")
.insecure(true);
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
assert_eq!(opts.transport_mode, TransportMode::Tls);
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
assert_eq!(opts.forwards.len(), 1);
assert_eq!(opts.remote_forwards.len(), 1);
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
assert!(opts.insecure);
}
#[test]
fn connect_options_validate_tcp_requires_server() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_tcp_with_server_ok() {
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
assert!(opts.validate().is_ok());
}
#[test]
fn connect_options_validate_tls_requires_server() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_tls_with_server_ok() {
let opts = ConnectOptions::new(make_identity())
.transport_mode(TransportMode::Tls)
.server("example.com:443");
assert!(opts.validate().is_ok());
}
#[test]
fn connect_options_validate_iroh_requires_peer() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_iroh_with_peer_ok() {
let opts = ConnectOptions::new(make_identity())
.transport_mode(TransportMode::Iroh)
.peer("some-endpoint-id");
assert!(opts.validate().is_ok());
}
#[test]
fn identity_accepts_key_source_file() {
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
let opts = ConnectOptions::new(file_source);
match &opts.identity {
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
_ => panic!("expected File variant"),
}
}
#[test]
fn identity_accepts_key_source_memory() {
let mem_source = KeySource::Memory(b"key-data".to_vec());
let opts = ConnectOptions::new(mem_source);
match &opts.identity {
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
_ => panic!("expected Memory variant"),
}
}
#[test]
fn transport_mode_display() {
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
assert_eq!(TransportMode::Tls.to_string(), "tls");
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
}
#[test]
fn connect_error_variants() {
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
}
#[test]
fn connect_options_debug_redacts_identity() {
let opts = ConnectOptions::new(make_identity());
let debug_str = format!("{:?}", opts);
assert!(debug_str.contains("<KeySource>"));
assert!(!debug_str.contains("OPENSSH"));
}
struct FailTransport;
#[async_trait::async_trait]
impl Transport for FailTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
Err(anyhow::anyhow!("always fails"))
}
fn describe(&self) -> String {
"fail".to_string()
}
}
struct DuplexTransport {
connect_count: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Transport for DuplexTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
self.connect_count.fetch_add(1, Ordering::SeqCst);
let (client, _) = duplex(4096);
Ok(client)
}
fn describe(&self) -> String {
"duplex".to_string()
}
}
#[tokio::test]
async fn client_session_new_transport_fails() {
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let transport = Arc::new(FailTransport);
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
}
#[tokio::test]
async fn client_session_new_ssh_handshake_fails() {
let transport = Arc::new(DuplexTransport {
connect_count: Arc::new(AtomicUsize::new(0)),
});
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
}
#[test]
fn build_local_forwarders_empty() {
let opts = ConnectOptions::new(make_identity());
let result = build_local_forwarders(&opts);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn build_local_forwarders_valid() {
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
let result = build_local_forwarders(&opts);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn build_local_forwarders_invalid_spec() {
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
let result = build_local_forwarders(&opts);
assert!(result.is_err());
}
#[test]
fn build_remote_specs_empty() {
let opts = ConnectOptions::new(make_identity());
let result = build_remote_specs(&opts);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn build_remote_specs_valid() {
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
let result = build_remote_specs(&opts);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn build_remote_specs_invalid() {
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
let result = build_remote_specs(&opts);
assert!(result.is_err());
}
#[test]
fn default_socks5_addr() {
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
}
#[test]
fn drain_timeout_is_two_seconds() {
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
}
#[test]
fn transport_mode_equality() {
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
}
#[tokio::test]
async fn shutdown_sends_disconnect_and_drains() {
let transport = Arc::new(DuplexTransport {
connect_count: Arc::new(AtomicUsize::new(0)),
});
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
}
#[test]
fn socks5_is_always_enabled_by_default() {
let opts = ConnectOptions::new(make_identity());
assert!(!opts.socks5_addr.is_empty());
}
#[tokio::test]
async fn integration_mock_transport_session() {
use crate::socks5::{ChannelOpener, ChannelOpenError};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
struct MockOpener;
impl ChannelOpener for MockOpener {
type Stream = tokio::io::DuplexStream;
async fn open_channel(
&self,
_host: String,
_port: u16,
) -> Result<Self::Stream, ChannelOpenError> {
let (client, _server) = duplex(4096);
Ok(client)
}
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound_addr = listener.local_addr().unwrap();
drop(listener);
let opener = MockOpener;
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
let _server_task = tokio::spawn(async move {
let _ = server.run().await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
let greeting = [0x05, 0x01, 0x00];
conn.write_all(&greeting).await.unwrap();
let mut auth_resp = [0u8; 2];
conn.read_exact(&mut auth_resp).await.unwrap();
assert_eq!(auth_resp, [0x05, 0x00]);
let connect_req = [
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
];
conn.write_all(&connect_req).await.unwrap();
let mut reply = [0u8; 10];
conn.read_exact(&mut reply).await.unwrap();
assert_eq!(reply[1], 0x00);
conn.write_all(b"test data").await.unwrap();
conn.shutdown().await.unwrap();
}
}

View File

@@ -0,0 +1,537 @@
//! Local and remote port forwarding.
//!
//! `LocalForwarder` binds a local TCP listener and forwards each connection through
//! an SSH `direct-tcpip` channel. `RemoteForwarder` requests `tcpip-forward` from
//! the server and handles `forwarded-tcpip` channels. Specs follow the
//! `bind_addr:bind_port:target_host:target_port` format.
use std::net::SocketAddr;
use std::sync::Arc;
use russh::client;
use tokio::io;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tracing::{debug, error, info};
use crate::error::ForwardError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PortForwardSpecKind {
Local,
Remote,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PortForwardSpec {
pub kind: PortForwardSpecKind,
pub bind_addr: String,
pub bind_port: u16,
pub target_host: String,
pub target_port: u16,
}
impl PortForwardSpec {
pub fn local(spec: &str) -> Result<Self, ForwardError> {
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
Ok(Self {
kind: PortForwardSpecKind::Local,
bind_addr,
bind_port,
target_host,
target_port,
})
}
pub fn remote(spec: &str) -> Result<Self, ForwardError> {
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
Ok(Self {
kind: PortForwardSpecKind::Remote,
bind_addr,
bind_port,
target_host,
target_port,
})
}
pub fn listen_addr(&self) -> Result<SocketAddr, ForwardError> {
format!("{}:{}", self.bind_addr, self.bind_port)
.parse()
.map_err(|_| ForwardError::InvalidSpec {
spec: format!("{}:{}", self.bind_addr, self.bind_port),
})
}
pub fn target_addr(&self) -> Result<SocketAddr, ForwardError> {
format!("{}:{}", self.target_host, self.target_port)
.parse()
.map_err(|_| ForwardError::InvalidSpec {
spec: format!("{}:{}", self.target_host, self.target_port),
})
}
}
impl std::fmt::Display for PortForwardSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let prefix = match self.kind {
PortForwardSpecKind::Local => "-L",
PortForwardSpecKind::Remote => "-R",
};
write!(
f,
"{} {}:{}:{}:{}",
prefix, self.bind_addr, self.bind_port, self.target_host, self.target_port
)
}
}
fn parse_spec(spec: &str) -> Result<(String, u16, String, u16), ForwardError> {
let parts: Vec<&str> = spec.split(':').collect();
if parts.len() != 4 {
return Err(ForwardError::InvalidSpec {
spec: spec.to_string(),
});
}
let bind_addr = parts[0].to_string();
let bind_port: u16 = parts[1].parse().map_err(|_| ForwardError::InvalidSpec {
spec: spec.to_string(),
})?;
let target_host = parts[2].to_string();
let target_port: u16 = parts[3].parse().map_err(|_| ForwardError::InvalidSpec {
spec: spec.to_string(),
})?;
Ok((bind_addr, bind_port, target_host, target_port))
}
pub struct LocalForwarder {
spec: PortForwardSpec,
listener: Option<TcpListener>,
}
impl LocalForwarder {
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
if spec.kind != PortForwardSpecKind::Local {
return Err(ForwardError::InvalidSpec {
spec: format!("expected local spec, got {:?}", spec.kind),
});
}
Ok(Self {
spec,
listener: None,
})
}
pub fn spec(&self) -> &PortForwardSpec {
&self.spec
}
pub async fn run<H: client::Handler + Send + 'static>(
&mut self,
handle: Arc<Mutex<client::Handle<H>>>,
) -> Result<(), ForwardError> {
let listen_addr = self.spec.listen_addr()?;
let listener: TcpListener = TcpListener::bind(listen_addr)
.await
.map_err(|e| ForwardError::BindFailed { source: e })?;
self.listener = Some(listener);
let remote_host = self.spec.target_host.clone();
let remote_port = self.spec.target_port;
info!(
"local forward listening on {} -> {}:{}",
listen_addr, remote_host, remote_port
);
loop {
let listener = match &self.listener {
Some(l) => l,
None => return Ok(()),
};
let accept_result = listener.accept().await;
let (local_stream, local_addr) = match accept_result {
Ok(conn) => conn,
Err(e) => {
let handle = handle.lock().await;
if handle.is_closed() {
debug!("local forward accept loop ending: ssh session closed");
return Ok(());
}
drop(handle);
error!("local forward accept error: {}", e);
continue;
}
};
debug!(
"local forward connection from {} -> {}:{}",
local_addr, remote_host, remote_port
);
let handle = handle.clone();
let remote_host = remote_host.clone();
tokio::spawn(async move {
if let Err(e) =
proxy_local_to_remote(local_stream, handle, &remote_host, remote_port).await
{
debug!("local forward proxy error: {}", e);
}
});
}
}
pub async fn stop(&mut self) {
if let Some(listener) = self.listener.take() {
drop(listener);
}
}
pub fn local_port(&self) -> u16 {
self.spec.bind_port
}
}
async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
local_stream: TcpStream,
handle: Arc<Mutex<client::Handle<H>>>,
remote_host: &str,
remote_port: u16,
) -> Result<(), ForwardError> {
let local_addr = local_stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_default();
let handle_guard = handle.lock().await;
let channel = handle_guard
.channel_open_direct_tcpip(
remote_host,
remote_port as u32,
&local_addr,
0,
)
.await
.map_err(|e| ForwardError::ChannelOpenFailed {
source: Box::new(e) as _,
})?;
drop(handle_guard);
let ssh_stream = channel.into_stream();
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
match tokio::join!(client_to_server, server_to_client) {
(Err(e), _) | (_, Err(e)) => {
debug!("local forward bidirectional copy error: {}", e);
}
_ => {}
}
Ok(())
}
pub struct RemoteForwarder {
spec: PortForwardSpec,
cancel: Option<tokio::sync::oneshot::Sender<()>>,
}
impl RemoteForwarder {
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
if spec.kind != PortForwardSpecKind::Remote {
return Err(ForwardError::InvalidSpec {
spec: format!("expected remote spec, got {:?}", spec.kind),
});
}
Ok(Self { spec, cancel: None })
}
pub fn spec(&self) -> &PortForwardSpec {
&self.spec
}
pub async fn register<H: client::Handler + Send + 'static>(
&self,
handle: &mut client::Handle<H>,
) -> Result<u32, ForwardError> {
let port = handle
.tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
.await
.map_err(|e| ForwardError::ChannelOpenFailed {
source: Box::new(e) as _,
})?;
Ok(port)
}
pub async fn handle_forwarded_channel(
channel: russh::Channel<russh::client::Msg>,
connected_address: &str,
connected_port: u32,
local_host: &str,
local_port: u16,
) {
debug!(
"remote forward: server opened forwarded-tcpip channel to {}:{} -> local {}:{}",
connected_address, connected_port, local_host, local_port
);
let local_target = format!("{}:{}", local_host, local_port);
let local_stream = match TcpStream::connect(&local_target).await {
Ok(s) => s,
Err(e) => {
error!(
"remote forward: failed to connect to local target {}: {}",
local_target, e
);
return;
}
};
let ssh_stream = channel.into_stream();
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
match tokio::join!(client_to_server, server_to_client) {
(Err(e), _) | (_, Err(e)) => {
debug!("remote forward bidirectional copy error: {}", e);
}
_ => {}
}
}
pub async fn unregister<H: client::Handler + Send + 'static>(
&self,
handle: &client::Handle<H>,
) -> Result<(), ForwardError> {
handle
.cancel_tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
.await
.map_err(|e| ForwardError::ChannelOpenFailed {
source: Box::new(e) as _,
})?;
Ok(())
}
pub async fn stop(&mut self) {
if let Some(cancel) = self.cancel.take() {
let _ = cancel.send(());
}
}
}
pub async fn run_local_forwarders<H: client::Handler + Send + 'static>(
forwarders: Vec<LocalForwarder>,
handle: Arc<Mutex<client::Handle<H>>>,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) -> Vec<LocalForwarder> {
let mut forwarders = forwarders;
let mut tasks = Vec::new();
for forwarder in forwarders.drain(..) {
let handle = handle.clone();
let spec = forwarder.spec().clone();
let (_cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
tasks.push(tokio::spawn(async move {
let mut fwd = forwarder;
tokio::select! {
result = fwd.run(handle) => {
if let Err(e) = result {
error!("local forward {} failed: {}", spec, e);
}
}
_ = cancel_rx => {
fwd.stop().await;
}
}
fwd
}));
}
let _ = shutdown.changed().await;
for task in &tasks {
task.abort();
}
let mut results = Vec::new();
for task in tasks {
match task.await {
Ok(fwd) => results.push(fwd),
Err(e) => {
if !e.is_cancelled() {
error!("local forwarder task panicked: {}", e);
}
}
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_local_spec() {
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
assert_eq!(spec.kind, PortForwardSpecKind::Local);
assert_eq!(spec.bind_addr, "127.0.0.1");
assert_eq!(spec.bind_port, 5432);
assert_eq!(spec.target_host, "db.internal");
assert_eq!(spec.target_port, 5432);
}
#[test]
fn parse_remote_spec() {
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
assert_eq!(spec.kind, PortForwardSpecKind::Remote);
assert_eq!(spec.bind_addr, "0.0.0.0");
assert_eq!(spec.bind_port, 8080);
assert_eq!(spec.target_host, "127.0.0.1");
assert_eq!(spec.target_port, 3000);
}
#[test]
fn parse_spec_invalid_few_parts() {
assert!(PortForwardSpec::local("127.0.0.1:5432:db").is_err());
}
#[test]
fn parse_spec_invalid_many_parts() {
assert!(PortForwardSpec::local("a:b:c:d:e").is_err());
}
#[test]
fn parse_spec_invalid_port() {
assert!(PortForwardSpec::local("127.0.0.1:abc:db:5432").is_err());
}
#[test]
fn parse_spec_invalid_target_port() {
assert!(PortForwardSpec::local("127.0.0.1:5432:db:abc").is_err());
}
#[test]
fn spec_display() {
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
assert_eq!(spec.to_string(), "-L 127.0.0.1:5432:db.internal:5432");
}
#[test]
fn spec_display_remote() {
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
assert_eq!(spec.to_string(), "-R 0.0.0.0:8080:127.0.0.1:3000");
}
#[test]
fn local_forwarder_rejects_remote_spec() {
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
assert!(LocalForwarder::new(spec).is_err());
}
#[test]
fn remote_forwarder_rejects_local_spec() {
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
assert!(RemoteForwarder::new(spec).is_err());
}
#[test]
fn listen_addr_valid() {
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
let addr = spec.listen_addr().unwrap();
assert_eq!(addr.port(), 5432);
}
#[test]
fn listen_addr_invalid_host() {
let spec = PortForwardSpec {
kind: PortForwardSpecKind::Local,
bind_addr: "!!!invalid".to_string(),
bind_port: 5432,
target_host: "db".to_string(),
target_port: 5432,
};
assert!(spec.listen_addr().is_err());
}
#[tokio::test]
async fn local_forward_bind_and_accept() {
let spec = PortForwardSpec::local(&format!("127.0.0.1:0:remote:5432")).unwrap();
let forwarder = LocalForwarder::new(spec).unwrap();
let listen_addr = forwarder.spec.listen_addr().unwrap();
let listener = TcpListener::bind(listen_addr).await.unwrap();
let bound_addr = listener.local_addr().unwrap();
drop(listener);
let spec = PortForwardSpec::local(&format!(
"127.0.0.1:{}:remote:5432",
bound_addr.port()
))
.unwrap();
let forwarder = LocalForwarder::new(spec).unwrap();
assert_eq!(forwarder.local_port(), bound_addr.port());
}
#[tokio::test]
async fn remote_forward_proxy_bidirectional() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let _echo_addr = echo_listener.local_addr().unwrap();
let echo_server = tokio::spawn(async move {
let (mut stream, _) = echo_listener.accept().await.unwrap();
let mut buf = [0u8; 64];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
let local_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = local_listener.local_addr().unwrap();
let proxy_task = tokio::spawn(async move {
let (stream, _) = local_listener.accept().await.unwrap();
let (mut read, mut write) = tokio::io::split(stream);
let _ = io::copy(&mut read, &mut write).await;
});
let mut local_conn = TcpStream::connect(local_addr).await.unwrap();
local_conn.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 64];
let n = local_conn.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello");
echo_server.abort();
proxy_task.abort();
}
#[test]
fn forwarder_spec_access() {
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
let forwarder = LocalForwarder::new(spec.clone()).unwrap();
assert_eq!(forwarder.spec(), &spec);
assert_eq!(forwarder.local_port(), 5432);
}
#[test]
fn remote_forwarder_spec_access() {
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
assert_eq!(forwarder.spec(), &spec);
}
}

View File

@@ -0,0 +1,17 @@
//! Client-side SSH session management.
//!
//! Provides `ClientSession` for establishing an SSH connection over any transport,
//! running a local SOCKS5 proxy, and managing port forwards. Also provides
//! `ChannelManager` for programmatic channel management with automatic reconnection.
//!
//! The client always starts a SOCKS5 proxy (default `127.0.0.1:1080`) when running
//! via `ClientSession::run()`. For VPN-like "route all traffic" behavior, use
//! [tun2proxy](https://github.com/tun2proxy/tun2proxy) alongside the SOCKS5 proxy.
pub mod channel_manager;
pub mod connect;
pub mod forward;
pub use channel_manager::{ChannelManager, ForwardRequest};
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};

View File

@@ -0,0 +1,215 @@
//! Error types for alknet-core.
//!
//! Layered error hierarchy:
//! - `TransportError` — connection/handshake/timeout errors (trigger reconnection on client)
//! - `AuthError` — key rejection, certificate validation failures
//! - `ChannelError` — per-channel failures (target unreachable, channel closed)
//! - `ConfigError` — invalid configuration (flags, key files, bind failures)
//! - `ForwardError` — port forward setup and connection failures
use std::io;
#[derive(Debug, thiserror::Error)]
pub enum TransportError {
#[error("connection failed")]
ConnectionFailed,
#[error("handshake failed")]
HandshakeFailed {
#[source]
source: io::Error,
},
#[error("transport timeout")]
Timeout,
#[error("proxy failed")]
ProxyFailed {
#[source]
source: io::Error,
},
}
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum AuthError {
#[error("key rejected")]
KeyRejected,
#[error("certificate invalid")]
CertInvalid,
#[error("certificate expired")]
CertExpired,
#[error("certificate principal mismatch")]
CertPrincipalMismatch,
#[error("no matching key")]
NoMatchingKey,
}
#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
#[error("target unreachable")]
TargetUnreachable,
#[error("proxy connect failed")]
ProxyConnectFailed {
#[source]
source: io::Error,
},
#[error("channel closed")]
ChannelClosed,
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("invalid flag: {name}")]
InvalidFlag { name: String },
#[error("key file not found: {path}")]
KeyFileNotFound { path: String },
#[error("bind failed")]
BindFailed {
#[source]
source: io::Error,
},
#[error("incompatible options")]
IncompatibleOptions,
}
#[derive(Debug, thiserror::Error)]
pub enum ForwardError {
#[error("invalid port forward spec: {spec}")]
InvalidSpec { spec: String },
#[error("bind failed")]
BindFailed {
#[source]
source: io::Error,
},
#[error("channel open failed")]
ChannelOpenFailed {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("connect to local target failed")]
LocalConnectFailed {
#[source]
source: io::Error,
},
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
#[test]
fn transport_error_display() {
assert_eq!(TransportError::ConnectionFailed.to_string(), "connection failed");
assert_eq!(
TransportError::HandshakeFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
}
.to_string(),
"handshake failed"
);
assert_eq!(TransportError::Timeout.to_string(), "transport timeout");
assert_eq!(
TransportError::ProxyFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "proxy err")
}
.to_string(),
"proxy failed"
);
}
#[test]
fn auth_error_display() {
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
assert_eq!(AuthError::CertPrincipalMismatch.to_string(), "certificate principal mismatch");
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
}
#[test]
fn channel_error_display() {
assert_eq!(ChannelError::TargetUnreachable.to_string(), "target unreachable");
assert_eq!(
ChannelError::ProxyConnectFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
}
.to_string(),
"proxy connect failed"
);
assert_eq!(ChannelError::ChannelClosed.to_string(), "channel closed");
}
#[test]
fn config_error_display() {
assert_eq!(
ConfigError::InvalidFlag {
name: "--bad".to_string()
}
.to_string(),
"invalid flag: --bad"
);
assert_eq!(
ConfigError::KeyFileNotFound {
path: "/missing".to_string()
}
.to_string(),
"key file not found: /missing"
);
assert_eq!(
ConfigError::BindFailed {
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
}
.to_string(),
"bind failed"
);
assert_eq!(ConfigError::IncompatibleOptions.to_string(), "incompatible options");
}
#[test]
fn error_source_chaining() {
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
let transport_err = TransportError::HandshakeFailed { source: io_err };
assert!(transport_err.source().is_some());
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "proxy");
let channel_err = ChannelError::ProxyConnectFailed { source: io_err };
assert!(channel_err.source().is_some());
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "addr");
let config_err = ConfigError::BindFailed { source: io_err };
assert!(config_err.source().is_some());
let plain = AuthError::KeyRejected;
assert!(plain.source().is_none());
}
#[test]
fn forward_error_display() {
assert_eq!(
ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(),
"invalid port forward spec: bad"
);
assert_eq!(
ForwardError::BindFailed {
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
}
.to_string(),
"bind failed"
);
assert_eq!(
ForwardError::LocalConnectFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
}
.to_string(),
"connect to local target failed"
);
}
#[test]
fn forward_error_source_chaining() {
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "in use");
let forward_err = ForwardError::BindFailed { source: io_err };
assert!(forward_err.source().is_some());
let plain = ForwardError::InvalidSpec { spec: "bad".to_string() };
assert!(plain.source().is_none());
}
}

View File

@@ -0,0 +1,67 @@
//! # alknet-core
//!
//! Core library for [Alknet](https://git.alk.dev/alkdev/alknet), a self-hostable SSH-based
//! tunnel tool. This crate provides the transport abstraction, SOCKS5 server, port forwarding,
//! authentication, and server handler — everything needed to build an alknet client or server
//! on top of pluggable transports.
//!
//! > **Alpha software.** This crate depends on solid libraries (russh, tokio, rustls, iroh)
//! > for core functionality, but the integration layer has not been battle-tested. Use with
//! > caution and report issues.
//!
//! # Key concepts
//!
//! - **Transport trait** — produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
//! that SSH consumes. Implementations: TCP, TLS, iroh (QUIC P2P).
//! - **SOCKS5 server** — the primary client interface, listening on a local port and routing
//! traffic through SSH channels.
//! - **Port forwarding** — `-L` local and `-R` remote port forwards over SSH channels.
//! - **Authentication** — Ed25519 public key and OpenSSH certificate authority. No passwords.
//! - **Server handler** — accepts SSH connections via a `TransportAcceptor` and proxies
//! `direct-tcpip` channel requests to targets (directly or via outbound proxy).
//!
//! # Feature flags
//!
//! | Feature | Default | Description |
//! |---------|---------|-------------|
//! | `tls` | yes | TLS transport via `tokio-rustls` |
//! | `iroh` | yes | iroh QUIC P2P transport |
//! | `acme` | no | ACME/Let's Encrypt auto-cert provisioning (implies `tls`) |
//! | `testutil` | no | Test utilities (for internal use) |
//!
//! # Quick example
//!
//! ```no_run
//! use std::sync::Arc;
//! use alknet_core::transport::TcpTransport;
//! use alknet_core::client::{ClientSession, ConnectOptions, TransportMode};
//! use alknet_core::auth::keys::KeySource;
//! use alknet_core::Transport;
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
//! .server("example.com:22")
//! .transport_mode(TransportMode::Tcp);
//! let transport = Arc::new(TcpTransport::new("example.com:22".parse()?));
//! let session = ClientSession::new(opts, transport).await?;
//! session.run().await?;
//! Ok(())
//! }
//! ```
pub mod transport;
pub mod client;
pub mod server;
pub mod auth;
pub mod socks5;
pub mod error;
#[cfg(feature = "testutil")]
pub mod testutil;
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
pub use client::channel_manager::{ChannelManager, ForwardRequest};
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
pub use server::serve::{Server, ServeError, ServeOptions, ServeTransportMode};

View File

@@ -0,0 +1,563 @@
//! Outbound connection proxy for SSH channel targets.
//!
//! Connects to the requested `host:port` either directly, via SOCKS5 proxy, or
//! via HTTP CONNECT proxy, then proxies bytes bidirectionally between the SSH
//! channel and the outbound TCP stream.
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::handler::{ProxyConfig, ProxyMode};
#[derive(Debug, thiserror::Error)]
pub enum ChannelProxyError {
#[error("connection refused")]
ConnectionRefused,
#[error("target unreachable")]
TargetUnreachable,
#[error("socks5 proxy handshake failed")]
Socks5HandshakeFailed,
#[error("socks5 proxy rejected connection")]
Socks5ProxyRejected,
#[error("http connect proxy handshake failed")]
HttpConnectHandshakeFailed,
#[error("http connect proxy rejected: {0}")]
HttpConnectProxyRejected(String),
#[error("io error")]
Io(#[from] std::io::Error),
}
pub async fn connect_outbound(
target: SocketAddr,
proxy: &ProxyConfig,
) -> Result<TcpStream, ChannelProxyError> {
match &proxy.mode {
ProxyMode::Direct => connect_direct(target).await,
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
}
}
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
TcpStream::connect(target)
.await
.map_err(|e| map_connection_error(e, target))
}
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
let mut stream = TcpStream::connect(proxy_addr)
.await
.map_err(ChannelProxyError::from)?;
stream.write_all(&[0x05, 0x01, 0x00]).await?;
stream.flush().await?;
let mut resp = [0u8; 2];
stream.read_exact(&mut resp).await?;
if resp[0] != 0x05 || resp[1] != 0x00 {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
let ip_bytes = target.ip().to_string();
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
connect_req.push(ip_bytes.len() as u8);
connect_req.extend_from_slice(ip_bytes.as_bytes());
connect_req.extend_from_slice(&target.port().to_be_bytes());
stream.write_all(&connect_req).await?;
stream.flush().await?;
let mut reply_header = [0u8; 4];
stream.read_exact(&mut reply_header).await?;
if reply_header[0] != 0x05 {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
if reply_header[1] != 0x00 {
return Err(ChannelProxyError::Socks5ProxyRejected);
}
let atyp = reply_header[3];
match atyp {
0x01 => {
let mut _addr = [0u8; 4];
stream.read_exact(&mut _addr).await?;
}
0x04 => {
let mut _addr = [0u8; 16];
stream.read_exact(&mut _addr).await?;
}
0x03 => {
let len = stream.read_u8().await?;
let mut _domain = vec![0u8; len as usize];
stream.read_exact(&mut _domain).await?;
}
_ => {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
}
let mut _port = [0u8; 2];
stream.read_exact(&mut _port).await?;
Ok(stream)
}
async fn connect_http_connect(
target: SocketAddr,
proxy_addr: SocketAddr,
) -> Result<TcpStream, ChannelProxyError> {
let mut stream = TcpStream::connect(proxy_addr)
.await
.map_err(ChannelProxyError::from)?;
let connect_request = format!(
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
target.ip(),
target.port(),
target.ip(),
target.port()
);
stream.write_all(connect_request.as_bytes()).await?;
stream.flush().await?;
let mut response = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = stream.read(&mut buf).await?;
if n == 0 {
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
}
response.extend_from_slice(&buf[..n]);
if response.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response_str = String::from_utf8_lossy(&response);
let status_line = response_str
.lines()
.next()
.unwrap_or("");
if status_line.contains("200") {
Ok(stream)
} else {
Err(ChannelProxyError::HttpConnectProxyRejected(
status_line.to_string(),
))
}
}
fn map_connection_error(e: std::io::Error, _target: SocketAddr) -> ChannelProxyError {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
std::io::ErrorKind::AddrNotAvailable
| std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
_ => ChannelProxyError::Io(e),
}
}
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
if let Ok(outbound) = connect_outbound(target, proxy).await {
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
let (mut read_out, mut write_out) = outbound.into_split();
let client_to_target = tokio::spawn(async move {
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
let _ = write_out.shutdown().await;
});
let target_to_client = tokio::spawn(async move {
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
let _ = write_chan.shutdown().await;
});
let _ = client_to_target.await;
let _ = target_to_client.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio::net::TcpListener;
fn direct_config() -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::Direct,
}
}
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
}
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
}
#[tokio::test]
async fn direct_connection_to_echo_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
let (mut read, mut write) = stream.into_split();
write.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
read.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
let _ = server.await;
}
#[tokio::test]
async fn direct_connection_target_unreachable() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let result = connect_outbound(target, &direct_config()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn socks5_proxy_handshake() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = target_listener.local_addr().unwrap();
let target_server = tokio::spawn(async move {
let (mut sock, _) = target_listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut greeting = [0u8; 3];
proxy_sock.read_exact(&mut greeting).await.unwrap();
assert_eq!(greeting[0], 0x05);
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
let mut req_header = [0u8; 4];
proxy_sock.read_exact(&mut req_header).await.unwrap();
assert_eq!(req_header[0], 0x05);
assert_eq!(req_header[1], 0x01);
let atyp = req_header[3];
assert_eq!(atyp, 0x03);
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
let mut domain = vec![0u8; domain_len];
proxy_sock.read_exact(&mut domain).await.unwrap();
let mut port_bytes = [0u8; 2];
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
let target: SocketAddr = format!(
"{}:{}",
String::from_utf8_lossy(&domain),
u16::from_be_bytes(port_bytes)
)
.parse()
.unwrap();
let reply = vec![
0x05, 0x00, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
proxy_sock.write_all(&reply).await.unwrap();
let mut target_stream = TcpStream::connect(target).await.unwrap();
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
});
let config = socks5_config(proxy_addr);
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
stream.write_all(b"hello socks").await.unwrap();
let mut buf = [0u8; 11];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello socks");
drop(stream);
let _ = target_server.await;
let _ = proxy_server.await;
}
#[tokio::test]
async fn socks5_proxy_rejected() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut greeting = [0u8; 3];
proxy_sock.read_exact(&mut greeting).await.unwrap();
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
let mut req_header = [0u8; 4];
proxy_sock.read_exact(&mut req_header).await.unwrap();
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
let mut domain = vec![0u8; domain_len];
proxy_sock.read_exact(&mut domain).await.unwrap();
let mut port_bytes = [0u8; 2];
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
let reply = vec![
0x05, 0x05, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
proxy_sock.write_all(&reply).await.unwrap();
});
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let config = socks5_config(proxy_addr);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ChannelProxyError::Socks5ProxyRejected
));
let _ = proxy_server.await;
}
#[tokio::test]
async fn http_connect_proxy_handshake() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = target_listener.local_addr().unwrap();
let target_server = tokio::spawn(async move {
let (mut sock, _) = target_listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = proxy_sock.read(&mut buf).await.unwrap();
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
proxy_sock.write_all(response.as_bytes()).await.unwrap();
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
});
let config = http_connect_config(proxy_addr);
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
stream.write_all(b"hello http").await.unwrap();
let mut buf = [0u8; 10];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello http");
drop(stream);
let _ = target_server.await;
let _ = proxy_server.await;
}
fn extract_connect_target(request: &str) -> String {
let connect_line = request.lines().next().unwrap_or("");
let parts: Vec<&str> = connect_line.split_whitespace().collect();
if parts.len() >= 2 {
parts[1].to_string()
} else {
String::new()
}
}
#[tokio::test]
async fn http_connect_proxy_rejected() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = proxy_sock.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
proxy_sock.write_all(response.as_bytes()).await.unwrap();
});
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let config = http_connect_config(proxy_addr);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
match result.unwrap_err() {
ChannelProxyError::HttpConnectProxyRejected(msg) => {
assert!(msg.contains("403"));
}
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
}
let _ = proxy_server.await;
}
#[tokio::test]
async fn target_unreachable_returns_appropriate_error() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let result = connect_outbound(target, &direct_config()).await;
match result.unwrap_err() {
ChannelProxyError::TargetUnreachable
| ChannelProxyError::ConnectionRefused
| ChannelProxyError::Io(_) => {}
other => panic!("unexpected error type: {:?}", other),
}
}
#[tokio::test]
async fn socks5_proxy_unreachable() {
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
let config = socks5_config(bad_proxy);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn http_connect_proxy_unreachable() {
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
let config = http_connect_config(bad_proxy);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
}
struct MockChannel {
read_half: tokio::io::ReadHalf<DuplexStream>,
write_half: tokio::io::WriteHalf<DuplexStream>,
}
impl tokio::io::AsyncRead for MockChannel {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
}
}
impl tokio::io::AsyncWrite for MockChannel {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
}
}
fn make_mock_channel() -> (MockChannel, DuplexStream) {
let (client, server) = duplex(4096);
let (read_half, write_half) = tokio::io::split(client);
(
MockChannel {
read_half,
write_half,
},
server,
)
}
#[tokio::test]
async fn proxy_channel_bidirectional_data_flow() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = listener.local_addr().unwrap();
let echo_server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let (channel, mut channel_peer) = make_mock_channel();
let target = target_addr;
let proxy = direct_config();
tokio::spawn(async move {
proxy_channel(channel, target, &proxy).await;
});
channel_peer.write_all(b"ping").await.unwrap();
channel_peer.flush().await.unwrap();
let mut buf = [0u8; 4];
channel_peer.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
drop(channel_peer);
let _ = echo_server.await;
}
#[tokio::test]
async fn proxy_channel_target_unreachable_closes_cleanly() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let (channel, _channel_peer) = make_mock_channel();
let proxy = direct_config();
proxy_channel(channel, target, &proxy).await;
}
}

View File

@@ -0,0 +1,192 @@
//! Control channel routing for reserved `alknet-*` destinations.
//!
//! SSH channels opened with a destination starting with `alknet-` are intercepted
//! by the server and routed to a `ControlChannelHandler` instead of proxied to a
//! TCP target. See ADR-018 for the design rationale.
use std::io;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
pub const ALKNET_CONTROL_DESTINATION: &str = "alknet-control";
pub const ALKNET_PREFIX: &str = "alknet-";
pub fn is_reserved_destination(host: &str) -> bool {
host.starts_with(ALKNET_PREFIX)
}
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
#[async_trait]
pub trait ControlChannelHandler: Send + Sync {
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
}
pub struct ControlChannelRouter {
handler: Option<Box<dyn ControlChannelHandler>>,
}
impl ControlChannelRouter {
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
Self { handler }
}
pub fn without_handler() -> Self {
Self { handler: None }
}
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
Self {
handler: Some(handler),
}
}
pub fn has_handler(&self) -> bool {
self.handler.is_some()
}
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
match &self.handler {
Some(handler) => {
handler.handle_channel(stream).await;
Ok(())
}
None => Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"no control channel handler configured",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[test]
fn alknet_control_destination_constant() {
assert_eq!(ALKNET_CONTROL_DESTINATION, "alknet-control");
}
#[test]
fn alknet_prefix_constant() {
assert_eq!(ALKNET_PREFIX, "alknet-");
}
#[test]
fn reserved_destination_detected() {
assert!(is_reserved_destination("alknet-control"));
assert!(is_reserved_destination("alknet-status"));
assert!(is_reserved_destination("alknet-events"));
assert!(is_reserved_destination("alknet-"));
}
#[test]
fn non_reserved_destination_passes_through() {
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("192.168.1.1"));
assert!(!is_reserved_destination("alknet.example.com"));
assert!(!is_reserved_destination(""));
assert!(!is_reserved_destination("alkne-control"));
assert!(!is_reserved_destination("ALKNET-control"));
}
#[test]
fn prefix_matching_case_sensitive() {
assert!(!is_reserved_destination("Alknet-control"));
assert!(!is_reserved_destination("ALKNET-control"));
assert!(is_reserved_destination("alknet-Control"));
}
#[test]
fn router_without_handler_has_no_handler() {
let router = ControlChannelRouter::without_handler();
assert!(!router.has_handler());
}
#[test]
fn router_with_handler_has_handler() {
struct DummyHandler;
#[async_trait]
impl ControlChannelHandler for DummyHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
}
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
assert!(router.has_handler());
}
#[tokio::test]
async fn route_without_handler_returns_error() {
let router = ControlChannelRouter::without_handler();
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
}
#[tokio::test]
async fn route_with_handler_succeeds() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
struct TrackedHandler {
called: Arc<AtomicBool>,
}
#[async_trait]
impl ControlChannelHandler for TrackedHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
self.called.store(true, Ordering::SeqCst);
}
}
let called = Arc::new(AtomicBool::new(false));
let handler = TrackedHandler {
called: called.clone(),
};
let router = ControlChannelRouter::with_handler(Box::new(handler));
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_ok());
assert!(called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn route_with_handler_can_read_write() {
struct EchoHandler;
#[async_trait]
impl ControlChannelHandler for EchoHandler {
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
let mut buf = [0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
stream.write_all(&buf[..n]).await.unwrap();
}
}
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
let (client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
tokio::spawn(async move {
router.route(stream).await.unwrap();
});
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut client = client;
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[test]
fn control_channel_destination_matches_prefix() {
assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION));
}
}

View File

@@ -0,0 +1,736 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use russh::ChannelId;
use crate::auth::ServerAuthConfig;
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX,
};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
#[derive(Debug, Clone)]
pub enum ProxyMode {
Direct,
Socks5(SocketAddr),
HttpConnect(SocketAddr),
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
}
}
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
#[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
control_channel_router: ControlChannelRouter,
#[allow(dead_code)]
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
connected_at: Instant,
}
impl ServerHandler {
pub fn new(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
auth_config,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
connected_at: Instant::now(),
}
}
pub fn is_connection_allowed(&self) -> bool {
self.connection_allowed
}
pub fn remote_ip(&self) -> Option<IpAddr> {
self.remote_addr.map(|a| a.ip())
}
}
impl Drop for ServerHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
}
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
impl ServerHandler {
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
pub fn control_channel_router(&self) -> &ControlChannelRouter {
&self.control_channel_router
}
}
#[async_trait]
impl Handler for ServerHandler {
type Error = russh::Error;
async fn auth_publickey(
&mut self,
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
let result = self.auth_config.authenticate_publickey(&russh_pub);
match result {
Ok(()) => {
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
);
Ok(Auth::Accept)
}
Err(_) => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
Ok(Auth::Reject {
proceed_with_methods: None,
})
}
}
}
async fn channel_open_direct_tcpip(
&mut self,
channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
_session: &mut Session,
) -> Result<bool, Self::Error> {
if host_to_connect.starts_with(ALKNET_PREFIX) {
if !self.control_channel_router.has_handler() {
return Ok(false);
}
let _ = channel;
return Ok(true);
}
let target_host = host_to_connect.to_string();
let target_port = port_to_connect;
let proxy_config = self.outbound_proxy.clone().unwrap_or(ProxyConfig {
mode: ProxyMode::Direct,
});
tokio::spawn(async move {
let target = match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
Ok(addr) => addr,
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => return,
},
Err(_) => return,
},
};
crate::server::channel_proxy::proxy_channel(channel.into_stream(), target, &proxy_config).await;
});
let _ = (originator_address, originator_port);
Ok(true)
}
async fn channel_open_session(
&mut self,
_channel: Channel<Msg>,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected session channel (shell/exec not supported)"
);
let _ = session;
Ok(false)
}
async fn channel_open_x11(
&mut self,
_channel: Channel<Msg>,
_originator_address: &str,
_originator_port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected x11 channel"
);
let _ = session;
Ok(false)
}
async fn channel_open_forwarded_tcpip(
&mut self,
_channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
target = %format!("{host_to_connect}:{port_to_connect}"),
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn exec_request(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
data_len = data.len(),
"rejected exec request on channel (shell/exec not supported)"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn shell_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected shell request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn subsystem_request(
&mut self,
channel: ChannelId,
name: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
subsystem = name,
"rejected subsystem request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn pty_request(
&mut self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(russh::Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
term = term,
"rejected pty request on channel"
);
let _ = (col_width, row_height, pix_width, pix_height, modes);
let _ = session.channel_failure(channel);
Ok(())
}
async fn env_request(
&mut self,
channel: ChannelId,
variable_name: &str,
variable_value: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
variable = variable_name,
"rejected env request on channel"
);
let _ = variable_value;
let _ = session.channel_failure(channel);
Ok(())
}
async fn x11_request(
&mut self,
channel: ChannelId,
single_connection: bool,
x11_auth_protocol: &str,
x11_auth_cookie: &str,
x11_screen_number: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected x11 request on channel"
);
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
let _ = session.channel_failure(channel);
Ok(())
}
async fn agent_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected agent forwarding request on channel"
);
let _ = session;
Ok(false)
}
async fn tcpip_forward(
&mut self,
address: &str,
port: &mut u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
address = address,
port = *port,
"rejected tcpip-forward request (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn cancel_tcpip_forward(
&mut self,
address: &str,
port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
let _ = (address, port, session);
Ok(false)
}
async fn streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
socket_path = socket_path,
"rejected streamlocal-forward request"
);
let _ = session;
Ok(false)
}
async fn signal(
&mut self,
channel: ChannelId,
signal: russh::Sig,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::debug!(
remote_addr = ?self.remote_addr,
channel = %channel,
signal = ?signal,
"received signal on channel (ignored)"
);
let _ = session;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use russh::keys::{decode_secret_key, PrivateKey};
use std::io::Write;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(keys_content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
fn load_key() -> PrivateKey {
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
}
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
let f = make_authorized_keys_file(keys_content);
Arc::new(
ServerAuthConfig::from_keys_and_ca(
Some(KeySource::File(f.path().to_path_buf())),
None,
)
.unwrap(),
)
}
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
Arc::new(ConnectionRateLimiter::new(0))
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
}
#[tokio::test]
async fn auth_delegation_accepts_known_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(result, Auth::Accept);
}
#[tokio::test]
async fn auth_delegation_rejects_unknown_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
other_key_text.split_whitespace().nth(1).unwrap(),
)
.unwrap();
let result = handler
.auth_publickey("testuser", &other_ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_delegation_empty_config_rejects_all() {
let auth_config = make_empty_auth_config();
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
.auth_publickey("testuser", &ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_logging_includes_remote_addr() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
let mut handler = make_handler(auth_config, None, Some(remote_addr));
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn reserved_alknet_destination_routing() {
use crate::server::control_channel::is_reserved_destination;
assert!(is_reserved_destination("alknet-control"));
assert!(is_reserved_destination("alknet-status"));
assert!(is_reserved_destination("alknet-events"));
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("alknet.example.com"));
}
#[test]
fn server_handler_without_control_handler_rejects_alknet_destinations() {
let auth_config = make_empty_auth_config();
let handler = make_handler(auth_config, None, None);
assert!(!handler.control_channel_router().has_handler());
}
#[test]
fn proxy_mode_variants() {
let direct = ProxyMode::Direct;
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
match direct {
ProxyMode::Direct => {}
_ => panic!("expected Direct"),
}
match socks5 {
ProxyMode::Socks5(_) => {}
_ => panic!("expected Socks5"),
}
match http {
ProxyMode::HttpConnect(_) => {}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn server_handler_holds_config() {
let auth_config = make_empty_auth_config();
let proxy = Some(ProxyConfig {
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
});
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
let handler = make_handler(auth_config, proxy.clone(), remote);
assert!(handler.outbound_proxy.is_some());
assert!(handler.remote_addr.is_some());
}
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
assert!(handler1.remote_addr != handler2.remote_addr);
}
#[tokio::test]
async fn auth_rate_limit_rejects_after_max_failures() {
let auth_config = make_empty_auth_config();
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let mut handler = ServerHandler::new(
auth_config,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
limiter,
2,
);
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert!(!handler.auth_limiter.check());
}
#[test]
fn connection_rate_limit_blocks_over_limit() {
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let auth_config = make_empty_auth_config();
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(h1.is_connection_allowed());
let h2 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(!h2.is_connection_allowed());
drop(h1);
let h3 = ServerHandler::new(
auth_config,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(h3.is_connection_allowed());
}
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
}
#[tokio::test]
async fn auth_log_includes_user_field() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn connection_closed_logs_duration_on_drop() {
let auth_config = make_empty_auth_config();
let _handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tcp,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
}
}

View File

@@ -0,0 +1,25 @@
//! Server-side SSH connection handling.
//!
//! Provides `Server` for accepting SSH connections over any transport and proxying
//! `direct-tcpip` channel requests to targets. Supports Ed25519 and certificate-authority
//! auth, connection rate limiting, auth attempt limiting, stealth mode (fake nginx 404),
//! and outbound proxy routing (direct/SOCKS5/HTTP CONNECT).
//!
//! Destination hosts starting with `alknet-` are reserved for internal use (control channel, ADR-018).
pub mod channel_proxy;
pub mod control_channel;
pub mod handler;
pub mod rate_limit;
pub mod serve;
pub mod stealth;
pub use channel_proxy::{connect_outbound, proxy_channel};
pub use control_channel::{
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
ALKNET_PREFIX, is_reserved_destination,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use serve::{Server, ServeError, ServeOptions, ServeTransportMode};
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};

View File

@@ -0,0 +1,200 @@
//! Connection rate limiting and auth attempt limiting.
//!
//! `ConnectionRateLimiter` tracks per-IP active connections (thread-safe).
//! `AuthAttemptLimiter` caps failed auth attempts per connection.
//! These complement fail2ban on Linux and provide abuse protection on all platforms.
//! See ADR-013.
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
pub struct ConnectionRateLimiter {
max_per_ip: usize,
active: Mutex<HashMap<IpAddr, usize>>,
}
impl ConnectionRateLimiter {
pub fn new(max_per_ip: usize) -> Self {
Self {
max_per_ip,
active: Mutex::new(HashMap::new()),
}
}
pub fn check(&self, ip: IpAddr) -> bool {
if self.max_per_ip == 0 {
return true;
}
let active = self.active.lock().unwrap();
let count = active.get(&ip).copied().unwrap_or(0);
count < self.max_per_ip
}
pub fn on_connect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
*active.entry(ip).or_insert(0) += 1;
}
pub fn on_disconnect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
if let Some(count) = active.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
active.remove(&ip);
}
}
}
}
pub struct AuthAttemptLimiter {
max_attempts: usize,
failures: usize,
}
impl AuthAttemptLimiter {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
failures: 0,
}
}
pub fn check(&self) -> bool {
if self.max_attempts == 0 {
return true;
}
self.failures < self.max_attempts
}
pub fn on_failure(&mut self) {
self.failures += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn ip(n: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
}
#[test]
fn connection_limiter_allows_when_under_limit() {
let limiter = ConnectionRateLimiter::new(3);
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_blocks_when_at_limit() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
}
#[test]
fn connection_limiter_allows_after_disconnect() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
limiter.on_disconnect(ip(1));
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_unlimited_when_zero() {
let limiter = ConnectionRateLimiter::new(0);
for _ in 0..100 {
limiter.on_connect(ip(1));
}
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_tracks_per_ip_independently() {
let limiter = ConnectionRateLimiter::new(1);
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
assert!(limiter.check(ip(2)));
}
#[test]
fn connection_limiter_ipv6() {
let limiter = ConnectionRateLimiter::new(1);
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
limiter.on_connect(ip6);
assert!(!limiter.check(ip6));
}
#[test]
fn connection_limiter_disconnect_removes_zero_entry() {
let limiter = ConnectionRateLimiter::new(3);
limiter.on_connect(ip(1));
limiter.on_disconnect(ip(1));
{
let active = limiter.active.lock().unwrap();
assert!(!active.contains_key(&ip(1)));
}
}
#[test]
fn auth_limiter_allows_when_under_limit() {
let limiter = AuthAttemptLimiter::new(3);
assert!(limiter.check());
}
#[test]
fn auth_limiter_blocks_after_max_failures() {
let mut limiter = AuthAttemptLimiter::new(2);
limiter.on_failure();
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn auth_limiter_unlimited_when_zero() {
let mut limiter = AuthAttemptLimiter::new(0);
for _ in 0..100 {
limiter.on_failure();
}
assert!(limiter.check());
}
#[test]
fn auth_limiter_still_allows_at_one_below_limit() {
let mut limiter = AuthAttemptLimiter::new(3);
limiter.on_failure();
limiter.on_failure();
assert!(limiter.check());
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn connection_limiter_thread_safety() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(ConnectionRateLimiter::new(100));
let mut handles = vec![];
for i in 0..10 {
let lim = Arc::clone(&limiter);
handles.push(thread::spawn(move || {
let ip_addr = ip((i % 3) as u8 + 1);
lim.on_connect(ip_addr);
assert!(lim.check(ip_addr));
lim.on_disconnect(ip_addr);
}));
}
for h in handles {
h.join().unwrap();
}
}
}

View File

@@ -0,0 +1,765 @@
//! Server configuration and accept loop.
//!
//! `Server` binds to a transport acceptor and runs an accept loop, handling
//! authentication, stealth mode protocol detection, and graceful shutdown.
//! `ServeOptions` provides a builder-pattern API for programmatic configuration.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use russh::server::{self, Config};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use crate::auth::keys::KeySource;
use crate::auth::server_auth::ServerAuthConfig;
use crate::error::ConfigError;
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
use crate::server::rate_limit::ConnectionRateLimiter;
use crate::server::stealth::{self, ProtocolDetection};
const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
/// Transport mode for the server listener.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ServeTransportMode {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for ServeTransportMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServeTransportMode::Tcp => write!(f, "tcp"),
ServeTransportMode::Tls => write!(f, "tls"),
ServeTransportMode::Iroh => write!(f, "iroh"),
}
}
}
/// Programmatic configuration for an alknet server.
///
/// Construct with `ServeOptions::new(key_source)` and chain builder methods.
/// Call `validate()` before passing to `Server::new()`.
///
/// ```
/// use alknet_core::server::{ServeOptions, ServeTransportMode};
/// use alknet_core::auth::keys::KeySource;
///
/// let opts = ServeOptions::new(KeySource::File("/path/to/host_key".into()))
/// .transport_mode(ServeTransportMode::Tcp)
/// .listen_addr("0.0.0.0:22")
/// .max_connections_per_ip(5)
/// .max_auth_attempts(3);
/// opts.validate().unwrap();
/// ```
pub struct ServeOptions {
pub key: KeySource,
pub authorized_keys: Option<KeySource>,
pub cert_authority: Option<KeySource>,
pub transport_mode: ServeTransportMode,
pub listen_addr: String,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub acme_domain: Option<String>,
pub stealth: bool,
pub proxy: Option<String>,
pub iroh_relay: Option<String>,
pub max_connections_per_ip: usize,
pub max_auth_attempts: usize,
}
impl ServeOptions {
pub fn new(key: KeySource) -> Self {
Self {
key,
authorized_keys: None,
cert_authority: None,
transport_mode: ServeTransportMode::Tcp,
listen_addr: DEFAULT_LISTEN_ADDR.to_string(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
proxy: None,
iroh_relay: None,
max_connections_per_ip: 0,
max_auth_attempts: 10,
}
}
pub fn authorized_keys(mut self, source: KeySource) -> Self {
self.authorized_keys = Some(source);
self
}
pub fn cert_authority(mut self, source: KeySource) -> Self {
self.cert_authority = Some(source);
self
}
pub fn transport_mode(mut self, mode: ServeTransportMode) -> Self {
self.transport_mode = mode;
self
}
pub fn listen_addr(mut self, addr: impl Into<String>) -> Self {
self.listen_addr = addr.into();
self
}
pub fn tls_cert(mut self, path: impl Into<String>) -> Self {
self.tls_cert = Some(path.into());
self
}
pub fn tls_key(mut self, path: impl Into<String>) -> Self {
self.tls_key = Some(path.into());
self
}
pub fn acme_domain(mut self, domain: impl Into<String>) -> Self {
self.acme_domain = Some(domain.into());
self
}
pub fn stealth(mut self, enabled: bool) -> Self {
self.stealth = enabled;
self
}
pub fn proxy(mut self, url: impl Into<String>) -> Self {
self.proxy = Some(url.into());
self
}
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
self.iroh_relay = Some(url.into());
self
}
pub fn max_connections_per_ip(mut self, max: usize) -> Self {
self.max_connections_per_ip = max;
self
}
pub fn max_auth_attempts(mut self, max: usize) -> Self {
self.max_auth_attempts = max;
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.stealth && self.transport_mode != ServeTransportMode::Tls {
return Err(ConfigError::InvalidFlag {
name: "stealth mode requires TLS transport (--transport tls)".to_string(),
});
}
match self.transport_mode {
ServeTransportMode::Tls => {
if self.tls_cert.is_none() && self.acme_domain.is_none() {
return Err(ConfigError::InvalidFlag {
name: "TLS transport requires --tls-cert/--tls-key or --acme-domain"
.to_string(),
});
}
if self.tls_cert.is_some() && self.tls_key.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--tls-cert requires --tls-key".to_string(),
});
}
if self.tls_key.is_some() && self.tls_cert.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--tls-key requires --tls-cert".to_string(),
});
}
}
ServeTransportMode::Tcp | ServeTransportMode::Iroh => {
if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() {
return Err(ConfigError::IncompatibleOptions);
}
}
}
Ok(())
}
}
impl std::fmt::Debug for ServeOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeOptions")
.field("key", &"<KeySource>")
.field("authorized_keys", &"<KeySource>")
.field("cert_authority", &"<KeySource>")
.field("transport_mode", &self.transport_mode)
.field("listen_addr", &self.listen_addr)
.field("stealth", &self.stealth)
.field("max_connections_per_ip", &self.max_connections_per_ip)
.field("max_auth_attempts", &self.max_auth_attempts)
.finish()
}
}
/// Errors that can occur during server setup and operation.
#[derive(Debug, thiserror::Error)]
pub enum ServeError {
#[error("config error: {0}")]
Config(#[from] ConfigError),
#[error("bind failed: {0}")]
BindFailed(anyhow::Error),
#[error("key load failed: {0}")]
KeyLoadFailed(ConfigError),
#[error("accept failed")]
AcceptFailed,
}
struct ActiveSession {
handle: server::Handle,
join: tokio::task::JoinHandle<()>,
}
/// The alknet SSH server.
///
/// Accepts connections over any `TransportAcceptor`, authenticates via Ed25519 keys
/// or certificate authority, and proxies `direct-tcpip` channels to their targets.
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
pub struct Server {
config: Arc<server::Config>,
auth_config: Arc<ServerAuthConfig>,
connection_limiter: Arc<ConnectionRateLimiter>,
outbound_proxy: Option<ProxyConfig>,
stealth: bool,
transport_mode: ServeTransportMode,
listen_addr: String,
max_auth_attempts: usize,
shutdown_tx: tokio::sync::watch::Sender<bool>,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
sessions: Arc<tokio::sync::Mutex<Vec<ActiveSession>>>,
}
impl Server {
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
opts.validate().map_err(ServeError::Config)?;
let private_key =
crate::auth::keys::load_private_key(opts.key.clone()).map_err(ServeError::KeyLoadFailed)?;
let auth_config = Arc::new(
ServerAuthConfig::from_keys_and_ca(opts.authorized_keys.clone(), opts.cert_authority.clone())
.map_err(ServeError::KeyLoadFailed)?,
);
let config = Arc::new(Config {
keys: vec![private_key],
max_auth_attempts: opts.max_auth_attempts,
methods: russh::MethodSet::PUBLICKEY,
preferred: russh::Preferred::DEFAULT,
..Default::default()
});
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
let connection_limiter = Arc::new(ConnectionRateLimiter::new(opts.max_connections_per_ip));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
Ok(Self {
config,
auth_config,
connection_limiter,
outbound_proxy,
stealth: opts.stealth,
transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr,
max_auth_attempts: opts.max_auth_attempts,
shutdown_tx,
shutdown_rx,
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
})
}
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
self.shutdown_tx.clone()
}
pub async fn shutdown(&self) -> Result<(), ServeError> {
info!("initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
{
let sessions = self.sessions.lock().await;
for session in sessions.iter() {
if let Err(e) = session.handle.disconnect(
russh::Disconnect::ByApplication,
"shutdown".to_string(),
String::new(),
).await {
warn!("failed to send SSH disconnect: {e}");
}
}
}
tokio::time::sleep(DRAIN_TIMEOUT).await;
{
let mut sessions = self.sessions.lock().await;
for session in sessions.drain(..) {
session.join.abort();
}
}
info!("graceful shutdown complete");
Ok(())
}
pub fn is_shutdown(&self) -> bool {
*self.shutdown_rx.borrow()
}
pub async fn run<A>(self, acceptor: A, endpoint_info: Option<&str>) -> Result<(), ServeError>
where
A: crate::transport::TransportAcceptor,
{
let transport_kind = match self.transport_mode {
ServeTransportMode::Tcp => TransportKind::Tcp,
ServeTransportMode::Tls => TransportKind::Tls,
ServeTransportMode::Iroh => TransportKind::Iroh,
};
if self.transport_mode == ServeTransportMode::Iroh {
if let Some(id) = endpoint_info {
info!("alknet server running: transport=iroh endpoint_id={}", id);
} else {
info!("alknet server running: transport=iroh");
}
} else {
info!(
"alknet server running: transport={} listen={}",
self.transport_mode, self.listen_addr
);
}
let server = Arc::new(self);
let mut shutdown_rx = server.shutdown_rx.clone();
#[cfg(unix)]
let signal_done = {
let sig_tx = server.shutdown_tx.clone();
tokio::spawn(async move {
let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = sigterm_stream.recv() => {
info!("received SIGTERM");
}
_ = tokio::signal::ctrl_c() => {
info!("received SIGINT (Ctrl+C)");
}
}
let _ = sig_tx.send(true);
})
};
loop {
let shutdown = *shutdown_rx.borrow();
if shutdown {
info!("shutdown signaled, stopping accept loop");
break;
}
let accept_result = tokio::select! {
result = acceptor.accept() => result,
_ = shutdown_rx.changed() => {
info!("shutdown signaled while waiting for connection");
break;
}
};
let (stream, info) = match accept_result {
Ok(conn) => conn,
Err(e) => {
error!("accept failed: {e}");
continue;
}
};
let remote_addr = info.remote_addr;
let handler_transport_kind = transport_kind;
let handler = ServerHandler::new(
Arc::clone(&server.auth_config),
server.outbound_proxy.clone(),
remote_addr,
handler_transport_kind,
Arc::clone(&server.connection_limiter),
server.max_auth_attempts,
);
if !handler.is_connection_allowed() {
continue;
}
let config = Arc::clone(&server.config);
let sessions = Arc::clone(&server.sessions);
let stealth = server.stealth;
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
tokio::spawn(async move {
let result = handle_connection(
stream,
config,
handler,
sessions,
stealth,
transport_is_tls,
)
.await;
if let Err(e) = result {
warn!("connection error: {e}");
}
});
}
#[cfg(unix)]
signal_done.abort();
server.shutdown().await?;
Ok(())
}
}
async fn handle_connection<S>(
stream: S,
config: Arc<Config>,
handler: ServerHandler,
sessions: Arc<tokio::sync::Mutex<Vec<ActiveSession>>>,
stealth: bool,
transport_is_tls: bool,
) -> Result<(), anyhow::Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
if stealth && transport_is_tls {
let (protocol, mut reader) = stealth::detect_protocol(stream).await;
match protocol {
ProtocolDetection::Http => {
stealth::send_fake_nginx_404(&mut reader).await;
return Ok(());
}
ProtocolDetection::Ssh => {
let running = server::run_stream(config, reader, handler).await?;
let handle = running.handle();
let join = tokio::spawn(async {
let _ = running.await;
});
sessions.lock().await.push(ActiveSession { handle, join });
return Ok(());
}
}
}
let running = server::run_stream(config, stream, handler).await?;
let handle = running.handle();
let join = tokio::spawn(async {
let _ = running.await;
});
sessions.lock().await.push(ActiveSession { handle, join });
Ok(())
}
fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
proxy.map(|url| {
if url.starts_with("socks5://") {
let addr: SocketAddr = url
.strip_prefix("socks5://")
.unwrap()
.parse()
.expect("invalid socks5 proxy address");
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
} else if url.starts_with("http://") {
let addr: SocketAddr = url
.strip_prefix("http://")
.unwrap()
.parse()
.expect("invalid http connect proxy address");
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
} else {
panic!("unsupported proxy URL scheme: {url}");
}
})
}
#[cfg(test)]
mod tests {
use super::*;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_key_source() -> KeySource {
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
}
fn make_authorized_keys_source() -> KeySource {
KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec())
}
#[test]
fn serve_options_default_fields() {
let opts = ServeOptions::new(make_key_source());
assert!(opts.authorized_keys.is_none());
assert!(opts.cert_authority.is_none());
assert_eq!(opts.transport_mode, ServeTransportMode::Tcp);
assert_eq!(opts.listen_addr, "0.0.0.0:22");
assert!(opts.tls_cert.is_none());
assert!(opts.tls_key.is_none());
assert!(opts.acme_domain.is_none());
assert!(!opts.stealth);
assert!(opts.proxy.is_none());
assert!(opts.iroh_relay.is_none());
assert_eq!(opts.max_connections_per_ip, 0);
assert_eq!(opts.max_auth_attempts, 10);
}
#[test]
fn serve_options_builder_pattern() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.cert_authority(make_authorized_keys_source())
.transport_mode(ServeTransportMode::Tls)
.listen_addr("0.0.0.0:443")
.tls_cert("/etc/ssl/cert.pem")
.tls_key("/etc/ssl/key.pem")
.stealth(true)
.proxy("socks5://127.0.0.1:9050")
.iroh_relay("https://relay.example.com")
.max_connections_per_ip(5)
.max_auth_attempts(3);
assert!(opts.authorized_keys.is_some());
assert!(opts.cert_authority.is_some());
assert_eq!(opts.transport_mode, ServeTransportMode::Tls);
assert_eq!(opts.listen_addr, "0.0.0.0:443");
assert_eq!(opts.tls_cert.as_deref(), Some("/etc/ssl/cert.pem"));
assert_eq!(opts.tls_key.as_deref(), Some("/etc/ssl/key.pem"));
assert!(opts.stealth);
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:9050"));
assert_eq!(
opts.iroh_relay.as_deref(),
Some("https://relay.example.com")
);
assert_eq!(opts.max_connections_per_ip, 5);
assert_eq!(opts.max_auth_attempts, 3);
}
#[test]
fn serve_options_validate_steam_without_tls_rejected() {
let opts = ServeOptions::new(make_key_source()).stealth(true);
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_stealth_with_tls_ok() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_cert("/cert.pem")
.tls_key("/key.pem")
.stealth(true);
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_tcp_no_tls_options_ok() {
let opts = ServeOptions::new(make_key_source());
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_tls_requires_certs() {
let opts = ServeOptions::new(make_key_source()).transport_mode(ServeTransportMode::Tls);
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tls_cert_without_key_rejected() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_cert("/cert.pem");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tls_key_without_cert_rejected() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_key("/key.pem");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tcp_with_acme_rejected() {
let opts =
ServeOptions::new(make_key_source()).acme_domain("example.com");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_acme_domain_with_tls_ok() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.acme_domain("example.com");
assert!(opts.validate().is_ok());
}
#[test]
fn server_new_creates_server() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
assert_eq!(server.max_auth_attempts, 10);
}
#[test]
fn server_new_stealth_without_tls_fails() {
let opts = ServeOptions::new(make_key_source()).stealth(true);
let result = Server::new(opts);
assert!(result.is_err());
}
#[test]
fn server_new_invalid_key_fails() {
let opts = ServeOptions::new(KeySource::Memory(b"not a key".to_vec()));
let result = Server::new(opts);
assert!(result.is_err());
}
#[test]
fn serve_transport_mode_display() {
assert_eq!(ServeTransportMode::Tcp.to_string(), "tcp");
assert_eq!(ServeTransportMode::Tls.to_string(), "tls");
assert_eq!(ServeTransportMode::Iroh.to_string(), "iroh");
}
#[test]
fn serve_transport_mode_equality() {
assert_eq!(ServeTransportMode::Tcp, ServeTransportMode::Tcp);
assert_ne!(ServeTransportMode::Tcp, ServeTransportMode::Tls);
assert_ne!(ServeTransportMode::Tls, ServeTransportMode::Iroh);
}
#[test]
fn serve_options_debug_redacts_keys() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let debug_str = format!("{:?}", opts);
assert!(debug_str.contains("<KeySource>"));
assert!(!debug_str.contains("OPENSSH"));
}
#[test]
fn parse_proxy_config_socks5() {
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::Socks5(addr) => {
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
}
_ => panic!("expected Socks5"),
}
}
#[test]
fn parse_proxy_config_http() {
let config = parse_proxy_config(Some("http://127.0.0.1:8080"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::HttpConnect(addr) => {
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn parse_proxy_config_none() {
assert!(parse_proxy_config(None).is_none());
}
#[test]
fn serve_error_variants() {
assert_eq!(ServeError::AcceptFailed.to_string(), "accept failed");
}
#[test]
fn default_listen_addr() {
assert_eq!(DEFAULT_LISTEN_ADDR, "0.0.0.0:22");
}
#[test]
fn drain_timeout_is_two_seconds() {
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
}
#[test]
fn server_shutdown_sender_clones() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
let sender = server.shutdown_sender();
assert!(!server.is_shutdown());
let _ = sender.send(true);
assert!(server.is_shutdown());
}
#[test]
fn server_holds_listen_addr() {
let opts = ServeOptions::new(make_key_source())
.listen_addr("0.0.0.0:443");
let server = Server::new(opts).unwrap();
assert_eq!(server.listen_addr, "0.0.0.0:443");
}
#[tokio::test]
async fn integration_server_accept_loop_and_shutdown() {
use crate::transport::TcpAcceptor;
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listen_addr(acceptor.listen_addr().to_string());
let server = Server::new(opts).unwrap();
let shutdown_tx = server.shutdown_sender();
let server_handle = tokio::spawn(async move {
server
.run(acceptor, None)
.await
.expect("server run failed")
});
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(true);
let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
assert!(result.is_ok(), "server should have shut down within timeout");
}
}

View File

@@ -0,0 +1,226 @@
//! Stealth mode: protocol detection on TLS connections.
//!
//! When stealth mode is enabled with TLS transport, the server peeks at the first
//! bytes after the TLS handshake to determine whether the client is speaking SSH
//! or HTTP. Non-SSH connections receive a fake nginx 404 response, making the
//! server appear as an ordinary web server to port scanners and DPI systems.
//! See ADR-017.
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolDetection {
Ssh,
Http,
}
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
where
S: AsyncRead + Unpin,
{
let mut reader = BufReader::new(stream);
let detection = match reader.fill_buf().await {
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
Ok(buf) if !buf.is_empty() => {
if buf.starts_with(SSH_BANNER_PREFIX) {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
_ => ProtocolDetection::Http,
};
(detection, reader)
}
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
let _ = reader.get_mut().shutdown().await;
}
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
if stealth && !transport_is_tls {
return Err("stealth mode requires TLS transport (--transport tls)");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(data).await.unwrap();
drop(client);
let (detection, _) = detect_protocol(server).await;
detection
}
#[tokio::test]
async fn ssh_banner_detected() {
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_other_implementation() {
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_minimal() {
let detection = write_and_detect(b"SSH-2.0-X\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn http_get_detected_as_http() {
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_post_detected_as_http() {
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn random_data_detected_as_http() {
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn empty_stream_detected_as_http() {
let (client, server) = duplex(1024);
drop(client);
let (detection, _) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn ssh_banner_bytes_preserved_by_bufreader() {
let (client, server) = duplex(1024);
let mut client = client;
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
client.write_all(banner).await.unwrap();
client.write_all(b"subsequent data").await.unwrap();
drop(client);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Ssh);
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data).await.unwrap();
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
}
#[tokio::test]
async fn fake_nginx_404_response() {
let (client, server) = duplex(1024);
let (mut client_read, mut client_write) = tokio::io::split(client);
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
drop(client_write);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client_read.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.contains("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
}
#[tokio::test]
async fn protocol_detection_enum_equality() {
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
}
#[test]
fn validate_stealth_without_tls_rejected() {
let result = validate_stealth_config(true, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("TLS transport"));
}
#[test]
fn validate_stealth_with_tls_accepted() {
let result = validate_stealth_config(true, true);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tcp_accepted() {
let result = validate_stealth_config(false, false);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tls_accepted() {
let result = validate_stealth_config(false, true);
assert!(result.is_ok());
}
#[tokio::test]
async fn short_data_detected_as_http() {
let detection = write_and_detect(b"GE").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn partial_ssh_prefix_detected_as_http() {
let detection = write_and_detect(b"SSH-1.").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_request_gets_404_then_closed() {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
let mut extra = [0u8; 16];
let result = client.read(&mut extra).await;
assert!(result.is_err() || result.unwrap() == 0);
}
}

View File

@@ -0,0 +1,497 @@
//! SOCKS5 proxy server.
//!
//! Listens on a local port and routes each SOCKS5 connection through an SSH
//! `direct-tcpip` channel. Supports SOCKS5h (domain names resolved server-side)
//! to prevent DNS leaks. Uses the `ChannelOpener` trait to abstract over the
//! SSH channel mechanism, making it testable without a real SSH session.
mod protocol;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tracing::debug;
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
pub use protocol::Socks5Address;
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
pub trait ChannelOpener: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
fn open_channel(
&self,
host: String,
port: u16,
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
}
#[derive(Debug, thiserror::Error)]
pub enum ChannelOpenError {
#[error("session closed")]
SessionClosed,
#[error("channel open failed")]
ChannelOpenFailed,
#[error("connection refused")]
ConnectionRefused,
}
pub struct Socks5Server<C: ChannelOpener> {
listen_addr: SocketAddr,
channel_opener: Arc<C>,
}
impl<C: ChannelOpener> Socks5Server<C> {
pub fn new(channel_opener: C) -> Self {
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
}
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
let listen_addr: SocketAddr = addr
.parse()
.expect("invalid SOCKS5 listen address");
Self {
listen_addr,
channel_opener: Arc::new(channel_opener),
}
}
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
pub async fn run(self) -> Result<(), std::io::Error> {
let listener = TcpListener::bind(self.listen_addr).await?;
debug!("socks5 server listening on {}", self.listen_addr);
loop {
let (socket, _peer) = listener.accept().await?;
let opener = Arc::clone(&self.channel_opener);
tokio::spawn(async move {
if let Err(e) = handle_socks5_connection(socket, opener).await {
debug!("socks5 connection error: {e}");
}
});
}
}
}
async fn handle_socks5_connection<S, C>(
mut socket: S,
opener: Arc<C>,
) -> Result<(), Socks5Error>
where
S: AsyncRead + AsyncWrite + Unpin,
C: ChannelOpener,
{
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
if vm.version != 0x05 {
return Err(Socks5Error::InvalidVersion(vm.version));
}
if !vm.methods.contains(&0x00) {
let reply = [0x05, 0xFF];
socket.write_all(&reply).await?;
socket.shutdown().await?;
return Err(Socks5Error::NoAcceptableAuth);
}
let reply = [0x05, 0x00];
socket.write_all(&reply).await?;
let request = Socks5Request::read_from(&mut socket).await?;
if request.version != 0x05 {
return Err(Socks5Error::InvalidVersion(request.version));
}
if request.command != 0x01 {
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
return Err(Socks5Error::UnsupportedCommand(request.command));
}
let (host, port) = match &request.address {
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
Socks5Address::Domain(name) => (name.clone(), request.port),
};
match opener.open_channel(host, port).await {
Ok(mut ssh_stream) => {
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
let reply = Socks5Reply::success(bind_addr, 0);
reply.write_to(&mut socket).await?;
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
Ok(())
}
Err(_) => {
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
Err(Socks5Error::ChannelOpenFailed)
}
}
}
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
socket: &mut S,
reply: Socks5Reply,
) -> Result<(), Socks5Error> {
reply.write_to(socket).await?;
let _ = socket.shutdown().await;
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum Socks5Error {
#[error("invalid SOCKS version: {0}")]
InvalidVersion(u8),
#[error("no acceptable auth method")]
NoAcceptableAuth,
#[error("unsupported command: {0}")]
UnsupportedCommand(u8),
#[error("channel open failed")]
ChannelOpenFailed,
#[error("io error")]
Io(#[from] std::io::Error),
}
pub struct HandleChannelOpener<H: russh::client::Handler> {
handle: Arc<Mutex<russh::client::Handle<H>>>,
}
impl<H: russh::client::Handler> HandleChannelOpener<H> {
pub fn new(handle: russh::client::Handle<H>) -> Self {
Self {
handle: Arc::new(Mutex::new(handle)),
}
}
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
Self { handle }
}
}
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
type Stream = russh::ChannelStream<russh::client::Msg>;
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
let handle = self.handle.lock().await;
if handle.is_closed() {
return Err(ChannelOpenError::SessionClosed);
}
let channel = handle
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
.await
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
Ok(channel.into_stream())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
struct MockChannelOpener {
fail: bool,
}
impl ChannelOpener for MockChannelOpener {
type Stream = DuplexStream;
async fn open_channel(
&self,
_host: String,
_port: u16,
) -> Result<Self::Stream, ChannelOpenError> {
if self.fail {
Err(ChannelOpenError::ChannelOpenFailed)
} else {
let (client, _server) = duplex(4096);
Ok(client)
}
}
}
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
let mut buf = vec![0x05, methods.len() as u8];
buf.extend_from_slice(methods);
buf
}
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
buf.extend_from_slice(&addr);
buf.extend_from_slice(&port.to_be_bytes());
buf
}
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
buf.push(domain.len() as u8);
buf.extend_from_slice(domain.as_bytes());
buf.extend_from_slice(&port.to_be_bytes());
buf
}
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
buf.extend_from_slice(&addr);
buf.extend_from_slice(&port.to_be_bytes());
buf
}
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
client.flush().await.unwrap();
let mut resp = [0u8; 2];
client.read_exact(&mut resp).await.unwrap();
resp
}
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
client
.write_all(&build_socks5_connect_ipv4(addr, port))
.await
.unwrap();
client.flush().await.unwrap();
let mut reply_buf = [0u8; 10];
client.read_exact(&mut reply_buf).await.unwrap();
reply_buf.to_vec()
}
#[tokio::test]
async fn handshake_no_auth_method() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let resp = do_handshake(&mut client).await;
assert_eq!(resp, [0x05, 0x00]);
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
assert_eq!(reply_buf[0], 0x05);
assert_eq!(reply_buf[1], 0x00);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn handshake_rejects_no_acceptable_method() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
client
.write_all(&build_socks5_greeting(&[0x02]))
.await
.unwrap();
client.flush().await.unwrap();
let mut resp = [0u8; 2];
client.read_exact(&mut resp).await.unwrap();
assert_eq!(resp, [0x05, 0xFF]);
drop(client);
let result = server_handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
Socks5Error::NoAcceptableAuth
));
}
#[tokio::test]
async fn address_type_ipv4() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
do_handshake(&mut client).await;
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
assert_eq!(reply_buf[1], 0x00);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn address_type_domain() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
do_handshake(&mut client).await;
client
.write_all(&build_socks5_connect_domain("example.com", 443))
.await
.unwrap();
client.flush().await.unwrap();
let mut reply_buf = [0u8; 10];
client.read_exact(&mut reply_buf).await.unwrap();
assert_eq!(reply_buf[1], 0x00);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn address_type_ipv6() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
do_handshake(&mut client).await;
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
client
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
.await
.unwrap();
client.flush().await.unwrap();
let mut reply_buf = [0u8; 10];
client.read_exact(&mut reply_buf).await.unwrap();
assert_eq!(reply_buf[0], 0x05);
assert_eq!(reply_buf[1], 0x00);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn channel_open_failure_returns_socks5_error() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: true };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
do_handshake(&mut client).await;
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
assert_eq!(reply_buf[0], 0x05);
assert_eq!(reply_buf[1], 0x05);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn unsupported_command_returns_error() {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
do_handshake(&mut client).await;
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
bind_req.extend_from_slice(&[127, 0, 0, 1]);
bind_req.extend_from_slice(&80u16.to_be_bytes());
client.write_all(&bind_req).await.unwrap();
client.flush().await.unwrap();
let mut reply_buf = [0u8; 10];
client.read_exact(&mut reply_buf).await.unwrap();
assert_eq!(reply_buf[1], 0x07);
drop(client);
let _ = server_handle.await;
}
#[tokio::test]
async fn bidirectional_proxy_flow() {
let (mut client_sock, server_sock) = duplex(4096);
let (ssh_client, mut ssh_server) = duplex(4096);
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
struct ProxyOpener {
stream: Arc<Mutex<Option<DuplexStream>>>,
}
impl ChannelOpener for ProxyOpener {
type Stream = DuplexStream;
async fn open_channel(
&self,
_host: String,
_port: u16,
) -> Result<Self::Stream, ChannelOpenError> {
self.stream
.lock()
.await
.take()
.ok_or(ChannelOpenError::ChannelOpenFailed)
}
}
let opener = ProxyOpener {
stream: Arc::clone(&ssh_stream),
};
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server_sock, Arc::new(opener)).await
});
do_handshake(&mut client_sock).await;
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
assert_eq!(reply_buf[1], 0x00);
let test_data = b"hello through tunnel";
client_sock.write_all(test_data).await.unwrap();
client_sock.flush().await.unwrap();
let mut received = vec![0u8; test_data.len()];
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
.await
.unwrap();
assert_eq!(&received, test_data);
let echo_data = b"response from tunnel";
ssh_server.write_all(echo_data).await.unwrap();
ssh_server.flush().await.unwrap();
let mut received_back = vec![0u8; echo_data.len()];
client_sock.read_exact(&mut received_back).await.unwrap();
assert_eq!(&received_back, echo_data);
drop(client_sock);
drop(ssh_server);
let _ = server_handle.await;
}
#[tokio::test]
async fn default_listen_address() {
let opener = MockChannelOpener { fail: false };
let server = Socks5Server::new(opener);
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
}
#[tokio::test]
async fn custom_listen_address() {
let opener = MockChannelOpener { fail: false };
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
}
}

View File

@@ -0,0 +1,304 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Debug, Clone, PartialEq)]
pub enum Socks5Address {
Ipv4(Ipv4Addr),
Ipv6(Ipv6Addr),
Domain(String),
}
#[derive(Debug)]
pub struct Socks5VersionMethod {
pub version: u8,
pub methods: Vec<u8>,
}
impl Socks5VersionMethod {
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
let version = reader.read_u8().await?;
let nmethods = reader.read_u8().await?;
let mut methods = vec![0u8; nmethods as usize];
reader.read_exact(&mut methods).await?;
Ok(Self { version, methods })
}
}
#[derive(Debug)]
pub struct Socks5Request {
pub version: u8,
pub command: u8,
pub address: Socks5Address,
pub port: u16,
}
impl Socks5Request {
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
let version = reader.read_u8().await?;
let command = reader.read_u8().await?;
let _rsv = reader.read_u8().await?;
let atyp = reader.read_u8().await?;
let address = match atyp {
0x01 => {
let mut octets = [0u8; 4];
reader.read_exact(&mut octets).await?;
Socks5Address::Ipv4(Ipv4Addr::from(octets))
}
0x04 => {
let mut octets = [0u8; 16];
reader.read_exact(&mut octets).await?;
Socks5Address::Ipv6(Ipv6Addr::from(octets))
}
0x03 => {
let len = reader.read_u8().await?;
let mut domain = vec![0u8; len as usize];
reader.read_exact(&mut domain).await?;
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unsupported address type: {atyp}"),
))
}
};
let port = reader.read_u16().await?;
Ok(Self {
version,
command,
address,
port,
})
}
}
#[derive(Debug)]
pub struct Socks5Reply {
pub version: u8,
pub reply: u8,
pub address: Socks5Address,
pub port: u16,
}
impl Socks5Reply {
pub fn success(address: Socks5Address, port: u16) -> Self {
Self {
version: 0x05,
reply: 0x00,
address,
port,
}
}
pub fn connection_refused() -> Self {
Self {
version: 0x05,
reply: 0x05,
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
port: 0,
}
}
pub fn command_not_supported() -> Self {
Self {
version: 0x05,
reply: 0x07,
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
port: 0,
}
}
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u8(self.version).await?;
writer.write_u8(self.reply).await?;
writer.write_u8(0x00).await?;
match &self.address {
Socks5Address::Ipv4(addr) => {
writer.write_u8(0x01).await?;
writer.write_all(&addr.octets()).await?;
}
Socks5Address::Ipv6(addr) => {
writer.write_u8(0x04).await?;
writer.write_all(&addr.octets()).await?;
}
Socks5Address::Domain(name) => {
writer.write_u8(0x03).await?;
writer.write_u8(name.len() as u8).await?;
writer.write_all(name.as_bytes()).await?;
}
}
writer.write_u16(self.port).await?;
writer.flush().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn parse_version_method_no_auth() {
let data = [0x05, 0x01, 0x00];
let mut cursor = Cursor::new(&data[..]);
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
assert_eq!(vm.version, 0x05);
assert_eq!(vm.methods, vec![0x00]);
}
#[tokio::test]
async fn parse_version_method_multiple() {
let data = [0x05, 0x02, 0x00, 0x02];
let mut cursor = Cursor::new(&data[..]);
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
assert_eq!(vm.version, 0x05);
assert_eq!(vm.methods, vec![0x00, 0x02]);
}
#[tokio::test]
async fn parse_request_ipv4() {
let mut data = vec![0x05, 0x01, 0x00, 0x01];
data.extend_from_slice(&[10, 0, 0, 1]);
data.extend_from_slice(&443u16.to_be_bytes());
let mut cursor = Cursor::new(&data[..]);
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
assert_eq!(req.version, 0x05);
assert_eq!(req.command, 0x01);
assert_eq!(
req.address,
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
);
assert_eq!(req.port, 443);
}
#[tokio::test]
async fn parse_request_ipv6() {
let mut data = vec![0x05, 0x01, 0x00, 0x04];
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
data.extend_from_slice(&octets);
data.extend_from_slice(&443u16.to_be_bytes());
let mut cursor = Cursor::new(&data[..]);
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
assert_eq!(req.version, 0x05);
assert_eq!(req.command, 0x01);
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
assert_eq!(req.port, 443);
}
#[tokio::test]
async fn parse_request_domain() {
let domain = "example.com";
let mut data = vec![0x05, 0x01, 0x00, 0x03];
data.push(domain.len() as u8);
data.extend_from_slice(domain.as_bytes());
data.extend_from_slice(&443u16.to_be_bytes());
let mut cursor = Cursor::new(&data[..]);
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
assert_eq!(req.version, 0x05);
assert_eq!(req.command, 0x01);
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
assert_eq!(req.port, 443);
}
#[tokio::test]
async fn parse_request_unsupported_address_type() {
let data = [0x05, 0x01, 0x00, 0x05];
let mut cursor = Cursor::new(&data[..]);
let result = Socks5Request::read_from(&mut cursor).await;
assert!(result.is_err());
}
#[tokio::test]
async fn reply_success_ipv4() {
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
assert_eq!(buf[0], 0x05);
assert_eq!(buf[1], 0x00);
assert_eq!(buf[2], 0x00);
assert_eq!(buf[3], 0x01);
}
#[tokio::test]
async fn reply_connection_refused() {
let reply = Socks5Reply::connection_refused();
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
assert_eq!(buf[0], 0x05);
assert_eq!(buf[1], 0x05);
}
#[tokio::test]
async fn reply_command_not_supported() {
let reply = Socks5Reply::command_not_supported();
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
assert_eq!(buf[0], 0x05);
assert_eq!(buf[1], 0x07);
}
#[tokio::test]
async fn roundtrip_ipv4_reply() {
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
let mut cursor = Cursor::new(&buf[..]);
let version = cursor.read_u8().await.unwrap();
let _reply_code = cursor.read_u8().await.unwrap();
let _rsv = cursor.read_u8().await.unwrap();
let atyp = cursor.read_u8().await.unwrap();
assert_eq!(version, 0x05);
assert_eq!(atyp, 0x01);
let mut octets = [0u8; 4];
cursor.read_exact(&mut octets).await.unwrap();
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
let port = cursor.read_u16().await.unwrap();
assert_eq!(port, 1080);
}
#[tokio::test]
async fn roundtrip_ipv6_reply() {
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
let mut cursor = Cursor::new(&buf[..]);
let _version = cursor.read_u8().await.unwrap();
let _reply_code = cursor.read_u8().await.unwrap();
let _rsv = cursor.read_u8().await.unwrap();
let atyp = cursor.read_u8().await.unwrap();
assert_eq!(atyp, 0x04);
let mut octets = [0u8; 16];
cursor.read_exact(&mut octets).await.unwrap();
assert_eq!(Ipv6Addr::from(octets), addr);
let port = cursor.read_u16().await.unwrap();
assert_eq!(port, 443);
}
#[tokio::test]
async fn roundtrip_domain_reply() {
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
let mut buf = Vec::new();
reply.write_to(&mut buf).await.unwrap();
let mut cursor = Cursor::new(&buf[..]);
let _version = cursor.read_u8().await.unwrap();
let _reply_code = cursor.read_u8().await.unwrap();
let _rsv = cursor.read_u8().await.unwrap();
let atyp = cursor.read_u8().await.unwrap();
assert_eq!(atyp, 0x03);
let len = cursor.read_u8().await.unwrap();
let mut domain = vec![0u8; len as usize];
cursor.read_exact(&mut domain).await.unwrap();
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
let port = cursor.read_u16().await.unwrap();
assert_eq!(port, 8080);
}
}

View File

@@ -0,0 +1,141 @@
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
use anyhow::Result;
#[cfg(feature = "transport-traits")]
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
#[cfg(not(feature = "transport-traits"))]
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
#[cfg(not(feature = "transport-traits"))]
mod local_traits {
use std::net::SocketAddr;
use anyhow::Result;
use tokio::io::{AsyncRead, AsyncWrite};
use async_trait::async_trait;
#[async_trait]
pub trait Transport: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn connect(&self) -> Result<Self::Stream>;
fn describe(&self) -> String;
}
#[async_trait]
pub trait TransportAcceptor: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
}
#[derive(Debug, Clone)]
pub struct TransportInfo {
pub remote_addr: Option<SocketAddr>,
pub transport_kind: TransportKind,
}
#[derive(Debug, Clone)]
pub enum TransportKind {
Tcp,
Tls { server_name: Option<String> },
Iroh { endpoint_id: String },
}
}
pub struct MockStream {
inner: DuplexStream,
}
impl MockStream {
pub fn new(inner: DuplexStream) -> Self {
Self { inner }
}
}
impl AsyncRead for MockStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}
impl AsyncWrite for MockStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
impl Unpin for MockStream {}
pub struct MockTransport {
buf_size: usize,
}
impl MockTransport {
pub fn new(buf_size: usize) -> Self {
Self { buf_size }
}
}
#[async_trait::async_trait]
impl Transport for MockTransport {
type Stream = MockStream;
async fn connect(&self) -> Result<Self::Stream> {
let (client, _) = tokio::io::duplex(self.buf_size);
Ok(MockStream::new(client))
}
fn describe(&self) -> String {
"mock".to_string()
}
}
pub struct MockTransportAcceptor {
buf_size: usize,
}
impl MockTransportAcceptor {
pub fn new(buf_size: usize) -> Self {
Self { buf_size }
}
}
#[async_trait::async_trait]
impl TransportAcceptor for MockTransportAcceptor {
type Stream = MockStream;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (_, server) = tokio::io::duplex(self.buf_size);
let info = TransportInfo {
remote_addr: None,
transport_kind: TransportKind::Tcp,
};
Ok((MockStream::new(server), info))
}
}
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
let (client, server) = tokio::io::duplex(buf_size);
(MockStream::new(client), MockStream::new(server))
}

View File

@@ -0,0 +1,362 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use rustls::crypto::aws_lc_rs::default_provider;
use rustls::ServerConfig;
use rustls_acme::caches::DirCache;
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
use tracing::{error, info};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
use super::{TransportAcceptor, TransportInfo, TransportKind};
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
#[derive(Debug, Clone)]
pub enum AcmeMode {
Domain { domain: String },
Ip,
}
pub struct AcmeCertProvider {
mode: AcmeMode,
cache_dir: Option<PathBuf>,
directory_url: String,
contact: Vec<String>,
}
impl std::fmt::Debug for AcmeCertProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AcmeCertProvider")
.field("mode", &self.mode)
.field("cache_dir", &self.cache_dir)
.field("directory_url", &self.directory_url)
.field("contact", &self.contact)
.finish_non_exhaustive()
}
}
impl AcmeCertProvider {
pub fn new(mode: AcmeMode) -> Self {
Self {
mode,
cache_dir: None,
directory_url: rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY.to_string(),
contact: Vec::new(),
}
}
pub fn domain(domain: impl Into<String>) -> Self {
Self::new(AcmeMode::Domain {
domain: domain.into(),
})
}
pub fn ip() -> Self {
Self::new(AcmeMode::Ip)
}
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(dir.into());
self
}
pub fn with_directory(mut self, url: impl Into<String>) -> Self {
self.directory_url = url.into();
self
}
pub fn with_production_directory(mut self) -> Self {
self.directory_url = rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string();
self
}
pub fn with_contact(mut self, contact: impl Into<String>) -> Self {
self.contact.push(contact.into());
self
}
pub fn mode(&self) -> &AcmeMode {
&self.mode
}
fn build_acme_state(&self) -> (AcmeState<std::io::Error>, Arc<ResolvesServerCertAcme>) {
let domains: Vec<String> = match &self.mode {
AcmeMode::Domain { domain } => vec![domain.clone()],
AcmeMode::Ip => vec![],
};
let base_config = AcmeConfig::new(domains)
.directory(&self.directory_url)
.contact(self.contact.clone());
let state = match &self.cache_dir {
Some(cache_dir) => {
base_config.cache(DirCache::new(cache_dir.clone())).state()
}
None => {
base_config
.cache(rustls_acme::caches::NoCache::default())
.state()
}
};
let resolver = state.resolver();
(state, resolver)
}
pub fn build_server_config_with_resolver(
&self,
resolver: Arc<ResolvesServerCertAcme>,
) -> Result<Arc<ServerConfig>> {
let provider = default_provider().into();
let mut config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
.with_no_client_auth()
.with_cert_resolver(resolver);
config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
Ok(Arc::new(config))
}
}
pub struct AcmeTlsAcceptor {
listener: TcpListener,
listen_addr: SocketAddr,
#[allow(dead_code)]
server_config: Arc<ServerConfig>,
tokio_acceptor: TokioTlsAcceptor,
}
impl AcmeTlsAcceptor {
pub async fn bind_acme(
addr: SocketAddr,
provider: Arc<AcmeCertProvider>,
) -> Result<Self> {
let (state, resolver) = provider.build_acme_state();
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
Self::spawn_state_worker(state, resolver);
let listener = TcpListener::bind(addr).await?;
let listen_addr = listener.local_addr()?;
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
Ok(Self {
listener,
listen_addr,
server_config,
tokio_acceptor,
})
}
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
fn spawn_state_worker(state: AcmeState<std::io::Error>, resolver: Arc<ResolvesServerCertAcme>) {
use futures::StreamExt;
let task = async move {
let mut state = state;
while let Some(event) = state.next().await {
match event {
Ok(ok) => {
if let rustls_acme::EventOk::DeployedNewCert = ok {
info!("ACME: new certificate deployed");
} else {
info!("ACME event: {:?}", ok);
}
}
Err(err) => error!("ACME event error: {:?}", err),
}
if Arc::strong_count(&resolver) == 1 {
info!("ACME resolver dropped, stopping background task");
break;
}
}
};
tokio::spawn(task);
}
}
#[async_trait::async_trait]
impl TransportAcceptor for AcmeTlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (tcp_stream, remote_addr) = self.listener.accept().await?;
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
let server_name = tls_stream
.get_ref()
.1
.server_name()
.map(|s| s.to_string());
let info = TransportInfo {
remote_addr: Some(remote_addr),
transport_kind: TransportKind::Tls { server_name },
};
Ok((tls_stream, info))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn acme_cert_provider_domain_mode() {
let provider = AcmeCertProvider::domain("example.com");
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
if let AcmeMode::Domain { domain } = provider.mode() {
assert_eq!(domain, "example.com");
}
}
#[test]
fn acme_cert_provider_ip_mode() {
let provider = AcmeCertProvider::ip();
assert!(matches!(provider.mode(), AcmeMode::Ip));
}
#[test]
fn acme_cert_provider_default_staging_directory() {
let provider = AcmeCertProvider::domain("example.com");
assert_eq!(
provider.directory_url,
rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY
);
}
#[test]
fn acme_cert_provider_production_directory() {
let provider = AcmeCertProvider::domain("example.com").with_production_directory();
assert_eq!(
provider.directory_url,
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
);
}
#[test]
fn acme_cert_provider_custom_directory() {
let provider =
AcmeCertProvider::domain("example.com").with_directory("https://custom.acme.dir/");
assert_eq!(provider.directory_url, "https://custom.acme.dir/");
}
#[test]
fn acme_cert_provider_with_cache_dir() {
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/acme_cache");
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/acme_cache")));
}
#[test]
fn acme_cert_provider_with_contact() {
let provider =
AcmeCertProvider::domain("example.com").with_contact("mailto:admin@example.com");
assert_eq!(
provider.contact,
vec!["mailto:admin@example.com".to_string()]
);
}
#[test]
fn acme_cert_provider_build_state_domain() {
let provider = AcmeCertProvider::domain("example.com");
let (_state, resolver) = provider.build_acme_state();
assert!(Arc::strong_count(&resolver) >= 2);
}
#[test]
fn acme_cert_provider_build_state_with_cache() {
let provider =
AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
let (_state, resolver) = provider.build_acme_state();
assert!(Arc::strong_count(&resolver) >= 2);
}
#[test]
fn acme_cert_provider_build_server_config() {
let _ = default_provider().install_default();
let provider = AcmeCertProvider::domain("example.com");
let (_, resolver) = provider.build_acme_state();
let config = provider.build_server_config_with_resolver(resolver).unwrap();
assert!(!config.alpn_protocols.is_empty());
assert!(config
.alpn_protocols
.iter()
.any(|p| p == ACME_TLS_ALPN_NAME));
}
#[test]
fn acme_mode_domain_debug() {
let mode = AcmeMode::Domain {
domain: "test.example.com".to_string(),
};
let debug_str = format!("{:?}", mode);
assert!(debug_str.contains("test.example.com"));
}
#[test]
fn acme_mode_ip_debug() {
let mode = AcmeMode::Ip;
let debug_str = format!("{:?}", mode);
assert!(debug_str.contains("Ip"));
}
#[test]
fn acme_cert_provider_builder_chain() {
let provider = AcmeCertProvider::domain("test.example.com")
.with_production_directory()
.with_cache_dir("/tmp/cache")
.with_contact("mailto:admin@test.example.com");
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
assert_eq!(
provider.directory_url,
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
);
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/cache")));
assert_eq!(provider.contact.len(), 1);
}
#[tokio::test]
async fn acme_tls_acceptor_bind_acme() {
let _ = default_provider().install_default();
let provider = Arc::new(AcmeCertProvider::domain("example.com"));
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let acceptor = AcmeTlsAcceptor::bind_acme(addr, provider).await.unwrap();
assert_ne!(acceptor.listen_addr().port(), 0);
}
#[tokio::test]
#[ignore]
async fn acme_staging_domain_cert_provisioning() {
let _ = default_provider().install_default();
let cache_dir = tempfile::tempdir().unwrap();
let provider = Arc::new(
AcmeCertProvider::domain("acme-test.example.com")
.with_cache_dir(cache_dir.path())
.with_contact("mailto:admin@example.com"),
);
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
let result = AcmeTlsAcceptor::bind_acme(addr, provider).await;
assert!(
result.is_ok(),
"ACME TlsAcceptor should bind: {:?}",
result.err()
);
let acceptor = result.unwrap();
assert_eq!(acceptor.listen_addr().port(), 443);
}
}

View File

@@ -0,0 +1,321 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use iroh::{
endpoint::RecvStream,
node_info::NodeIdExt,
Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
};
use tokio::io;
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
pub const ALPN: &[u8] = b"alknet-ssh";
const DEFAULT_RELAY_URL: &str = "https://relay.iroh.network/";
/// A client-side iroh QUIC P2P transport that connects to a remote iroh endpoint.
///
/// Connects via `Endpoint::connect(node_id, alpn)`, opens a bidirectional
/// QUIC stream with `conn.open_bi()`, and joins the halves with
/// `tokio::io::join(recv, send)` to produce a duplex stream for russh.
/// Per ADR-003, `tokio::io::join` is used instead of a custom wrapper.
///
/// Use [`IrohTransport::new`] to create a standalone endpoint, or
/// [`IrohTransport::from_endpoint`] to share an existing iroh `Endpoint`
/// with other protocol handlers (blobs, gossip, docs).
pub struct IrohTransport {
node_id: NodeId,
endpoint: Endpoint,
owned: bool,
}
impl IrohTransport {
/// Create a new iroh transport with its own dedicated endpoint.
///
/// The endpoint is created with the `alknet-ssh` ALPN and the provided
/// relay URL. Use this when alknet is the only iroh service on this node.
pub async fn new(
node_id: NodeId,
relay_url: Option<RelayUrl>,
proxy_url: Option<url::Url>,
) -> Result<Self> {
let relay_url = relay_url.unwrap_or_else(|| {
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
});
let relay_map = RelayMap::from_url(relay_url);
let mut builder = Endpoint::builder()
.relay_mode(RelayMode::Custom(relay_map))
.alpns(vec![ALPN.to_vec()]);
if let Some(ref proxy) = proxy_url {
builder = builder.proxy_url(proxy.clone());
}
let endpoint = builder.bind().await?;
Ok(Self { node_id, endpoint, owned: true })
}
/// Create an iroh transport using an existing shared endpoint.
///
/// The endpoint must already have the `alknet-ssh` ALPN registered
/// (typically via [`iroh::protocol::Router::builder`]). This enables
/// running alknet alongside iroh-blobs, iroh-gossip, iroh-docs, and
/// other protocol handlers on the same QUIC endpoint — one connection
/// per peer, multiplexed by ALPN.
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
Self { node_id, endpoint, owned: false }
}
pub fn endpoint_id(&self) -> String {
self.endpoint.node_id().to_z32()
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
pub fn owned(&self) -> bool {
self.owned
}
}
#[async_trait]
impl Transport for IrohTransport {
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
async fn connect(&self) -> Result<Self::Stream> {
let conn = self.endpoint.connect(self.node_id, ALPN).await?;
let (send, recv) = conn.open_bi().await?;
Ok(io::join(recv, send))
}
fn describe(&self) -> String {
format!("iroh://{}", self.node_id.to_z32())
}
}
/// A server-side iroh QUIC P2P transport acceptor that listens for incoming connections.
///
/// Binds an iroh `Endpoint` with the configured relay URL and optional proxy
/// (ADR-010). Accepts incoming connections, accepts bidirectional QUIC streams,
/// and joins the halves with `tokio::io::join(recv, send)`. Exposes
/// `endpoint_id()` for CLI display of the server's z-base-32 node ID.
///
/// Use [`IrohAcceptor::bind`] to create a standalone endpoint, or
/// [`IrohAcceptor::from_endpoint`] to share an existing iroh `Endpoint`
/// with other protocol handlers (blobs, gossip, docs).
///
/// When using `from_endpoint`, the alknet-ssh ALPN must be registered
/// via an iroh `Router` that calls `Handler::accept()` on incoming
/// connections with the `alknet-ssh` ALPN, then passes the accepted
/// bidirectional stream to `russh::server::run_stream()`.
pub struct IrohAcceptor {
endpoint: Endpoint,
owned: bool,
}
impl IrohAcceptor {
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
///
/// Use this when alknet is the only iroh service on this node.
pub async fn bind(
relay_url: Option<RelayUrl>,
proxy_url: Option<url::Url>,
) -> Result<Self> {
let relay_url = relay_url.unwrap_or_else(|| {
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
});
let relay_map = RelayMap::from_url(relay_url);
let mut builder = Endpoint::builder()
.relay_mode(RelayMode::Custom(relay_map))
.alpns(vec![ALPN.to_vec()]);
if let Some(ref proxy) = proxy_url {
builder = builder.proxy_url(proxy.clone());
}
let endpoint = builder.bind().await?;
Ok(Self { endpoint, owned: true })
}
/// Create an iroh acceptor using an existing shared endpoint.
///
/// The endpoint must already have the `alknet-ssh` ALPN registered
/// (typically via [`iroh::protocol::Router::builder`]). When using a
/// shared endpoint, incoming connections with the `alknet-ssh` ALPN
/// are routed by the Router to a `ProtocolHandler` that this acceptor
/// does not manage — the caller is responsible for bridging the
/// Router's `accept()` callback to this acceptor's stream handling.
///
/// For the standalone case where alknet owns the endpoint, use
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
/// internally.
pub fn from_endpoint(endpoint: Endpoint) -> Self {
Self { endpoint, owned: false }
}
pub fn endpoint_id(&self) -> String {
self.endpoint.node_id().to_z32()
}
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
pub fn owned(&self) -> bool {
self.owned
}
}
#[async_trait]
impl TransportAcceptor for IrohAcceptor {
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let incoming = self
.endpoint
.accept()
.await
.ok_or_else(|| anyhow!("endpoint closed"))?;
let conn = incoming.await?;
let node_id = conn.remote_node_id()?;
let (send, recv) = conn.accept_bi().await?;
let stream = io::join(recv, send);
let info = TransportInfo {
remote_addr: None,
transport_kind: TransportKind::Iroh {
endpoint_id: node_id.to_z32(),
},
};
Ok((stream, info))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn iroh_acceptor_bind_creates_endpoint() {
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let endpoint_id = acceptor.endpoint_id();
assert!(!endpoint_id.is_empty());
let parsed = NodeId::from_z32(&endpoint_id);
assert!(parsed.is_ok());
assert!(acceptor.owned());
}
#[tokio::test]
async fn iroh_acceptor_bind_with_custom_relay() {
let relay: RelayUrl = "https://relay.iroh.network/".parse().unwrap();
let acceptor = IrohAcceptor::bind(Some(relay), None).await.unwrap();
assert!(!acceptor.endpoint_id().is_empty());
assert!(acceptor.owned());
}
#[tokio::test]
async fn iroh_acceptor_from_endpoint() {
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let endpoint = acceptor.endpoint.clone();
let shared = IrohAcceptor::from_endpoint(endpoint);
assert_eq!(shared.endpoint_id(), acceptor.endpoint_id());
assert!(!shared.owned());
}
#[test]
fn iroh_transport_describe_format() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let desc = format!("iroh://{}", node_id.to_z32());
assert!(desc.starts_with("iroh://"));
}
#[tokio::test]
async fn iroh_transport_connect_builds_endpoint() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
assert!(transport.describe().starts_with("iroh://"));
assert!(!transport.endpoint_id().is_empty());
assert!(transport.owned());
}
#[tokio::test]
async fn iroh_transport_from_endpoint() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let endpoint = acceptor.endpoint.clone();
let transport = IrohTransport::from_endpoint(node_id, endpoint);
assert!(transport.describe().starts_with("iroh://"));
assert_eq!(transport.endpoint_id(), acceptor.endpoint_id());
assert!(!transport.owned());
}
#[tokio::test]
#[ignore]
async fn iroh_client_connects_to_iroh_server() {
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let server_node_id = acceptor.endpoint().node_id();
let transport = IrohTransport::new(server_node_id, None, None)
.await
.unwrap();
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
addrs_watcher.initialized().await.unwrap();
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
for addr in addr_set {
transport
.endpoint
.add_node_addr(iroh::NodeAddr::from_parts(
server_node_id,
None,
vec![addr.addr],
))
.unwrap();
}
let accept_handle = tokio::spawn(async move {
let (stream, info) = acceptor.accept().await.unwrap();
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
stream
});
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
transport.connect().await.unwrap();
let _server_stream = accept_handle.await.unwrap();
}
#[tokio::test]
#[ignore]
async fn iroh_shared_endpoint_client_connects_to_server() {
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let server_node_id = acceptor.endpoint().node_id();
let shared_endpoint = acceptor.endpoint().clone();
let transport = IrohTransport::from_endpoint(server_node_id, shared_endpoint);
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
addrs_watcher.initialized().await.unwrap();
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
for addr in addr_set {
transport
.endpoint
.add_node_addr(iroh::NodeAddr::from_parts(
server_node_id,
None,
vec![addr.addr],
))
.unwrap();
}
let accept_handle = tokio::spawn(async move {
let (stream, info) = acceptor.accept().await.unwrap();
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
stream
});
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
transport.connect().await.unwrap();
let _server_stream = accept_handle.await.unwrap();
}
}

View File

@@ -0,0 +1,188 @@
//! Pluggable transport layer for Alknet.
//!
//! The transport layer produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
//! that SSH consumes. This is the core architectural abstraction — SSH never opens its own
//! network connections; it runs entirely over whatever stream the transport provides.
//!
//! Available transports (feature-gated):
//! - `TcpTransport` / `TcpAcceptor` — always available, direct TCP
//! - `TlsTransport` / `TlsAcceptor` — behind the `tls` feature, TCP + rustls
//! - `IrohTransport` / `IrohAcceptor` — behind the `iroh` feature, QUIC P2P via iroh
//! - `AcmeTlsAcceptor` — behind the `acme` feature, auto-provision TLS certs via Let's Encrypt
//!
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
mod tcp;
#[cfg(feature = "iroh")]
mod iroh_transport;
pub use tcp::{TcpAcceptor, TcpTransport};
#[cfg(feature = "iroh")]
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
#[cfg(feature = "acme")]
mod acme;
#[cfg(feature = "acme")]
pub use acme::{AcmeCertProvider, AcmeMode, AcmeTlsAcceptor};
use std::net::SocketAddr;
use anyhow::Result;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
/// Client-side transport trait. Produces a single duplex stream per connection.
///
/// Implementations connect to a remote endpoint and return a stream that SSH
/// runs over via `russh::client::connect_stream()`. Each call to `connect()` creates
/// a new stream — multiple sessions need multiple calls or multiple transports.
#[async_trait]
pub trait Transport: Send + Sync + 'static {
/// The duplex stream type produced by this transport.
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// Connect to the remote endpoint and return a duplex stream.
async fn connect(&self) -> Result<Self::Stream>;
/// Return a human-readable description of this transport for logging.
fn describe(&self) -> String;
}
/// Server-side transport acceptor. Accepts incoming connections and returns streams.
///
/// Implementations bind to a local endpoint and produce streams that SSH
/// runs over via `russh::server::run_stream()`.
#[async_trait]
pub trait TransportAcceptor: Send + Sync + 'static {
/// The duplex stream type produced by this acceptor.
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// Accept an incoming connection and return a duplex stream with metadata.
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
}
/// Metadata about an incoming transport connection.
///
/// Carries the remote address (if available) and the kind of transport
/// used. The server handler uses this for logging and auth decisions.
/// See ADR-001 for the pluggable transport rationale and ADR-004
/// for why SSH runs entirely over the transport stream.
#[derive(Debug, Clone)]
pub struct TransportInfo {
pub remote_addr: Option<SocketAddr>,
pub transport_kind: TransportKind,
}
/// The kind of transport that produced a connection.
///
/// Each variant identifies the transport mechanism. Used by the
/// server handler for logging and authorization decisions.
/// See ADR-001 and ADR-004.
#[derive(Debug, Clone)]
pub enum TransportKind {
Tcp,
Tls {
server_name: Option<String>,
},
Iroh {
endpoint_id: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, DuplexStream};
struct MockTransport;
#[async_trait]
impl Transport for MockTransport {
type Stream = DuplexStream;
async fn connect(&self) -> Result<Self::Stream> {
let (stream, _) = duplex(1024);
Ok(stream)
}
fn describe(&self) -> String {
"mock".to_string()
}
}
struct MockAcceptor;
#[async_trait]
impl TransportAcceptor for MockAcceptor {
type Stream = DuplexStream;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (stream, _) = duplex(1024);
let info = TransportInfo {
remote_addr: None,
transport_kind: TransportKind::Tcp,
};
Ok((stream, info))
}
}
#[tokio::test]
async fn transport_trait_object() {
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
}
#[tokio::test]
async fn transport_acceptor_trait_object() {
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
}
#[tokio::test]
async fn transport_connect_returns_stream() {
let t = MockTransport;
let _stream = t.connect().await.unwrap();
}
#[tokio::test]
async fn transport_describe_returns_string() {
let t = MockTransport;
assert_eq!(t.describe(), "mock");
}
#[tokio::test]
async fn acceptor_accept_returns_stream_and_info() {
let a = MockAcceptor;
let (_, info) = a.accept().await.unwrap();
assert!(info.remote_addr.is_none());
assert!(matches!(info.transport_kind, TransportKind::Tcp));
}
#[test]
fn transport_kind_variants() {
let tcp = TransportKind::Tcp;
let tls = TransportKind::Tls {
server_name: Some("example.com".to_string()),
};
let iroh = TransportKind::Iroh {
endpoint_id: "abc123".to_string(),
};
if let TransportKind::Tcp = tcp {}
if let TransportKind::Tls {
server_name: Some(name),
} = tls
{
assert_eq!(name, "example.com");
}
if let TransportKind::Iroh { endpoint_id } = iroh {
assert_eq!(endpoint_id, "abc123");
}
}
}

View File

@@ -0,0 +1,162 @@
use std::net::SocketAddr;
use anyhow::Result;
use async_trait::async_trait;
use tokio::net::{TcpListener, TcpStream};
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
/// A TCP-based client transport that connects to a remote address.
///
/// Connects via `TcpStream::connect(addr)`. Uses tokio's default
/// connect timeout behavior: the OS controls connection timeout
/// (typically ~2 minutes on Linux via `net.ipv4.tcp_syn_retries`).
/// For custom timeouts, wrap `TcpTransport` with
/// `tokio::time::timeout(duration, transport.connect())`.
pub struct TcpTransport {
addr: SocketAddr,
}
impl TcpTransport {
pub fn new(addr: SocketAddr) -> Self {
Self { addr }
}
}
#[async_trait]
impl Transport for TcpTransport {
type Stream = TcpStream;
async fn connect(&self) -> Result<Self::Stream> {
let stream = TcpStream::connect(self.addr).await?;
Ok(stream)
}
fn describe(&self) -> String {
format!("tcp://{}", self.addr)
}
}
/// A TCP-based server transport acceptor that listens for incoming connections.
///
/// Binds via `TcpListener::bind(addr)`. Accepts connections and returns
/// the stream together with `TransportInfo` containing the remote address
/// and `TransportKind::Tcp`.
pub struct TcpAcceptor {
listener: TcpListener,
listen_addr: SocketAddr,
}
impl TcpAcceptor {
/// Bind a TCP listener on the given address.
///
/// Returns the acceptor ready to receive connections.
/// The actual bound address may differ from the requested one
/// (e.g., when binding to port 0 the OS assigns an ephemeral port).
pub async fn bind(addr: SocketAddr) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
let listen_addr = listener.local_addr()?;
Ok(Self {
listener,
listen_addr,
})
}
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
}
#[async_trait]
impl TransportAcceptor for TcpAcceptor {
type Stream = TcpStream;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (stream, remote_addr) = self.listener.accept().await?;
let info = TransportInfo {
remote_addr: Some(remote_addr),
transport_kind: TransportKind::Tcp,
};
Ok((stream, info))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn tcp_transport_connect_creates_stream() {
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TcpTransport::new(addr);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let stream = transport.connect().await.unwrap();
assert_eq!(stream.local_addr().unwrap().ip(), addr.ip());
let (_server_stream, info) = accept_handle.await.unwrap();
assert!(info.remote_addr.is_some());
assert!(matches!(info.transport_kind, TransportKind::Tcp));
}
#[tokio::test]
async fn tcp_acceptor_accept_receives_connection() {
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let addr = acceptor.listen_addr();
tokio::spawn(async move {
TcpStream::connect(addr).await.unwrap();
});
let (stream, info) = acceptor.accept().await.unwrap();
assert!(info.remote_addr.is_some());
assert!(matches!(info.transport_kind, TransportKind::Tcp));
assert_eq!(
info.remote_addr.unwrap().ip(),
stream.peer_addr().unwrap().ip()
);
}
#[test]
fn tcp_transport_describe_format() {
let addr: SocketAddr = "1.2.3.4:22".parse().unwrap();
let transport = TcpTransport::new(addr);
assert_eq!(transport.describe(), "tcp://1.2.3.4:22");
}
#[tokio::test]
async fn tcp_stream_is_duplex() {
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let addr = acceptor.listen_addr();
let mut client = TcpStream::connect(addr).await.unwrap();
let (mut server, _) = acceptor.accept().await.unwrap();
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
server.write_all(b"world").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"world");
}
#[tokio::test]
async fn tcp_acceptor_bind_port_zero_assigns_ephemeral() {
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
assert_ne!(acceptor.listen_addr().port(), 0);
}
}

View File

@@ -0,0 +1,432 @@
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
#[cfg(feature = "acme")]
use rustls::crypto::aws_lc_rs::default_provider;
#[cfg(feature = "acme")]
use rustls_acme::ResolvesServerCertAcme;
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
#[cfg(feature = "acme")]
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
/// A TLS-based client transport that connects to a remote address over TLS.
///
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
/// Supports insecure mode (accepts any certificate, for development) and
/// custom root CA certificates for verification. The `tls_server_name` field
/// overrides the SNI hostname sent during the TLS handshake (ADR-010).
pub struct TlsTransport {
addr: SocketAddr,
tls_server_name: Option<String>,
insecure: bool,
root_cert: Option<CertificateDer<'static>>,
}
impl TlsTransport {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
tls_server_name: None,
insecure: false,
root_cert: None,
}
}
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
self.tls_server_name = Some(name.into());
self
}
pub fn with_insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self {
self.root_cert = Some(cert);
self
}
fn build_client_config(&self) -> Result<ClientConfig> {
if self.insecure {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
return Ok(config);
}
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(ref cert) = self.root_cert {
root_store.add(cert.clone())?;
}
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
fn resolve_server_name(&self) -> Result<ServerName<'static>> {
let name = match &self.tls_server_name {
Some(n) => n.clone(),
None => self.addr.ip().to_string(),
};
ServerName::try_from(name.clone())
.map_err(move |e| anyhow!("invalid server name '{}': {}", name, e))
}
}
#[async_trait]
impl Transport for TlsTransport {
type Stream = ClientTlsStream<TcpStream>;
async fn connect(&self) -> Result<Self::Stream> {
let tcp_stream = TcpStream::connect(self.addr).await?;
let config = self.build_client_config()?;
let connector = TlsConnector::from(Arc::new(config));
let server_name = self.resolve_server_name()?;
let tls_stream = connector.connect(server_name, tcp_stream).await?;
Ok(tls_stream)
}
fn describe(&self) -> String {
format!("tls://{}", self.addr)
}
}
/// Stub configuration for ACME certificate provisioning (ADR-008).
/// Feature-gated behind the `acme` feature. When implemented, this will
/// hold the ACME domain and challenge responder configuration.
#[derive(Debug)]
pub struct AcmeConfig {
pub domain: String,
}
/// A TLS-based server transport acceptor that accepts TCP connections
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
///
/// Supports three certificate modes (ADR-008):
/// - Manual certs via `bind()` with explicit cert/key
/// - ACME certs via `bind_acme()` with an `AcmeCertProvider`
/// - The stub `AcmeConfig` parameter in `bind()` is kept for backward compat
pub struct TlsAcceptor {
listener: TcpListener,
listen_addr: SocketAddr,
#[allow(dead_code)]
server_config: Arc<ServerConfig>,
tokio_acceptor: TokioTlsAcceptor,
}
impl TlsAcceptor {
pub async fn bind(
addr: SocketAddr,
tls_certs: Vec<CertificateDer<'static>>,
tls_key: PrivateKeyDer<'static>,
_acme_config: Option<AcmeConfig>,
) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
let listen_addr = listener.local_addr()?;
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(tls_certs, tls_key)?;
let server_config = Arc::new(server_config);
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
Ok(Self {
listener,
listen_addr,
server_config,
tokio_acceptor,
})
}
#[cfg(feature = "acme")]
pub async fn bind_acme(
addr: SocketAddr,
acme_resolver: Arc<ResolvesServerCertAcme>,
) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
let listen_addr = listener.local_addr()?;
let provider = default_provider().into();
let mut server_config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
.with_no_client_auth()
.with_cert_resolver(acme_resolver);
server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
let server_config = Arc::new(server_config);
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
Ok(Self {
listener,
listen_addr,
server_config,
tokio_acceptor,
})
}
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
}
#[async_trait]
impl TransportAcceptor for TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<TcpStream>;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (tcp_stream, remote_addr) = self.listener.accept().await?;
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
let server_name = tls_stream
.get_ref()
.1
.server_name()
.map(|s| s.to_string());
let info = TransportInfo {
remote_addr: Some(remote_addr),
transport_kind: TransportKind::Tls { server_name },
};
Ok((tls_stream, info))
}
}
#[derive(Debug)]
struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_doc: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_doc: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rcgen::{CertificateParams, KeyPair};
use rustls::crypto::aws_lc_rs::default_provider;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn ensure_crypto_provider() {
let _ = default_provider().install_default();
}
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der: CertificateDer<'static> = cert.into();
let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into());
(cert_der, key_der)
}
#[test]
fn tls_transport_describe_format() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr).with_server_name("example.com");
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
}
#[test]
fn tls_transport_describe_with_ip() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr);
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
}
#[test]
fn tls_transport_builder_methods() {
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
let transport = TlsTransport::new(addr)
.with_server_name("alknet.test")
.with_insecure(true);
assert_eq!(transport.tls_server_name, Some("alknet.test".to_string()));
assert!(transport.insecure);
}
#[tokio::test]
async fn tls_connect_insecure_self_signed() {
ensure_crypto_provider();
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let mut client = transport.connect().await.unwrap();
let (mut server, info) = accept_handle.await.unwrap();
assert!(info.remote_addr.is_some());
assert!(matches!(
info.transport_kind,
TransportKind::Tls { .. }
));
client.write_all(b"hello tls").await.unwrap();
let mut buf = [0u8; 9];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello tls");
server.write_all(b"reply").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"reply");
}
#[tokio::test]
async fn tls_acceptor_returns_server_name() {
ensure_crypto_provider();
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let _client = transport.connect().await.unwrap();
let (_, info) = accept_handle.await.unwrap();
if let TransportKind::Tls { server_name } = info.transport_kind {
assert_eq!(server_name, Some("localhost".to_string()));
} else {
panic!("expected TransportKind::Tls");
}
}
#[tokio::test]
async fn tls_full_client_to_server_connection() {
ensure_crypto_provider();
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
let addr = acceptor.listen_addr();
let transport = TlsTransport::new(addr)
.with_server_name("localhost")
.with_insecure(true);
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
let mut client = transport.connect().await.unwrap();
let (mut server, _info) = accept_handle.await.unwrap();
let msg = b"alknet integration test";
client.write_all(msg).await.unwrap();
let mut buf = vec![0u8; msg.len()];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], msg);
let reply = b"ok";
server.write_all(reply).await.unwrap();
let mut buf = [0u8; 2];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, reply);
}
#[tokio::test]
async fn tls_acceptor_bind_port_zero_assigns_ephemeral() {
ensure_crypto_provider();
let (cert_der, key_der) = generate_self_signed_cert();
let acceptor = TlsAcceptor::bind(
"127.0.0.1:0".parse().unwrap(),
vec![cert_der],
key_der,
None,
)
.await
.unwrap();
assert_ne!(acceptor.listen_addr().port(), 0);
}
#[test]
fn no_verifier_accepts_any_cert() {
let verifier = NoVerifier;
assert!(verifier.supported_verify_schemes().len() > 0);
}
}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn auth_placeholder() {}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn client_placeholder() {}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn server_placeholder() {}

View File

@@ -0,0 +1,26 @@
use alknet_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
#[tokio::test]
async fn mock_transport_connect() {
let transport = MockTransport::new(1024);
let stream = transport.connect().await.unwrap();
drop(stream);
}
#[tokio::test]
async fn mock_transport_acceptor_accept() {
let acceptor = MockTransportAcceptor::new(1024);
let (stream, info) = acceptor.accept().await.unwrap();
drop(stream);
drop(info);
}
#[tokio::test]
async fn mock_pair_communicates() {
let (mut client, mut server) = mock_pair(1024);
use tokio::io::{AsyncReadExt, AsyncWriteExt};
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}