diff --git a/Cargo.lock b/Cargo.lock index 11c430d..5b099c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3863,6 +3863,15 @@ dependencies = [ "x509-parser 0.16.0", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.1" @@ -5586,6 +5595,10 @@ dependencies = [ "anyhow", "clap", "iroh", + "rustls", + "rustls-acme", + "rustls-pemfile", + "rustls-pki-types", "tokio", "url", "wraith-core", diff --git a/crates/wraith/Cargo.toml b/crates/wraith/Cargo.toml index bb52c54..312dfb5 100644 --- a/crates/wraith/Cargo.toml +++ b/crates/wraith/Cargo.toml @@ -9,8 +9,9 @@ path = "src/main.rs" [features] default = ["tls", "iroh"] -tls = ["wraith-core/tls"] +tls = ["wraith-core/tls", "dep:rustls-pemfile", "dep:rustls-pki-types"] iroh = ["wraith-core/iroh", "dep:iroh", "dep:url"] +acme = ["wraith-core/acme", "dep:rustls-acme", "dep:rustls", "tls"] [dependencies] wraith-core = { path = "../wraith-core" } @@ -18,4 +19,8 @@ clap = { version = "4", features = ["derive", "env"] } tokio = { version = "1", features = ["full"] } anyhow = "1" iroh = { version = "0.34", optional = true } -url = { version = "2", optional = true } \ No newline at end of file +url = { version = "2", optional = true } +rustls-acme = { version = "0.12", optional = true } +rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] } +rustls-pemfile = { version = "2", optional = true } +rustls-pki-types = { version = "1", optional = true } \ No newline at end of file diff --git a/crates/wraith/src/main.rs b/crates/wraith/src/main.rs index eea1d15..24587be 100644 --- a/crates/wraith/src/main.rs +++ b/crates/wraith/src/main.rs @@ -6,15 +6,16 @@ use anyhow::{anyhow, Result}; use clap::{Parser, Subcommand, ValueEnum}; use wraith_core::auth::keys::KeySource; use wraith_core::client::{ConnectOptions, TransportMode}; +use wraith_core::server::{ServeOptions, ServeTransportMode, Server}; +#[cfg(feature = "iroh")] +use wraith_core::transport::IrohTransport; use wraith_core::transport::TcpTransport; #[cfg(feature = "tls")] use wraith_core::transport::TlsTransport; -#[cfg(feature = "iroh")] -use wraith_core::transport::IrohTransport; use wraith_core::transport::Transport; #[derive(Parser)] -#[command(name = "wraith", version, about = "Wraith SSH tunnel client")] +#[command(name = "wraith", version, about = "Wraith SSH tunnel tool")] struct Cli { #[command(subcommand)] command: Commands, @@ -22,12 +23,21 @@ struct Cli { #[derive(Subcommand)] enum Commands { - #[command(about = "Connect to a wraith server and start a SOCKS5 proxy / port forwarding session")] + #[command( + about = "Connect to a wraith server and start a SOCKS5 proxy / port forwarding session" + )] Connect { - #[arg(long, help = "TCP/TLS server address (required for tcp/tls transport)", env = "WRAITH_SERVER")] + #[arg( + long, + help = "TCP/TLS server address (required for tcp/tls transport)", + env = "WRAITH_SERVER" + )] server: Option, - #[arg(long, help = "iroh endpoint ID, base58-encoded (required for iroh transport)")] + #[arg( + long, + help = "iroh endpoint ID, base58-encoded (required for iroh transport)" + )] peer: Option, #[arg(long, value_enum, default_value = "tcp", help = "Transport mode")] @@ -57,6 +67,68 @@ enum Commands { #[arg(long, help = "Accept self-signed TLS certs")] insecure: bool, }, + + #[command(about = "Start the wraith server (accept SSH connections)")] + Serve { + #[arg(long, help = "SSH host key path (required)")] + key: String, + + #[arg(long, help = "Authorized keys file path")] + authorized_keys: Option, + + #[arg(long, help = "CA public key for certificate authority auth")] + cert_authority: Option, + + #[arg( + long, + value_enum, + default_value = "tcp", + help = "Transport mode (tcp, tls, iroh)" + )] + transport: ServeTransportModeArg, + + #[arg( + long, + default_value = "0.0.0.0:22", + help = "Listen address for TCP/TLS" + )] + listen: String, + + #[arg(long, help = "TLS certificate path (manual)")] + tls_cert: Option, + + #[arg(long, help = "TLS private key path (manual)")] + tls_key: Option, + + #[arg(long, help = "ACME auto-cert domain")] + acme_domain: Option, + + #[arg( + long, + help = "Serve fake nginx 404 to non-SSH connections (requires --transport tls)" + )] + stealth: bool, + + #[arg(long, help = "Outbound proxy URL (socks5:// or http://)")] + proxy: Option, + + #[arg(long, help = "iroh relay server URL")] + iroh_relay: Option, + + #[arg( + long, + default_value_t = 0, + help = "Max concurrent connections per IP (0 = unlimited)" + )] + max_connections_per_ip: usize, + + #[arg( + long, + default_value_t = 10, + help = "Max auth failures before disconnect" + )] + max_auth_attempts: usize, + }, } #[derive(Clone, Debug, ValueEnum)] @@ -76,6 +148,23 @@ impl From for TransportMode { } } +#[derive(Clone, Debug, ValueEnum)] +enum ServeTransportModeArg { + Tcp, + Tls, + Iroh, +} + +impl From for ServeTransportMode { + fn from(val: ServeTransportModeArg) -> Self { + match val { + ServeTransportModeArg::Tcp => ServeTransportMode::Tcp, + ServeTransportModeArg::Tls => ServeTransportMode::Tls, + ServeTransportModeArg::Iroh => ServeTransportMode::Iroh, + } + } +} + #[tokio::main] async fn main() { if let Err(e) = run().await { @@ -101,114 +190,177 @@ async fn run() -> Result<()> { tls_server_name, insecure, } => { - let identity_val = identity - .ok_or_else(|| anyhow!("--identity is required (or set WRAITH_IDENTITY env var)"))?; - let key_source = KeySource::File(identity_val.into()); + run_connect( + server, + peer, + transport, + identity, + socks5, + forward, + remote_forward, + proxy, + iroh_relay, + tls_server_name, + insecure, + ) + .await + } + Commands::Serve { + key, + authorized_keys, + cert_authority, + transport, + listen, + tls_cert, + tls_key, + acme_domain, + stealth, + proxy, + iroh_relay, + max_connections_per_ip, + max_auth_attempts, + } => { + run_serve( + key, + authorized_keys, + cert_authority, + transport, + listen, + tls_cert, + tls_key, + acme_domain, + stealth, + proxy, + iroh_relay, + max_connections_per_ip, + max_auth_attempts, + ) + .await + } + } +} - let transport_mode: TransportMode = transport.into(); +#[allow(clippy::too_many_arguments)] +async fn run_connect( + server: Option, + peer: Option, + transport: TransportModeArg, + identity: Option, + socks5: String, + forward: Vec, + remote_forward: Vec, + proxy: Option, + iroh_relay: Option, + tls_server_name: Option, + insecure: bool, +) -> Result<()> { + let identity_val = identity + .ok_or_else(|| anyhow!("--identity is required (or set WRAITH_IDENTITY env var)"))?; + let key_source = KeySource::File(identity_val.into()); - if proxy.is_some() && matches!(transport_mode, TransportMode::Tcp) { - eprintln!("warning: --proxy with --transport tcp is effectively a no-op (TCP transport is already a direct connection); use the SOCKS5 server instead"); - } + let transport_mode: TransportMode = transport.into(); - let mut opts = ConnectOptions::new(key_source) - .transport_mode(transport_mode.clone()) - .socks5_addr(&socks5); + if proxy.is_some() && matches!(transport_mode, TransportMode::Tcp) { + eprintln!("warning: --proxy with --transport tcp is effectively a no-op (TCP transport is already a direct connection); use the SOCKS5 server instead"); + } - if let Some(ref s) = server { - opts = opts.server(s); - } - if let Some(ref p) = peer { - opts = opts.peer(p); - } - for fwd in &forward { - opts = opts.forward(fwd); - } - for rfwd in &remote_forward { - opts = opts.remote_forward(rfwd); - } - if let Some(ref p) = proxy { - opts = opts.proxy(p); - } - if let Some(ref r) = iroh_relay { - opts = opts.iroh_relay(r); - } - if let Some(ref n) = tls_server_name { - opts = opts.tls_server_name(n); - } - if insecure { - opts = opts.insecure(true); - } + let mut opts = ConnectOptions::new(key_source) + .transport_mode(transport_mode.clone()) + .socks5_addr(&socks5); - opts.validate().map_err(|e| anyhow!("{e}"))?; + if let Some(ref s) = server { + opts = opts.server(s); + } + if let Some(ref p) = peer { + opts = opts.peer(p); + } + for fwd in &forward { + opts = opts.forward(fwd); + } + for rfwd in &remote_forward { + opts = opts.remote_forward(rfwd); + } + if let Some(ref p) = proxy { + opts = opts.proxy(p); + } + if let Some(ref r) = iroh_relay { + opts = opts.iroh_relay(r); + } + if let Some(ref n) = tls_server_name { + opts = opts.tls_server_name(n); + } + if insecure { + opts = opts.insecure(true); + } - match transport_mode { - TransportMode::Tcp => { - let addr: SocketAddr = server - .as_deref() - .ok_or_else(|| anyhow!("--server is required for tcp transport"))? - .parse() - .map_err(|e| anyhow!("invalid server address: {e}"))?; - let t = Arc::new(TcpTransport::new(addr)); - connect_and_run(opts, t).await - } - TransportMode::Tls => { - #[cfg(not(feature = "tls"))] - { - return Err(anyhow!("TLS transport is not available (wraith-core built without 'tls' feature)")); - } - #[cfg(feature = "tls")] - { - let addr: SocketAddr = server - .as_deref() - .ok_or_else(|| anyhow!("--server is required for tls transport"))? - .parse() - .map_err(|e| anyhow!("invalid server address: {e}"))?; - let mut t = TlsTransport::new(addr); - if let Some(ref n) = tls_server_name { - t = t.with_server_name(n); - } - t = t.with_insecure(insecure); - let t = Arc::new(t); - connect_and_run(opts, t).await - } - } - TransportMode::Iroh => { - #[cfg(not(feature = "iroh"))] - { - return Err(anyhow!("iroh transport is not available (wraith-core built without 'iroh' feature)")); - } - #[cfg(feature = "iroh")] - { - use iroh::{NodeId, RelayUrl}; - let node_id_str = peer - .as_deref() - .ok_or_else(|| anyhow!("--peer is required for iroh transport"))?; - let node_id: NodeId = node_id_str - .parse() - .map_err(|e| anyhow!("invalid iroh peer endpoint ID: {e}"))?; - let relay_url: Option = match iroh_relay.as_deref() { - Some(u) => Some( - u.parse() - .map_err(|e| anyhow!("invalid iroh relay URL: {e}"))?, - ), - None => None, - }; - let proxy_url: Option = match proxy.as_deref() { - Some(u) => Some( - u.parse() - .map_err(|e| anyhow!("invalid proxy URL: {e}"))?, - ), - None => None, - }; - let t = Arc::new( - IrohTransport::new(node_id, relay_url, proxy_url) - .await - .map_err(|e| anyhow!("failed to create iroh transport: {e}"))?, - ); - connect_and_run(opts, t).await - } + opts.validate().map_err(|e| anyhow!("{e}"))?; + + match transport_mode { + TransportMode::Tcp => { + let addr: SocketAddr = server + .as_deref() + .ok_or_else(|| anyhow!("--server is required for tcp transport"))? + .parse() + .map_err(|e| anyhow!("invalid server address: {e}"))?; + let t = Arc::new(TcpTransport::new(addr)); + connect_and_run(opts, t).await + } + TransportMode::Tls => { + #[cfg(not(feature = "tls"))] + { + Err(anyhow!( + "TLS transport is not available (wraith-core built without 'tls' feature)" + )) + } + #[cfg(feature = "tls")] + { + let addr: SocketAddr = server + .as_deref() + .ok_or_else(|| anyhow!("--server is required for tls transport"))? + .parse() + .map_err(|e| anyhow!("invalid server address: {e}"))?; + let mut t = TlsTransport::new(addr); + if let Some(ref n) = tls_server_name { + t = t.with_server_name(n); } + t = t.with_insecure(insecure); + let t = Arc::new(t); + connect_and_run(opts, t).await + } + } + TransportMode::Iroh => { + #[cfg(not(feature = "iroh"))] + { + Err(anyhow!( + "iroh transport is not available (wraith-core built without 'iroh' feature)" + )) + } + #[cfg(feature = "iroh")] + { + use iroh::{NodeId, RelayUrl}; + let node_id_str = peer + .as_deref() + .ok_or_else(|| anyhow!("--peer is required for iroh transport"))?; + let node_id: NodeId = node_id_str + .parse() + .map_err(|e| anyhow!("invalid iroh peer endpoint ID: {e}"))?; + let relay_url: Option = match iroh_relay.as_deref() { + Some(u) => Some( + u.parse() + .map_err(|e| anyhow!("invalid iroh relay URL: {e}"))?, + ), + None => None, + }; + let proxy_url: Option = match proxy.as_deref() { + Some(u) => Some(u.parse().map_err(|e| anyhow!("invalid proxy URL: {e}"))?), + None => None, + }; + let t = Arc::new( + IrohTransport::new(node_id, relay_url, proxy_url) + .await + .map_err(|e| anyhow!("failed to create iroh transport: {e}"))?, + ); + connect_and_run(opts, t).await } } } @@ -221,4 +373,168 @@ async fn connect_and_run(opts: ConnectOptions, transport: Arc) .run() .await .map_err(|e| anyhow!("{e}")) -} \ No newline at end of file +} + +#[allow(clippy::too_many_arguments)] +async fn run_serve( + key: String, + authorized_keys: Option, + cert_authority: Option, + transport: ServeTransportModeArg, + listen: String, + tls_cert: Option, + tls_key: Option, + acme_domain: Option, + stealth: bool, + proxy: Option, + iroh_relay: Option, + max_connections_per_ip: usize, + max_auth_attempts: usize, +) -> Result<()> { + let transport_mode: ServeTransportMode = transport.into(); + + if acme_domain.is_some() { + #[cfg(not(feature = "acme"))] + { + return Err(anyhow!( + "ACME support is not available (wraith built without 'acme' feature)" + )); + } + } + + if stealth && transport_mode != ServeTransportMode::Tls { + return Err(anyhow!( + "stealth mode requires TLS transport (--transport tls)" + )); + } + + let mut opts = ServeOptions::new(KeySource::File(key.into())) + .transport_mode(transport_mode.clone()) + .listen_addr(&listen) + .stealth(stealth) + .max_connections_per_ip(max_connections_per_ip) + .max_auth_attempts(max_auth_attempts); + + if let Some(ref path) = authorized_keys { + opts = opts.authorized_keys(KeySource::File(path.into())); + } + if let Some(ref path) = cert_authority { + opts = opts.cert_authority(KeySource::File(path.into())); + } + if let Some(ref path) = tls_cert { + opts = opts.tls_cert(path); + } + if let Some(ref path) = tls_key { + opts = opts.tls_key(path); + } + if let Some(ref domain) = acme_domain { + opts = opts.acme_domain(domain); + } + if let Some(ref url) = proxy { + opts = opts.proxy(url); + } + if let Some(ref url) = iroh_relay { + opts = opts.iroh_relay(url); + } + + opts.validate().map_err(|e| anyhow!("{e}"))?; + + let server = Server::new(opts).map_err(|e| anyhow!("{e}"))?; + + match transport_mode { + ServeTransportMode::Tcp => { + let addr: SocketAddr = listen + .parse() + .map_err(|e| anyhow!("invalid listen address: {e}"))?; + let acceptor = wraith_core::transport::TcpAcceptor::bind(addr) + .await + .map_err(|e| anyhow!("bind failed: {e}"))?; + server.run(acceptor, None).await.map_err(|e| anyhow!("{e}")) + } + ServeTransportMode::Tls => { + #[cfg(not(feature = "tls"))] + { + Err(anyhow!( + "TLS transport is not available (wraith-core built without 'tls' feature)" + )) + } + #[cfg(feature = "acme")] + { + if let Some(ref domain) = acme_domain { + let addr: SocketAddr = listen + .parse() + .map_err(|e| anyhow!("invalid listen address: {e}"))?; + let provider = Arc::new( + wraith_core::transport::AcmeCertProvider::domain(domain) + .with_production_directory(), + ); + let acceptor = + wraith_core::transport::AcmeTlsAcceptor::bind_acme(addr, provider) + .await + .map_err(|e| anyhow!("ACME bind failed: {e}"))?; + return server.run(acceptor, None).await.map_err(|e| anyhow!("{e}")); + } + } + #[cfg(feature = "tls")] + { + use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + let addr: SocketAddr = listen + .parse() + .map_err(|e| anyhow!("invalid listen address: {e}"))?; + let cert_path = tls_cert.ok_or_else(|| { + anyhow!("--tls-cert is required for TLS transport (or use --acme-domain)") + })?; + let key_path = tls_key.ok_or_else(|| { + anyhow!("--tls-key is required for TLS transport (or use --acme-domain)") + })?; + let cert_data = std::fs::read(&cert_path) + .map_err(|e| anyhow!("failed to read TLS cert '{}': {e}", cert_path))?; + let key_data = std::fs::read(&key_path) + .map_err(|e| anyhow!("failed to read TLS key '{}': {e}", key_path))?; + let certs: Vec> = + rustls_pemfile::certs(&mut &cert_data[..]) + .collect::, _>>() + .map_err(|e| anyhow!("failed to parse TLS certificates: {e}"))?; + let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut &key_data[..]) + .map_err(|e| anyhow!("failed to parse TLS private key: {e}"))? + .ok_or_else(|| anyhow!("no private key found in {}", key_path))?; + let acceptor = wraith_core::transport::TlsAcceptor::bind(addr, certs, key, None) + .await + .map_err(|e| anyhow!("TLS bind failed: {e}"))?; + server.run(acceptor, None).await.map_err(|e| anyhow!("{e}")) + } + } + ServeTransportMode::Iroh => { + #[cfg(not(feature = "iroh"))] + { + Err(anyhow!( + "iroh transport is not available (wraith-core built without 'iroh' feature)" + )) + } + #[cfg(feature = "iroh")] + { + use iroh::RelayUrl; + let relay_url: Option = match iroh_relay.as_deref() { + Some(u) => Some( + u.parse() + .map_err(|e| anyhow!("invalid iroh relay URL: {e}"))?, + ), + None => None, + }; + let proxy_url: Option = match proxy.as_deref() { + Some(u) => Some(u.parse().map_err(|e| anyhow!("invalid proxy URL: {e}"))?), + None => None, + }; + let acceptor = wraith_core::transport::IrohAcceptor::bind(relay_url, proxy_url) + .await + .map_err(|e| anyhow!("iroh bind failed: {e}"))?; + let endpoint_id = acceptor.endpoint_id(); + eprintln!("iroh endpoint ID: {endpoint_id}"); + server + .run(acceptor, Some(&endpoint_id)) + .await + .map_err(|e| anyhow!("{e}")) + } + } + } +}