11 Commits

Author SHA1 Message Date
94feb5fdac feat(cli): implement wraith connect subcommand with clap derive
All CLI flags from client.md: --server, --peer, --transport (default tcp),
--identity, --socks5 (default 127.0.0.1:1080), --forward (repeatable),
--remote-forward (repeatable), --proxy, --iroh-relay, --tls-server-name,
--insecure. Env var defaults: WRAITH_SERVER, WRAITH_IDENTITY. Validates
--server required for tcp/tls, --peer required for iroh, --identity required.
Warns on --proxy with --transport tcp (ADR-019). Translates args to
ConnectOptions and calls ClientSession::new(opts).run().await. Errors to
stderr with non-zero exit.
2026-06-02 11:39:57 +00:00
f13a1c985f Merge remote-tracking branch 'origin/feat/server/channel-proxy'
# Conflicts:
#	crates/wraith-core/src/error.rs
#	crates/wraith-core/src/server/mod.rs
2026-06-02 11:32:28 +00:00
365b11d19e Merge remote-tracking branch 'origin/feat/server/stealth-mode'
# Conflicts:
#	crates/wraith-core/src/error.rs
#	crates/wraith-core/src/server/mod.rs
2026-06-02 11:14:06 +00:00
7dcf7502b7 feat(server): implement stealth mode protocol multiplexing (ADR-017)
Add stealth mode detection that peeks at the first bytes after TLS handshake
to determine SSH vs HTTP protocol. SSH connections proceed to russh handler;
non-SSH connections receive a fake nginx 404 response, making the server
indistinguishable from an ordinary HTTPS site to scanners and DPI systems.

- ProtocolDetection enum (Ssh, Http) for protocol classification
- detect_protocol() uses BufReader::fill_buf() to peek without consuming bytes
- send_fake_nginx_404() writes HTTP/1.1 404 + Server: nginx headers
- validate_stealth_config() enforces TLS transport requirement for stealth
- 17 unit tests covering SSH banner, HTTP, random data, and edge cases
2026-06-02 11:13:15 +00:00
585913d3c8 Merge remote-tracking branch 'origin/feat/napi/connect-function'
# Conflicts:
#	crates/wraith-core/src/error.rs
2026-06-02 11:11:14 +00:00
243243a82f Implement NAPI connect() function — single SSH channel as duplex stream
- Add WraithConnectOptions struct with napi fields: server, peer, transport,
  identity (string path or Buffer), tlsServerName, insecure, irohRelay, proxy
- Add WraithStream napi class wrapping SSH channel read/write halves via
  ChannelStream::into_stream() + tokio::io::split()
- Implement connect() async function: transport creation (tcp, tls), SSH client
  connection, authenticate, open direct_tcpip channel, return WraithStream
- Identity field accepts file path (string) or in-memory key data (Buffer)
- All Rust errors marshalled to JavaScript exceptions with descriptive messages
- Add ForwardError enum to wraith-core (required by forward.rs)
- Enable tls, iroh features on wraith-core dependency
- 7 unit tests for key source resolution and address parsing
2026-06-02 11:10:42 +00:00
2ab5eeda53 Merge remote-tracking branch 'origin/feat/client/connect-options' 2026-06-02 11:07:54 +00:00
128affd264 Implement ConnectOptions struct and ClientSession orchestration with graceful shutdown
Adds client/connect.rs with ConnectOptions (programmatic API per ADR-011),
ClientSession::new() for SSH session establishment, ClientSession::run()
for SOCKS5 + port forwards + shutdown, and graceful shutdown via
SIGTERM/SIGINT with SSH disconnect and 2s drain timeout.
2026-06-02 11:07:33 +00:00
5a2b535605 Merge remote-tracking branch 'origin/feat/server/rate-limiting-and-logging'
# Conflicts:
#	crates/wraith-core/src/error.rs
#	crates/wraith-core/src/server/handler.rs
#	crates/wraith-core/src/server/mod.rs
2026-06-02 11:06:18 +00:00
24b70f5651 Implement server rate limiting and fail2ban-friendly structured logging
Add ConnectionRateLimiter (HashMap<IpAddr, usize>) and AuthAttemptLimiter
with check/on_connect/on_disconnect and check/on_failure methods.
Integrate into ServerHandler with structured tracing::info! logging for
auth attempts, connection opened/closed events. No logging of tunnel
destinations per ADR-006. Also add ForwardError type and fix type
annotation in forward.rs to unblock compilation.
2026-06-02 11:02:55 +00:00
f963898a05 Implement control channel routing for wraith-* reserved destinations (ADR-018)
- Add control_channel.rs with WRAITH_CONTROL_DESTINATION, WRAITH_PREFIX constants
- Add ControlChannelHandler trait and ControlChannelRouter for routing logic
- Add DuplexStream supertrait for Box<dyn> compatibility
- Server handler rejects wraith-* destinations when no handler configured
- Add ForwardError type to fix pre-existing compilation error
- Unit tests: reserved detection, non-reserved pass-through, prefix matching
2026-06-02 11:01:54 +00:00
14 changed files with 2096 additions and 26 deletions

7
Cargo.lock generated
View File

@@ -2395,6 +2395,7 @@ version = "3.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2"
dependencies = [
"anyhow",
"bitflags 2.11.1",
"ctor",
"futures",
@@ -2402,6 +2403,7 @@ dependencies = [
"napi-sys",
"nohash-hasher",
"rustc-hash",
"tokio",
]
[[package]]
@@ -5583,7 +5585,9 @@ version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"iroh",
"tokio",
"url",
"wraith-core",
]
@@ -5593,6 +5597,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"futures",
"ipnetwork",
"iroh",
"rand 0.10.1",
@@ -5620,6 +5625,8 @@ version = "0.1.0"
dependencies = [
"napi",
"napi-derive",
"russh",
"tokio",
"wraith-core",
]

View File

@@ -0,0 +1,727 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use russh::client;
use russh::keys::PrivateKey;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
use crate::auth::keys::KeySource;
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
use crate::error::ConfigError;
use crate::socks5::{HandleChannelOpener, Socks5Server};
use crate::transport::Transport;
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransportMode {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportMode::Tcp => write!(f, "tcp"),
TransportMode::Tls => write!(f, "tls"),
TransportMode::Iroh => write!(f, "iroh"),
}
}
}
#[derive(Clone)]
pub struct ConnectOptions {
pub server: Option<String>,
pub peer: Option<String>,
pub transport_mode: TransportMode,
pub identity: KeySource,
pub socks5_addr: String,
pub forwards: Vec<String>,
pub remote_forwards: Vec<String>,
pub proxy: Option<String>,
pub iroh_relay: Option<String>,
pub tls_server_name: Option<String>,
pub insecure: bool,
}
impl ConnectOptions {
pub fn new(identity: KeySource) -> Self {
Self {
server: None,
peer: None,
transport_mode: TransportMode::Tcp,
identity,
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
forwards: Vec::new(),
remote_forwards: Vec::new(),
proxy: None,
iroh_relay: None,
tls_server_name: None,
insecure: false,
}
}
pub fn server(mut self, addr: impl Into<String>) -> Self {
self.server = Some(addr.into());
self
}
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
self.peer = Some(endpoint_id.into());
self
}
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
self.transport_mode = mode;
self
}
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
self.socks5_addr = addr.into();
self
}
pub fn forward(mut self, spec: impl Into<String>) -> Self {
self.forwards.push(spec.into());
self
}
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
self.remote_forwards.push(spec.into());
self
}
pub fn proxy(mut self, url: impl Into<String>) -> Self {
self.proxy = Some(url.into());
self
}
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
self.iroh_relay = Some(url.into());
self
}
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
self.tls_server_name = Some(name.into());
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
match self.transport_mode {
TransportMode::Tcp | TransportMode::Tls => {
if self.server.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--server is required for tcp/tls transport".to_string(),
});
}
}
TransportMode::Iroh => {
if self.peer.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--peer is required for iroh transport".to_string(),
});
}
}
}
Ok(())
}
}
impl std::fmt::Debug for ConnectOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectOptions")
.field("server", &self.server)
.field("peer", &self.peer)
.field("transport_mode", &self.transport_mode)
.field("identity", &"<KeySource>")
.field("socks5_addr", &self.socks5_addr)
.field("forwards", &self.forwards)
.field("remote_forwards", &self.remote_forwards)
.field("proxy", &self.proxy)
.field("iroh_relay", &self.iroh_relay)
.field("tls_server_name", &self.tls_server_name)
.field("insecure", &self.insecure)
.finish()
}
}
pub struct ClientSession<T: Transport> {
opts: ConnectOptions,
transport: Arc<T>,
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
auth_config: Arc<ClientAuthConfig>,
#[allow(dead_code)]
private_key: Arc<PrivateKey>,
#[allow(dead_code)]
username: String,
shutdown_tx: tokio::sync::watch::Sender<bool>,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
}
impl<T: Transport> ClientSession<T> {
pub async fn new(
opts: ConnectOptions,
transport: Arc<T>,
) -> Result<Self, ConnectError> {
opts.validate().map_err(ConnectError::Config)?;
let auth_config = Arc::new(
ClientAuthConfig::from_key_source(opts.identity.clone())
.map_err(ConnectError::Config)?,
);
let private_key = auth_config.private_key();
let username = derive_username();
let handler = ClientHandler::from_config(&auth_config);
let stream = transport.connect().await.map_err(|e| {
error!("transport connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let config = Arc::new(client::Config::default());
let mut handle = client::connect_stream(config, stream, handler)
.await
.map_err(|e| {
error!("SSH connect failed: {e}");
ConnectError::ConnectionFailed
})?;
let auth_ok = auth_config
.authenticate(&mut handle, &username)
.await
.map_err(|_| ConnectError::AuthFailed)?;
if !auth_ok {
return Err(ConnectError::AuthFailed);
}
let handle = Arc::new(Mutex::new(handle));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
Ok(Self {
opts,
transport,
handle,
auth_config,
private_key,
username,
shutdown_tx,
shutdown_rx,
})
}
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
Arc::clone(&self.handle)
}
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
&self.auth_config
}
pub fn transport(&self) -> &Arc<T> {
&self.transport
}
pub fn options(&self) -> &ConnectOptions {
&self.opts
}
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
self.shutdown_tx.clone()
}
pub async fn run(self) -> Result<(), ConnectError> {
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
})
})?;
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
let socks5_listen = socks5_server.listen_addr();
let local_forwarders = build_local_forwarders(&self.opts)?;
let remote_specs = build_remote_specs(&self.opts)?;
for spec in &remote_specs {
let remote_forwarder = RemoteForwarder::new(spec.clone())
.map_err(|_| ConnectError::ForwardFailed)?;
let mut h = self.handle.lock().await;
remote_forwarder
.register(&mut h)
.await
.map_err(|_| {
warn!("failed to register remote forward {}", spec);
ConnectError::ForwardFailed
})?;
info!("registered remote forward: {}", spec);
}
let socks5_task = tokio::spawn(async move {
debug!("SOCKS5 server starting on {}", socks5_listen);
if let Err(e) = socks5_server.run().await {
error!("SOCKS5 server error: {e}");
}
});
let fwd_handle = Arc::clone(&self.handle);
let fwd_shutdown = self.shutdown_rx.clone();
let forward_task = tokio::spawn(async move {
crate::client::forward::run_local_forwarders(
local_forwarders, fwd_handle, fwd_shutdown,
)
.await;
});
info!("wraith client running: SOCKS5 on {}", socks5_listen);
#[cfg(unix)]
let signal_done = {
let sig_tx = self.shutdown_tx.clone();
tokio::spawn(async move {
let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = sigterm_stream.recv() => {
info!("received SIGTERM");
}
_ = tokio::signal::ctrl_c() => {
info!("received SIGINT (Ctrl+C)");
}
}
let _ = sig_tx.send(true);
})
};
let mut wait_shutdown = self.shutdown_rx.clone();
tokio::select! {
_ = wait_shutdown.changed() => {
if *wait_shutdown.borrow() {
info!("shutdown signal received");
}
}
_ = socks5_task => {
warn!("SOCKS5 server exited unexpectedly");
}
}
#[cfg(unix)]
signal_done.abort();
self.shutdown().await?;
forward_task.abort();
let _ = forward_task.await;
Ok(())
}
pub async fn shutdown(&self) -> Result<(), ConnectError> {
info!("initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
{
let handle = self.handle.lock().await;
if !handle.is_closed() {
if let Err(e) = handle
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
.await
{
warn!("failed to send SSH disconnect: {e}");
}
}
}
tokio::time::sleep(DRAIN_TIMEOUT).await;
info!("graceful shutdown complete");
Ok(())
}
}
fn derive_username() -> String {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "wraith".to_string())
}
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
let mut forwarders = Vec::new();
for spec_str in &opts.forwards {
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
warn!("invalid local forward spec '{}': {}", spec_str, e);
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid forward spec: {}", spec_str),
})
})?;
forwarders.push(
LocalForwarder::new(spec).map_err(|e| {
warn!("failed to create local forwarder: {}", e);
ConnectError::ForwardFailed
})?,
);
}
Ok(forwarders)
}
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
let mut specs = Vec::new();
for spec_str in &opts.remote_forwards {
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
warn!("invalid remote forward spec '{}': {}", spec_str, e);
ConnectError::Config(ConfigError::InvalidFlag {
name: format!("invalid remote forward spec: {}", spec_str),
})
})?;
specs.push(spec);
}
Ok(specs)
}
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
#[error("connection failed")]
ConnectionFailed,
#[error("authentication failed")]
AuthFailed,
#[error("forward setup failed")]
ForwardFailed,
#[error("config error: {0}")]
Config(#[from] ConfigError),
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::duplex;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
fn make_identity() -> KeySource {
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
}
#[test]
fn connect_options_default_fields() {
let opts = ConnectOptions::new(make_identity());
assert!(opts.server.is_none());
assert!(opts.peer.is_none());
assert_eq!(opts.transport_mode, TransportMode::Tcp);
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
assert!(opts.forwards.is_empty());
assert!(opts.remote_forwards.is_empty());
assert!(opts.proxy.is_none());
assert!(opts.iroh_relay.is_none());
assert!(opts.tls_server_name.is_none());
assert!(!opts.insecure);
}
#[test]
fn connect_options_builder_pattern() {
let opts = ConnectOptions::new(make_identity())
.server("example.com:22")
.transport_mode(TransportMode::Tls)
.socks5_addr("127.0.0.1:9050")
.forward("127.0.0.1:5432:db:5432")
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
.proxy("socks5://127.0.0.1:1080")
.iroh_relay("https://relay.example.com")
.tls_server_name("wraith.test")
.insecure(true);
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
assert_eq!(opts.transport_mode, TransportMode::Tls);
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
assert_eq!(opts.forwards.len(), 1);
assert_eq!(opts.remote_forwards.len(), 1);
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
assert_eq!(opts.tls_server_name.as_deref(), Some("wraith.test"));
assert!(opts.insecure);
}
#[test]
fn connect_options_validate_tcp_requires_server() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_tcp_with_server_ok() {
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
assert!(opts.validate().is_ok());
}
#[test]
fn connect_options_validate_tls_requires_server() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_tls_with_server_ok() {
let opts = ConnectOptions::new(make_identity())
.transport_mode(TransportMode::Tls)
.server("example.com:443");
assert!(opts.validate().is_ok());
}
#[test]
fn connect_options_validate_iroh_requires_peer() {
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
assert!(opts.validate().is_err());
}
#[test]
fn connect_options_validate_iroh_with_peer_ok() {
let opts = ConnectOptions::new(make_identity())
.transport_mode(TransportMode::Iroh)
.peer("some-endpoint-id");
assert!(opts.validate().is_ok());
}
#[test]
fn identity_accepts_key_source_file() {
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
let opts = ConnectOptions::new(file_source);
match &opts.identity {
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
_ => panic!("expected File variant"),
}
}
#[test]
fn identity_accepts_key_source_memory() {
let mem_source = KeySource::Memory(b"key-data".to_vec());
let opts = ConnectOptions::new(mem_source);
match &opts.identity {
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
_ => panic!("expected Memory variant"),
}
}
#[test]
fn transport_mode_display() {
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
assert_eq!(TransportMode::Tls.to_string(), "tls");
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
}
#[test]
fn connect_error_variants() {
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
}
#[test]
fn connect_options_debug_redacts_identity() {
let opts = ConnectOptions::new(make_identity());
let debug_str = format!("{:?}", opts);
assert!(debug_str.contains("<KeySource>"));
assert!(!debug_str.contains("OPENSSH"));
}
struct FailTransport;
#[async_trait::async_trait]
impl Transport for FailTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
Err(anyhow::anyhow!("always fails"))
}
fn describe(&self) -> String {
"fail".to_string()
}
}
struct DuplexTransport {
connect_count: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Transport for DuplexTransport {
type Stream = tokio::io::DuplexStream;
async fn connect(&self) -> anyhow::Result<Self::Stream> {
self.connect_count.fetch_add(1, Ordering::SeqCst);
let (client, _) = duplex(4096);
Ok(client)
}
fn describe(&self) -> String {
"duplex".to_string()
}
}
#[tokio::test]
async fn client_session_new_transport_fails() {
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let transport = Arc::new(FailTransport);
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
}
#[tokio::test]
async fn client_session_new_ssh_handshake_fails() {
let transport = Arc::new(DuplexTransport {
connect_count: Arc::new(AtomicUsize::new(0)),
});
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
}
#[test]
fn build_local_forwarders_empty() {
let opts = ConnectOptions::new(make_identity());
let result = build_local_forwarders(&opts);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn build_local_forwarders_valid() {
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
let result = build_local_forwarders(&opts);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn build_local_forwarders_invalid_spec() {
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
let result = build_local_forwarders(&opts);
assert!(result.is_err());
}
#[test]
fn build_remote_specs_empty() {
let opts = ConnectOptions::new(make_identity());
let result = build_remote_specs(&opts);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn build_remote_specs_valid() {
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
let result = build_remote_specs(&opts);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn build_remote_specs_invalid() {
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
let result = build_remote_specs(&opts);
assert!(result.is_err());
}
#[test]
fn default_socks5_addr() {
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
}
#[test]
fn drain_timeout_is_two_seconds() {
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
}
#[test]
fn transport_mode_equality() {
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
}
#[tokio::test]
async fn shutdown_sends_disconnect_and_drains() {
let transport = Arc::new(DuplexTransport {
connect_count: Arc::new(AtomicUsize::new(0)),
});
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
}
#[test]
fn socks5_is_always_enabled_by_default() {
let opts = ConnectOptions::new(make_identity());
assert!(!opts.socks5_addr.is_empty());
}
#[tokio::test]
async fn integration_mock_transport_session() {
use crate::socks5::{ChannelOpener, ChannelOpenError};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
struct MockOpener;
impl ChannelOpener for MockOpener {
type Stream = tokio::io::DuplexStream;
async fn open_channel(
&self,
_host: String,
_port: u16,
) -> Result<Self::Stream, ChannelOpenError> {
let (client, _server) = duplex(4096);
Ok(client)
}
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound_addr = listener.local_addr().unwrap();
drop(listener);
let opener = MockOpener;
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
let _server_task = tokio::spawn(async move {
let _ = server.run().await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
let greeting = [0x05, 0x01, 0x00];
conn.write_all(&greeting).await.unwrap();
let mut auth_resp = [0u8; 2];
conn.read_exact(&mut auth_resp).await.unwrap();
assert_eq!(auth_resp, [0x05, 0x00]);
let connect_req = [
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
];
conn.write_all(&connect_req).await.unwrap();
let mut reply = [0u8; 10];
conn.read_exact(&mut reply).await.unwrap();
assert_eq!(reply[1], 0x00);
conn.write_all(b"test data").await.unwrap();
conn.shutdown().await.unwrap();
}
}

View File

@@ -1,5 +1,7 @@
pub mod channel_manager;
pub mod connect;
pub mod forward;
pub use channel_manager::{ChannelManager, ForwardRequest};
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};

View File

@@ -10,4 +10,5 @@ pub mod testutil;
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
pub use client::channel_manager::{ChannelManager, ForwardRequest};
pub use client::channel_manager::{ChannelManager, ForwardRequest};
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};

View File

@@ -0,0 +1,186 @@
use std::io;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
pub const WRAITH_CONTROL_DESTINATION: &str = "wraith-control";
pub const WRAITH_PREFIX: &str = "wraith-";
pub fn is_reserved_destination(host: &str) -> bool {
host.starts_with(WRAITH_PREFIX)
}
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
#[async_trait]
pub trait ControlChannelHandler: Send + Sync {
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
}
pub struct ControlChannelRouter {
handler: Option<Box<dyn ControlChannelHandler>>,
}
impl ControlChannelRouter {
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
Self { handler }
}
pub fn without_handler() -> Self {
Self { handler: None }
}
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
Self {
handler: Some(handler),
}
}
pub fn has_handler(&self) -> bool {
self.handler.is_some()
}
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
match &self.handler {
Some(handler) => {
handler.handle_channel(stream).await;
Ok(())
}
None => Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"no control channel handler configured",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[test]
fn wraith_control_destination_constant() {
assert_eq!(WRAITH_CONTROL_DESTINATION, "wraith-control");
}
#[test]
fn wraith_prefix_constant() {
assert_eq!(WRAITH_PREFIX, "wraith-");
}
#[test]
fn reserved_destination_detected() {
assert!(is_reserved_destination("wraith-control"));
assert!(is_reserved_destination("wraith-status"));
assert!(is_reserved_destination("wraith-events"));
assert!(is_reserved_destination("wraith-"));
}
#[test]
fn non_reserved_destination_passes_through() {
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("192.168.1.1"));
assert!(!is_reserved_destination("wraith.example.com"));
assert!(!is_reserved_destination(""));
assert!(!is_reserved_destination("wrait-control"));
assert!(!is_reserved_destination("WRAITH-control"));
}
#[test]
fn prefix_matching_case_sensitive() {
assert!(!is_reserved_destination("Wraith-control"));
assert!(!is_reserved_destination("WRAITH-control"));
assert!(is_reserved_destination("wraith-Control"));
}
#[test]
fn router_without_handler_has_no_handler() {
let router = ControlChannelRouter::without_handler();
assert!(!router.has_handler());
}
#[test]
fn router_with_handler_has_handler() {
struct DummyHandler;
#[async_trait]
impl ControlChannelHandler for DummyHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
}
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
assert!(router.has_handler());
}
#[tokio::test]
async fn route_without_handler_returns_error() {
let router = ControlChannelRouter::without_handler();
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
}
#[tokio::test]
async fn route_with_handler_succeeds() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
struct TrackedHandler {
called: Arc<AtomicBool>,
}
#[async_trait]
impl ControlChannelHandler for TrackedHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
self.called.store(true, Ordering::SeqCst);
}
}
let called = Arc::new(AtomicBool::new(false));
let handler = TrackedHandler {
called: called.clone(),
};
let router = ControlChannelRouter::with_handler(Box::new(handler));
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_ok());
assert!(called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn route_with_handler_can_read_write() {
struct EchoHandler;
#[async_trait]
impl ControlChannelHandler for EchoHandler {
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
let mut buf = [0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
stream.write_all(&buf[..n]).await.unwrap();
}
}
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
let (client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
tokio::spawn(async move {
router.route(stream).await.unwrap();
});
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut client = client;
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[test]
fn control_channel_destination_matches_prefix() {
assert!(is_reserved_destination(WRAITH_CONTROL_DESTINATION));
}
}

View File

@@ -1,5 +1,6 @@
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
@@ -7,8 +8,10 @@ use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use crate::auth::ServerAuthConfig;
const WRAITH_PREFIX: &str = "wraith-";
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
#[derive(Debug, Clone)]
pub enum ProxyMode {
@@ -22,11 +25,34 @@ pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
}
}
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
#[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
control_channel_router: ControlChannelRouter,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
connected_at: Instant,
}
impl ServerHandler {
@@ -34,13 +60,82 @@ impl ServerHandler {
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
auth_config,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
connected_at: Instant::now(),
}
}
pub fn is_connection_allowed(&self) -> bool {
self.connection_allowed
}
pub fn remote_ip(&self) -> Option<IpAddr> {
self.remote_addr.map(|a| a.ip())
}
}
impl Drop for ServerHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
}
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
impl ServerHandler {
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
pub fn control_channel_router(&self) -> &ControlChannelRouter {
&self.control_channel_router
}
}
#[async_trait]
@@ -52,6 +147,23 @@ impl Handler for ServerHandler {
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
@@ -64,6 +176,7 @@ impl Handler for ServerHandler {
Ok(()) => {
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
@@ -71,8 +184,10 @@ impl Handler for ServerHandler {
Ok(Auth::Accept)
}
Err(_) => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
@@ -99,6 +214,16 @@ impl Handler for ServerHandler {
port = port_to_connect,
"routing to internal control channel handler"
);
if !self.control_channel_router.has_handler() {
tracing::warn!(
host = host_to_connect,
"no control channel handler configured, rejecting channel open"
);
return Ok(false);
}
let _ = channel;
return Ok(true);
}
@@ -174,10 +299,22 @@ mod tests {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
Arc::new(ConnectionRateLimiter::new(0))
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
}
#[tokio::test]
async fn auth_delegation_accepts_known_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
@@ -187,7 +324,7 @@ mod tests {
#[tokio::test]
async fn auth_delegation_rejects_unknown_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
@@ -210,7 +347,7 @@ mod tests {
#[tokio::test]
async fn auth_delegation_empty_config_rejects_all() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
@@ -229,7 +366,7 @@ mod tests {
async fn auth_logging_includes_remote_addr() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
let mut handler = make_handler(auth_config, None, Some(remote_addr));
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
@@ -237,12 +374,20 @@ mod tests {
#[test]
fn reserved_wraith_destination_routing() {
assert!("wraith-control".starts_with(WRAITH_PREFIX));
assert!("wraith-status".starts_with(WRAITH_PREFIX));
assert!("wraith-events".starts_with(WRAITH_PREFIX));
assert!(!"example.com".starts_with(WRAITH_PREFIX));
assert!(!"localhost".starts_with(WRAITH_PREFIX));
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
use crate::server::control_channel::is_reserved_destination;
assert!(is_reserved_destination("wraith-control"));
assert!(is_reserved_destination("wraith-status"));
assert!(is_reserved_destination("wraith-events"));
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("wraith.example.com"));
}
#[test]
fn server_handler_without_control_handler_rejects_wraith_destinations() {
let auth_config = make_empty_auth_config();
let handler = make_handler(auth_config, None, None);
assert!(!handler.control_channel_router().has_handler());
}
#[test]
@@ -273,7 +418,7 @@ mod tests {
});
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
let handler = make_handler(auth_config, proxy.clone(), remote);
assert!(handler.outbound_proxy.is_some());
assert!(handler.remote_addr.is_some());
}
@@ -281,9 +426,108 @@ mod tests {
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
assert!(handler1.remote_addr != handler2.remote_addr);
}
#[tokio::test]
async fn auth_rate_limit_rejects_after_max_failures() {
let auth_config = make_empty_auth_config();
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let mut handler = ServerHandler::new(
auth_config,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
limiter,
2,
);
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert!(!handler.auth_limiter.check());
}
#[test]
fn connection_rate_limit_blocks_over_limit() {
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let auth_config = make_empty_auth_config();
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(h1.is_connection_allowed());
let h2 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(!h2.is_connection_allowed());
drop(h1);
let h3 = ServerHandler::new(
auth_config,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(h3.is_connection_allowed());
}
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
}
#[tokio::test]
async fn auth_log_includes_user_field() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn connection_closed_logs_duration_on_drop() {
let auth_config = make_empty_auth_config();
let _handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tcp,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
}
}

View File

@@ -1,5 +1,14 @@
pub mod channel_proxy;
pub mod control_channel;
pub mod handler;
pub mod rate_limit;
pub mod stealth;
pub use channel_proxy::{ChannelProxyError, connect_outbound, proxy_channel};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
pub use channel_proxy::{connect_outbound, proxy_channel};
pub use control_channel::{
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
WRAITH_PREFIX, is_reserved_destination,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};

View File

@@ -0,0 +1,193 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
pub struct ConnectionRateLimiter {
max_per_ip: usize,
active: Mutex<HashMap<IpAddr, usize>>,
}
impl ConnectionRateLimiter {
pub fn new(max_per_ip: usize) -> Self {
Self {
max_per_ip,
active: Mutex::new(HashMap::new()),
}
}
pub fn check(&self, ip: IpAddr) -> bool {
if self.max_per_ip == 0 {
return true;
}
let active = self.active.lock().unwrap();
let count = active.get(&ip).copied().unwrap_or(0);
count < self.max_per_ip
}
pub fn on_connect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
*active.entry(ip).or_insert(0) += 1;
}
pub fn on_disconnect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
if let Some(count) = active.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
active.remove(&ip);
}
}
}
}
pub struct AuthAttemptLimiter {
max_attempts: usize,
failures: usize,
}
impl AuthAttemptLimiter {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
failures: 0,
}
}
pub fn check(&self) -> bool {
if self.max_attempts == 0 {
return true;
}
self.failures < self.max_attempts
}
pub fn on_failure(&mut self) {
self.failures += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn ip(n: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
}
#[test]
fn connection_limiter_allows_when_under_limit() {
let limiter = ConnectionRateLimiter::new(3);
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_blocks_when_at_limit() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
}
#[test]
fn connection_limiter_allows_after_disconnect() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
limiter.on_disconnect(ip(1));
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_unlimited_when_zero() {
let limiter = ConnectionRateLimiter::new(0);
for _ in 0..100 {
limiter.on_connect(ip(1));
}
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_tracks_per_ip_independently() {
let limiter = ConnectionRateLimiter::new(1);
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
assert!(limiter.check(ip(2)));
}
#[test]
fn connection_limiter_ipv6() {
let limiter = ConnectionRateLimiter::new(1);
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
limiter.on_connect(ip6);
assert!(!limiter.check(ip6));
}
#[test]
fn connection_limiter_disconnect_removes_zero_entry() {
let limiter = ConnectionRateLimiter::new(3);
limiter.on_connect(ip(1));
limiter.on_disconnect(ip(1));
{
let active = limiter.active.lock().unwrap();
assert!(!active.contains_key(&ip(1)));
}
}
#[test]
fn auth_limiter_allows_when_under_limit() {
let limiter = AuthAttemptLimiter::new(3);
assert!(limiter.check());
}
#[test]
fn auth_limiter_blocks_after_max_failures() {
let mut limiter = AuthAttemptLimiter::new(2);
limiter.on_failure();
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn auth_limiter_unlimited_when_zero() {
let mut limiter = AuthAttemptLimiter::new(0);
for _ in 0..100 {
limiter.on_failure();
}
assert!(limiter.check());
}
#[test]
fn auth_limiter_still_allows_at_one_below_limit() {
let mut limiter = AuthAttemptLimiter::new(3);
limiter.on_failure();
limiter.on_failure();
assert!(limiter.check());
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn connection_limiter_thread_safety() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(ConnectionRateLimiter::new(100));
let mut handles = vec![];
for i in 0..10 {
let lim = Arc::clone(&limiter);
handles.push(thread::spawn(move || {
let ip_addr = ip((i % 3) as u8 + 1);
lim.on_connect(ip_addr);
assert!(lim.check(ip_addr));
lim.on_disconnect(ip_addr);
}));
}
for h in handles {
h.join().unwrap();
}
}
}

View File

@@ -0,0 +1,218 @@
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolDetection {
Ssh,
Http,
}
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
where
S: AsyncRead + Unpin,
{
let mut reader = BufReader::new(stream);
let detection = match reader.fill_buf().await {
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
Ok(buf) if !buf.is_empty() => {
if buf.starts_with(SSH_BANNER_PREFIX) {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
_ => ProtocolDetection::Http,
};
(detection, reader)
}
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
let _ = reader.get_mut().shutdown().await;
}
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
if stealth && !transport_is_tls {
return Err("stealth mode requires TLS transport (--transport tls)");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(data).await.unwrap();
drop(client);
let (detection, _) = detect_protocol(server).await;
detection
}
#[tokio::test]
async fn ssh_banner_detected() {
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_other_implementation() {
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_minimal() {
let detection = write_and_detect(b"SSH-2.0-X\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn http_get_detected_as_http() {
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_post_detected_as_http() {
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn random_data_detected_as_http() {
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn empty_stream_detected_as_http() {
let (client, server) = duplex(1024);
drop(client);
let (detection, _) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn ssh_banner_bytes_preserved_by_bufreader() {
let (client, server) = duplex(1024);
let mut client = client;
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
client.write_all(banner).await.unwrap();
client.write_all(b"subsequent data").await.unwrap();
drop(client);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Ssh);
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data).await.unwrap();
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
}
#[tokio::test]
async fn fake_nginx_404_response() {
let (client, server) = duplex(1024);
let (mut client_read, mut client_write) = tokio::io::split(client);
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
drop(client_write);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client_read.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.contains("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
}
#[tokio::test]
async fn protocol_detection_enum_equality() {
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
}
#[test]
fn validate_stealth_without_tls_rejected() {
let result = validate_stealth_config(true, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("TLS transport"));
}
#[test]
fn validate_stealth_with_tls_accepted() {
let result = validate_stealth_config(true, true);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tcp_accepted() {
let result = validate_stealth_config(false, false);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tls_accepted() {
let result = validate_stealth_config(false, true);
assert!(result.is_ok());
}
#[tokio::test]
async fn short_data_detected_as_http() {
let detection = write_and_detect(b"GE").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn partial_ssh_prefix_detected_as_http() {
let detection = write_and_detect(b"SSH-1.").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_request_gets_404_then_closed() {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
let mut extra = [0u8; 16];
let result = client.read(&mut extra).await;
assert!(result.is_err() || result.unwrap() == 0);
}
}

View File

@@ -7,6 +7,8 @@ edition = "2021"
crate-type = ["cdylib"]
[dependencies]
wraith-core = { path = "../wraith-core" }
napi = "3"
napi-derive = "3"
wraith-core = { path = "../wraith-core", features = ["tls", "iroh"] }
napi = { version = "3", features = ["async", "error_anyhow"] }
napi-derive = "3"
tokio = { version = "1", features = ["io-util", "sync"] }
russh = "0.49"

View File

@@ -0,0 +1,249 @@
use std::net::SocketAddr;
use std::sync::Arc;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use russh::client;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
use wraith_core::auth::client_auth::{ClientAuthConfig, ClientHandler};
use wraith_core::auth::keys::KeySource;
use wraith_core::transport::{TcpTransport, TlsTransport, Transport};
const DEFAULT_HOST: &str = "wraith-control";
const DEFAULT_PORT: u32 = 0;
#[napi(object)]
pub struct WraithConnectOptions {
pub server: Option<String>,
pub peer: Option<String>,
pub transport: String,
pub identity: Option<Either<String, Buffer>>,
pub tls_server_name: Option<String>,
pub insecure: Option<bool>,
pub iroh_relay: Option<String>,
pub proxy: Option<String>,
}
fn resolve_key_source(identity: &Option<Either<String, Buffer>>) -> Result<KeySource> {
match identity {
None => Err(Error::new(
Status::InvalidArg,
"identity is required: provide a file path (string) or key data (Buffer)",
)),
Some(Either::A(path)) => Ok(KeySource::File(path.into())),
Some(Either::B(buf)) => Ok(KeySource::Memory(buf.to_vec())),
}
}
fn parse_addr(addr_str: &str) -> Result<SocketAddr> {
addr_str.parse().map_err(|e| {
Error::new(
Status::InvalidArg,
format!("invalid server address '{}': {}", addr_str, e),
)
})
}
#[napi]
pub struct WraithStream {
read: Arc<Mutex<tokio::io::ReadHalf<russh::ChannelStream<client::Msg>>>>,
write: Arc<Mutex<tokio::io::WriteHalf<russh::ChannelStream<client::Msg>>>>,
}
#[napi]
impl WraithStream {
#[napi]
pub async fn read(&self, size: u32) -> Result<Buffer> {
let mut buf = vec![0u8; size as usize];
let mut guard = self.read.lock().await;
let n = guard.read(&mut buf).await.map_err(|e| {
Error::new(Status::GenericFailure, format!("read failed: {}", e))
})?;
if n == 0 {
return Ok(Vec::<u8>::new().into());
}
buf.truncate(n);
Ok(buf.into())
}
#[napi]
pub async fn write(&self, data: Buffer) -> Result<()> {
let mut guard = self.write.lock().await;
guard.write_all(&data).await.map_err(|e| {
Error::new(Status::GenericFailure, format!("write failed: {}", e))
})?;
Ok(())
}
#[napi]
pub async fn close(&self) -> Result<()> {
let mut guard = self.write.lock().await;
guard.shutdown().await.map_err(|e| {
Error::new(Status::GenericFailure, format!("close failed: {}", e))
})
}
}
#[napi]
pub async fn connect(options: WraithConnectOptions) -> Result<WraithStream> {
let key_source = resolve_key_source(&options.identity)?;
let auth_config = Arc::new(ClientAuthConfig::from_key_source(key_source).map_err(|e| {
Error::new(Status::InvalidArg, format!("invalid identity key: {}", e))
})?);
let transport_mode = options.transport.to_lowercase();
let handler = ClientHandler::from_config(&auth_config);
let username = "wraith".to_string();
let config = Arc::new(client::Config::default());
let mut handle: client::Handle<ClientHandler> = match transport_mode.as_str() {
"tcp" => {
let server = options.server.as_ref().ok_or_else(|| {
Error::new(Status::InvalidArg, "server is required for tcp transport")
})?;
let addr = parse_addr(server)?;
let transport = TcpTransport::new(addr);
let stream = transport.connect().await.map_err(|e| {
Error::new(Status::GenericFailure, format!("tcp connect failed: {}", e))
})?;
client::connect_stream(config, stream, handler)
.await
.map_err(|e| {
Error::new(
Status::GenericFailure,
format!("ssh handshake failed: {}", e),
)
})?
}
"tls" => {
let server = options.server.as_ref().ok_or_else(|| {
Error::new(Status::InvalidArg, "server is required for tls transport")
})?;
let addr = parse_addr(server)?;
let mut transport = TlsTransport::new(addr);
if let Some(ref name) = options.tls_server_name {
transport = transport.with_server_name(name);
}
if let Some(true) = options.insecure {
transport = transport.with_insecure(true);
}
let stream = transport.connect().await.map_err(|e| {
Error::new(Status::GenericFailure, format!("tls connect failed: {}", e))
})?;
client::connect_stream(config, stream, handler)
.await
.map_err(|e| {
Error::new(
Status::GenericFailure,
format!("ssh handshake failed: {}", e),
)
})?
}
"iroh" => {
return Err(Error::new(
Status::GenericFailure,
"iroh transport is not yet supported in napi connect()".to_string(),
));
}
_ => {
return Err(Error::new(
Status::InvalidArg,
format!("unknown transport '{}'; expected tcp, tls, or iroh", transport_mode),
));
}
};
let auth_ok = auth_config
.authenticate(&mut handle, &username)
.await
.map_err(|e| {
Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e))
})?;
if !auth_ok {
return Err(Error::new(Status::GenericFailure, "ssh authentication rejected"));
}
let channel = handle
.channel_open_direct_tcpip(DEFAULT_HOST, DEFAULT_PORT, "127.0.0.1", 0)
.await
.map_err(|e| {
Error::new(
Status::GenericFailure,
format!("failed to open ssh channel: {}", e),
)
})?;
let stream = channel.into_stream();
let (read_half, write_half) = tokio::io::split(stream);
Ok(WraithStream {
read: Arc::new(Mutex::new(read_half)),
write: Arc::new(Mutex::new(write_half)),
})
}
#[cfg(test)]
mod tests {
use super::*;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
#[test]
fn resolve_key_source_file_path() {
let identity = Some(Either::<String, Buffer>::A("/path/to/key".to_string()));
let result = resolve_key_source(&identity);
assert!(result.is_ok());
match result.unwrap() {
KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/key")),
_ => panic!("expected File variant"),
}
}
#[test]
fn resolve_key_source_buffer() {
let identity = Some(Either::<String, Buffer>::B(Buffer::from(ED25519_PRIVATE_KEY.as_bytes().to_vec())));
let result = resolve_key_source(&identity);
assert!(result.is_ok());
match result.unwrap() {
KeySource::Memory(data) => assert!(!data.is_empty()),
_ => panic!("expected Memory variant"),
}
}
#[test]
fn resolve_key_source_missing() {
let identity: Option<Either<String, Buffer>> = None;
let result = resolve_key_source(&identity);
assert!(result.is_err());
}
#[test]
fn parse_addr_valid() {
let addr = parse_addr("127.0.0.1:22");
assert!(addr.is_ok());
assert_eq!(addr.unwrap().port(), 22);
}
#[test]
fn parse_addr_invalid() {
let addr = parse_addr("not-an-address");
assert!(addr.is_err());
}
#[test]
fn auth_config_from_memory_key() {
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
let config = ClientAuthConfig::from_key_source(source);
assert!(config.is_ok());
}
#[test]
fn auth_config_from_invalid_key() {
let source = KeySource::Memory(b"not-a-key".to_vec());
let config = ClientAuthConfig::from_key_source(source);
assert!(config.is_err());
}
}

View File

@@ -1,3 +1,5 @@
#[allow(unused_imports)]
#[macro_use]
extern crate napi_derive;
extern crate napi_derive;
mod connect;

View File

@@ -7,8 +7,15 @@ edition = "2021"
name = "wraith"
path = "src/main.rs"
[features]
default = ["tls", "iroh"]
tls = ["wraith-core/tls"]
iroh = ["wraith-core/iroh", "dep:iroh", "dep:url"]
[dependencies]
wraith-core = { path = "../wraith-core" }
clap = { version = "4", features = ["derive"] }
clap = { version = "4", features = ["derive", "env"] }
tokio = { version = "1", features = ["full"] }
anyhow = "1"
anyhow = "1"
iroh = { version = "0.34", optional = true }
url = { version = "2", optional = true }

View File

@@ -1 +1,224 @@
fn main() {}
use std::net::SocketAddr;
use std::process;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use clap::{Parser, Subcommand, ValueEnum};
use wraith_core::auth::keys::KeySource;
use wraith_core::client::{ConnectOptions, TransportMode};
use wraith_core::transport::TcpTransport;
#[cfg(feature = "tls")]
use wraith_core::transport::TlsTransport;
#[cfg(feature = "iroh")]
use wraith_core::transport::IrohTransport;
use wraith_core::transport::Transport;
#[derive(Parser)]
#[command(name = "wraith", version, about = "Wraith SSH tunnel client")]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
#[command(about = "Connect to a wraith server and start a SOCKS5 proxy / port forwarding session")]
Connect {
#[arg(long, help = "TCP/TLS server address (required for tcp/tls transport)", env = "WRAITH_SERVER")]
server: Option<String>,
#[arg(long, help = "iroh endpoint ID, base58-encoded (required for iroh transport)")]
peer: Option<String>,
#[arg(long, value_enum, default_value = "tcp", help = "Transport mode")]
transport: TransportModeArg,
#[arg(long, help = "SSH private key path", env = "WRAITH_IDENTITY")]
identity: Option<String>,
#[arg(long, default_value = "127.0.0.1:1080", help = "SOCKS5 listen address")]
socks5: String,
#[arg(long, action = clap::ArgAction::Append, help = "Port forward spec (repeatable, e.g. 5432:db:5432)")]
forward: Vec<String>,
#[arg(long, action = clap::ArgAction::Append, help = "Remote port forward spec (repeatable)")]
remote_forward: Vec<String>,
#[arg(long, help = "Upstream proxy URL (socks5:// or http://)")]
proxy: Option<String>,
#[arg(long, help = "iroh relay URL")]
iroh_relay: Option<String>,
#[arg(long, help = "SNI hostname for TLS")]
tls_server_name: Option<String>,
#[arg(long, help = "Accept self-signed TLS certs")]
insecure: bool,
},
}
#[derive(Clone, Debug, ValueEnum)]
enum TransportModeArg {
Tcp,
Tls,
Iroh,
}
impl From<TransportModeArg> for TransportMode {
fn from(val: TransportModeArg) -> Self {
match val {
TransportModeArg::Tcp => TransportMode::Tcp,
TransportModeArg::Tls => TransportMode::Tls,
TransportModeArg::Iroh => TransportMode::Iroh,
}
}
}
#[tokio::main]
async fn main() {
if let Err(e) = run().await {
eprintln!("error: {e}");
process::exit(1);
}
}
async fn run() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Commands::Connect {
server,
peer,
transport,
identity,
socks5,
forward,
remote_forward,
proxy,
iroh_relay,
tls_server_name,
insecure,
} => {
let identity_val = identity
.ok_or_else(|| anyhow!("--identity is required (or set WRAITH_IDENTITY env var)"))?;
let key_source = KeySource::File(identity_val.into());
let transport_mode: TransportMode = transport.into();
if proxy.is_some() && matches!(transport_mode, TransportMode::Tcp) {
eprintln!("warning: --proxy with --transport tcp is effectively a no-op (TCP transport is already a direct connection); use the SOCKS5 server instead");
}
let mut opts = ConnectOptions::new(key_source)
.transport_mode(transport_mode.clone())
.socks5_addr(&socks5);
if let Some(ref s) = server {
opts = opts.server(s);
}
if let Some(ref p) = peer {
opts = opts.peer(p);
}
for fwd in &forward {
opts = opts.forward(fwd);
}
for rfwd in &remote_forward {
opts = opts.remote_forward(rfwd);
}
if let Some(ref p) = proxy {
opts = opts.proxy(p);
}
if let Some(ref r) = iroh_relay {
opts = opts.iroh_relay(r);
}
if let Some(ref n) = tls_server_name {
opts = opts.tls_server_name(n);
}
if insecure {
opts = opts.insecure(true);
}
opts.validate().map_err(|e| anyhow!("{e}"))?;
match transport_mode {
TransportMode::Tcp => {
let addr: SocketAddr = server
.as_deref()
.ok_or_else(|| anyhow!("--server is required for tcp transport"))?
.parse()
.map_err(|e| anyhow!("invalid server address: {e}"))?;
let t = Arc::new(TcpTransport::new(addr));
connect_and_run(opts, t).await
}
TransportMode::Tls => {
#[cfg(not(feature = "tls"))]
{
return Err(anyhow!("TLS transport is not available (wraith-core built without 'tls' feature)"));
}
#[cfg(feature = "tls")]
{
let addr: SocketAddr = server
.as_deref()
.ok_or_else(|| anyhow!("--server is required for tls transport"))?
.parse()
.map_err(|e| anyhow!("invalid server address: {e}"))?;
let mut t = TlsTransport::new(addr);
if let Some(ref n) = tls_server_name {
t = t.with_server_name(n);
}
t = t.with_insecure(insecure);
let t = Arc::new(t);
connect_and_run(opts, t).await
}
}
TransportMode::Iroh => {
#[cfg(not(feature = "iroh"))]
{
return Err(anyhow!("iroh transport is not available (wraith-core built without 'iroh' feature)"));
}
#[cfg(feature = "iroh")]
{
use iroh::{NodeId, RelayUrl};
let node_id_str = peer
.as_deref()
.ok_or_else(|| anyhow!("--peer is required for iroh transport"))?;
let node_id: NodeId = node_id_str
.parse()
.map_err(|e| anyhow!("invalid iroh peer endpoint ID: {e}"))?;
let relay_url: Option<RelayUrl> = match iroh_relay.as_deref() {
Some(u) => Some(
u.parse()
.map_err(|e| anyhow!("invalid iroh relay URL: {e}"))?,
),
None => None,
};
let proxy_url: Option<url::Url> = match proxy.as_deref() {
Some(u) => Some(
u.parse()
.map_err(|e| anyhow!("invalid proxy URL: {e}"))?,
),
None => None,
};
let t = Arc::new(
IrohTransport::new(node_id, relay_url, proxy_url)
.await
.map_err(|e| anyhow!("failed to create iroh transport: {e}"))?,
);
connect_and_run(opts, t).await
}
}
}
}
}
}
async fn connect_and_run<T: Transport>(opts: ConnectOptions, transport: Arc<T>) -> Result<()> {
wraith_core::client::ClientSession::new(opts, transport)
.await
.map_err(|e| anyhow!("{e}"))?
.run()
.await
.map_err(|e| anyhow!("{e}"))
}