Compare commits
2 Commits
feat/serve
...
feat/clien
| Author | SHA1 | Date | |
|---|---|---|---|
| 128affd264 | |||
| f963898a05 |
727
crates/wraith-core/src/client/connect.rs
Normal file
727
crates/wraith-core/src/client/connect.rs
Normal file
@@ -0,0 +1,727 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use russh::client;
|
||||
use russh::keys::PrivateKey;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
|
||||
use crate::error::ConfigError;
|
||||
use crate::socks5::{HandleChannelOpener, Socks5Server};
|
||||
use crate::transport::Transport;
|
||||
|
||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TransportMode {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportMode::Tcp => write!(f, "tcp"),
|
||||
TransportMode::Tls => write!(f, "tls"),
|
||||
TransportMode::Iroh => write!(f, "iroh"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectOptions {
|
||||
pub server: Option<String>,
|
||||
pub peer: Option<String>,
|
||||
pub transport_mode: TransportMode,
|
||||
pub identity: KeySource,
|
||||
pub socks5_addr: String,
|
||||
pub forwards: Vec<String>,
|
||||
pub remote_forwards: Vec<String>,
|
||||
pub proxy: Option<String>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub tls_server_name: Option<String>,
|
||||
pub insecure: bool,
|
||||
}
|
||||
|
||||
impl ConnectOptions {
|
||||
pub fn new(identity: KeySource) -> Self {
|
||||
Self {
|
||||
server: None,
|
||||
peer: None,
|
||||
transport_mode: TransportMode::Tcp,
|
||||
identity,
|
||||
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
|
||||
forwards: Vec::new(),
|
||||
remote_forwards: Vec::new(),
|
||||
proxy: None,
|
||||
iroh_relay: None,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server(mut self, addr: impl Into<String>) -> Self {
|
||||
self.server = Some(addr.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
|
||||
self.peer = Some(endpoint_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
|
||||
self.transport_mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
|
||||
self.socks5_addr = addr.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.remote_forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn proxy(mut self, url: impl Into<String>) -> Self {
|
||||
self.proxy = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
|
||||
self.iroh_relay = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
match self.transport_mode {
|
||||
TransportMode::Tcp | TransportMode::Tls => {
|
||||
if self.server.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--server is required for tcp/tls transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
TransportMode::Iroh => {
|
||||
if self.peer.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--peer is required for iroh transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConnectOptions {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConnectOptions")
|
||||
.field("server", &self.server)
|
||||
.field("peer", &self.peer)
|
||||
.field("transport_mode", &self.transport_mode)
|
||||
.field("identity", &"<KeySource>")
|
||||
.field("socks5_addr", &self.socks5_addr)
|
||||
.field("forwards", &self.forwards)
|
||||
.field("remote_forwards", &self.remote_forwards)
|
||||
.field("proxy", &self.proxy)
|
||||
.field("iroh_relay", &self.iroh_relay)
|
||||
.field("tls_server_name", &self.tls_server_name)
|
||||
.field("insecure", &self.insecure)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientSession<T: Transport> {
|
||||
opts: ConnectOptions,
|
||||
transport: Arc<T>,
|
||||
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
#[allow(dead_code)]
|
||||
private_key: Arc<PrivateKey>,
|
||||
#[allow(dead_code)]
|
||||
username: String,
|
||||
shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl<T: Transport> ClientSession<T> {
|
||||
pub async fn new(
|
||||
opts: ConnectOptions,
|
||||
transport: Arc<T>,
|
||||
) -> Result<Self, ConnectError> {
|
||||
opts.validate().map_err(ConnectError::Config)?;
|
||||
|
||||
let auth_config = Arc::new(
|
||||
ClientAuthConfig::from_key_source(opts.identity.clone())
|
||||
.map_err(ConnectError::Config)?,
|
||||
);
|
||||
let private_key = auth_config.private_key();
|
||||
|
||||
let username = derive_username();
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SSH connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, &username)
|
||||
.await
|
||||
.map_err(|_| ConnectError::AuthFailed)?;
|
||||
if !auth_ok {
|
||||
return Err(ConnectError::AuthFailed);
|
||||
}
|
||||
|
||||
let handle = Arc::new(Mutex::new(handle));
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
Ok(Self {
|
||||
opts,
|
||||
transport,
|
||||
handle,
|
||||
auth_config,
|
||||
private_key,
|
||||
username,
|
||||
shutdown_tx,
|
||||
shutdown_rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
|
||||
Arc::clone(&self.handle)
|
||||
}
|
||||
|
||||
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
|
||||
&self.auth_config
|
||||
}
|
||||
|
||||
pub fn transport(&self) -> &Arc<T> {
|
||||
&self.transport
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.opts
|
||||
}
|
||||
|
||||
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
|
||||
self.shutdown_tx.clone()
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), ConnectError> {
|
||||
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
|
||||
})
|
||||
})?;
|
||||
|
||||
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
|
||||
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
|
||||
let socks5_listen = socks5_server.listen_addr();
|
||||
|
||||
let local_forwarders = build_local_forwarders(&self.opts)?;
|
||||
let remote_specs = build_remote_specs(&self.opts)?;
|
||||
|
||||
for spec in &remote_specs {
|
||||
let remote_forwarder = RemoteForwarder::new(spec.clone())
|
||||
.map_err(|_| ConnectError::ForwardFailed)?;
|
||||
let mut h = self.handle.lock().await;
|
||||
remote_forwarder
|
||||
.register(&mut h)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
warn!("failed to register remote forward {}", spec);
|
||||
ConnectError::ForwardFailed
|
||||
})?;
|
||||
info!("registered remote forward: {}", spec);
|
||||
}
|
||||
|
||||
let socks5_task = tokio::spawn(async move {
|
||||
debug!("SOCKS5 server starting on {}", socks5_listen);
|
||||
if let Err(e) = socks5_server.run().await {
|
||||
error!("SOCKS5 server error: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
let fwd_handle = Arc::clone(&self.handle);
|
||||
let fwd_shutdown = self.shutdown_rx.clone();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
crate::client::forward::run_local_forwarders(
|
||||
local_forwarders, fwd_handle, fwd_shutdown,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
info!("wraith client running: SOCKS5 on {}", socks5_listen);
|
||||
|
||||
#[cfg(unix)]
|
||||
let signal_done = {
|
||||
let sig_tx = self.shutdown_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut sigterm_stream =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install SIGTERM handler");
|
||||
tokio::select! {
|
||||
_ = sigterm_stream.recv() => {
|
||||
info!("received SIGTERM");
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("received SIGINT (Ctrl+C)");
|
||||
}
|
||||
}
|
||||
let _ = sig_tx.send(true);
|
||||
})
|
||||
};
|
||||
|
||||
let mut wait_shutdown = self.shutdown_rx.clone();
|
||||
tokio::select! {
|
||||
_ = wait_shutdown.changed() => {
|
||||
if *wait_shutdown.borrow() {
|
||||
info!("shutdown signal received");
|
||||
}
|
||||
}
|
||||
_ = socks5_task => {
|
||||
warn!("SOCKS5 server exited unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
signal_done.abort();
|
||||
|
||||
self.shutdown().await?;
|
||||
|
||||
forward_task.abort();
|
||||
let _ = forward_task.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> Result<(), ConnectError> {
|
||||
info!("initiating graceful shutdown");
|
||||
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
|
||||
{
|
||||
let handle = self.handle.lock().await;
|
||||
if !handle.is_closed() {
|
||||
if let Err(e) = handle
|
||||
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
|
||||
.await
|
||||
{
|
||||
warn!("failed to send SSH disconnect: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(DRAIN_TIMEOUT).await;
|
||||
|
||||
info!("graceful shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_username() -> String {
|
||||
std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "wraith".to_string())
|
||||
}
|
||||
|
||||
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
|
||||
let mut forwarders = Vec::new();
|
||||
for spec_str in &opts.forwards {
|
||||
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
|
||||
warn!("invalid local forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
forwarders.push(
|
||||
LocalForwarder::new(spec).map_err(|e| {
|
||||
warn!("failed to create local forwarder: {}", e);
|
||||
ConnectError::ForwardFailed
|
||||
})?,
|
||||
);
|
||||
}
|
||||
Ok(forwarders)
|
||||
}
|
||||
|
||||
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
|
||||
let mut specs = Vec::new();
|
||||
for spec_str in &opts.remote_forwards {
|
||||
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
|
||||
warn!("invalid remote forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid remote forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
specs.push(spec);
|
||||
}
|
||||
Ok(specs)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectError {
|
||||
#[error("connection failed")]
|
||||
ConnectionFailed,
|
||||
#[error("authentication failed")]
|
||||
AuthFailed,
|
||||
#[error("forward setup failed")]
|
||||
ForwardFailed,
|
||||
#[error("config error: {0}")]
|
||||
Config(#[from] ConfigError),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::io::duplex;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn make_identity() -> KeySource {
|
||||
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_default_fields() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(opts.server.is_none());
|
||||
assert!(opts.peer.is_none());
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tcp);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
|
||||
assert!(opts.forwards.is_empty());
|
||||
assert!(opts.remote_forwards.is_empty());
|
||||
assert!(opts.proxy.is_none());
|
||||
assert!(opts.iroh_relay.is_none());
|
||||
assert!(opts.tls_server_name.is_none());
|
||||
assert!(!opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_builder_pattern() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.server("example.com:22")
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.socks5_addr("127.0.0.1:9050")
|
||||
.forward("127.0.0.1:5432:db:5432")
|
||||
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
|
||||
.proxy("socks5://127.0.0.1:1080")
|
||||
.iroh_relay("https://relay.example.com")
|
||||
.tls_server_name("wraith.test")
|
||||
.insecure(true);
|
||||
|
||||
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tls);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
|
||||
assert_eq!(opts.forwards.len(), 1);
|
||||
assert_eq!(opts.remote_forwards.len(), 1);
|
||||
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
||||
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
|
||||
assert_eq!(opts.tls_server_name.as_deref(), Some("wraith.test"));
|
||||
assert!(opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.server("example.com:443");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_requires_peer() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_with_peer_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Iroh)
|
||||
.peer("some-endpoint-id");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_file() {
|
||||
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
|
||||
let opts = ConnectOptions::new(file_source);
|
||||
match &opts.identity {
|
||||
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
|
||||
_ => panic!("expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_memory() {
|
||||
let mem_source = KeySource::Memory(b"key-data".to_vec());
|
||||
let opts = ConnectOptions::new(mem_source);
|
||||
match &opts.identity {
|
||||
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
|
||||
_ => panic!("expected Memory variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_display() {
|
||||
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportMode::Tls.to_string(), "tls");
|
||||
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_error_variants() {
|
||||
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
|
||||
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
|
||||
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_debug_redacts_identity() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let debug_str = format!("{:?}", opts);
|
||||
assert!(debug_str.contains("<KeySource>"));
|
||||
assert!(!debug_str.contains("OPENSSH"));
|
||||
}
|
||||
|
||||
struct FailTransport;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
Err(anyhow::anyhow!("always fails"))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct DuplexTransport {
|
||||
connect_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for DuplexTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"duplex".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_transport_fails() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let transport = Arc::new(FailTransport);
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_ssh_handshake_fails() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_valid() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_invalid_spec() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_valid() {
|
||||
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_invalid() {
|
||||
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_socks5_addr() {
|
||||
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drain_timeout_is_two_seconds() {
|
||||
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_equality() {
|
||||
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
|
||||
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
|
||||
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_sends_disconnect_and_drains() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn socks5_is_always_enabled_by_default() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(!opts.socks5_addr.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_mock_transport_session() {
|
||||
use crate::socks5::{ChannelOpener, ChannelOpenError};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
struct MockOpener;
|
||||
|
||||
impl ChannelOpener for MockOpener {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
let (client, _server) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let bound_addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
|
||||
let opener = MockOpener;
|
||||
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
|
||||
|
||||
let _server_task = tokio::spawn(async move {
|
||||
let _ = server.run().await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
|
||||
|
||||
let greeting = [0x05, 0x01, 0x00];
|
||||
conn.write_all(&greeting).await.unwrap();
|
||||
|
||||
let mut auth_resp = [0u8; 2];
|
||||
conn.read_exact(&mut auth_resp).await.unwrap();
|
||||
assert_eq!(auth_resp, [0x05, 0x00]);
|
||||
|
||||
let connect_req = [
|
||||
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
|
||||
];
|
||||
conn.write_all(&connect_req).await.unwrap();
|
||||
|
||||
let mut reply = [0u8; 10];
|
||||
conn.read_exact(&mut reply).await.unwrap();
|
||||
assert_eq!(reply[1], 0x00);
|
||||
|
||||
conn.write_all(b"test data").await.unwrap();
|
||||
conn.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -125,7 +125,7 @@ impl LocalForwarder {
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
) -> Result<(), ForwardError> {
|
||||
let listen_addr = self.spec.listen_addr()?;
|
||||
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||
let listener = TcpListener::bind(listen_addr)
|
||||
.await
|
||||
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||
self.listener = Some(listener);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod channel_manager;
|
||||
pub mod connect;
|
||||
pub mod forward;
|
||||
|
||||
pub use channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
|
||||
@@ -62,7 +62,7 @@ pub enum ConfigError {
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ForwardError {
|
||||
#[error("invalid forward spec: {spec}")]
|
||||
#[error("invalid forward specification: {spec}")]
|
||||
InvalidSpec { spec: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
|
||||
@@ -10,4 +10,5 @@ pub mod testutil;
|
||||
|
||||
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
186
crates/wraith-core/src/server/control_channel.rs
Normal file
186
crates/wraith-core/src/server/control_channel.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub const WRAITH_CONTROL_DESTINATION: &str = "wraith-control";
|
||||
pub const WRAITH_PREFIX: &str = "wraith-";
|
||||
|
||||
pub fn is_reserved_destination(host: &str) -> bool {
|
||||
host.starts_with(WRAITH_PREFIX)
|
||||
}
|
||||
|
||||
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ControlChannelHandler: Send + Sync {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
|
||||
}
|
||||
|
||||
pub struct ControlChannelRouter {
|
||||
handler: Option<Box<dyn ControlChannelHandler>>,
|
||||
}
|
||||
|
||||
impl ControlChannelRouter {
|
||||
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
|
||||
pub fn without_handler() -> Self {
|
||||
Self { handler: None }
|
||||
}
|
||||
|
||||
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
Self {
|
||||
handler: Some(handler),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_handler(&self) -> bool {
|
||||
self.handler.is_some()
|
||||
}
|
||||
|
||||
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
|
||||
match &self.handler {
|
||||
Some(handler) => {
|
||||
handler.handle_channel(stream).await;
|
||||
Ok(())
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionRefused,
|
||||
"no control channel handler configured",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn wraith_control_destination_constant() {
|
||||
assert_eq!(WRAITH_CONTROL_DESTINATION, "wraith-control");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wraith_prefix_constant() {
|
||||
assert_eq!(WRAITH_PREFIX, "wraith-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_destination_detected() {
|
||||
assert!(is_reserved_destination("wraith-control"));
|
||||
assert!(is_reserved_destination("wraith-status"));
|
||||
assert!(is_reserved_destination("wraith-events"));
|
||||
assert!(is_reserved_destination("wraith-"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_reserved_destination_passes_through() {
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("192.168.1.1"));
|
||||
assert!(!is_reserved_destination("wraith.example.com"));
|
||||
assert!(!is_reserved_destination(""));
|
||||
assert!(!is_reserved_destination("wrait-control"));
|
||||
assert!(!is_reserved_destination("WRAITH-control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_matching_case_sensitive() {
|
||||
assert!(!is_reserved_destination("Wraith-control"));
|
||||
assert!(!is_reserved_destination("WRAITH-control"));
|
||||
assert!(is_reserved_destination("wraith-Control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_without_handler_has_no_handler() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
assert!(!router.has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_with_handler_has_handler() {
|
||||
struct DummyHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for DummyHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
|
||||
}
|
||||
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
|
||||
assert!(router.has_handler());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_without_handler_returns_error() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_succeeds() {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TrackedHandler {
|
||||
called: Arc<AtomicBool>,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackedHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let called = Arc::new(AtomicBool::new(false));
|
||||
let handler = TrackedHandler {
|
||||
called: called.clone(),
|
||||
};
|
||||
let router = ControlChannelRouter::with_handler(Box::new(handler));
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_can_read_write() {
|
||||
struct EchoHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for EchoHandler {
|
||||
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
|
||||
let mut buf = [0u8; 64];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
stream.write_all(&buf[..n]).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
|
||||
let (client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
tokio::spawn(async move {
|
||||
router.route(stream).await.unwrap();
|
||||
});
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
let mut client = client;
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_destination_matches_prefix() {
|
||||
assert!(is_reserved_destination(WRAITH_CONTROL_DESTINATION));
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
@@ -8,9 +7,9 @@ use russh::server::{Auth, Handler, Msg, Session};
|
||||
use russh::Channel;
|
||||
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
|
||||
const WRAITH_PREFIX: &str = "wraith-";
|
||||
use crate::server::control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyMode {
|
||||
@@ -24,32 +23,11 @@ pub struct ProxyConfig {
|
||||
pub mode: ProxyMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportKind::Tcp => write!(f, "tcp"),
|
||||
TransportKind::Tls => write!(f, "tls"),
|
||||
TransportKind::Iroh => write!(f, "iroh"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerHandler {
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
connected_at: Instant,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
@@ -57,66 +35,25 @@ impl ServerHandler {
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
auth_config,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
connected_at: Instant::now(),
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_connection_allowed(&self) -> bool {
|
||||
self.connection_allowed
|
||||
pub fn with_control_channel_handler(
|
||||
mut self,
|
||||
handler: Box<dyn ControlChannelHandler>,
|
||||
) -> Self {
|
||||
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||
self.remote_addr.map(|a| a.ip())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
}
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
pub fn control_channel_router(&self) -> &ControlChannelRouter {
|
||||
&self.control_channel_router
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,23 +66,6 @@ impl Handler for ServerHandler {
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
@@ -158,7 +78,6 @@ impl Handler for ServerHandler {
|
||||
Ok(()) => {
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
@@ -166,10 +85,8 @@ impl Handler for ServerHandler {
|
||||
Ok(Auth::Accept)
|
||||
}
|
||||
Err(_) => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
@@ -196,6 +113,16 @@ impl Handler for ServerHandler {
|
||||
port = port_to_connect,
|
||||
"routing to internal control channel handler"
|
||||
);
|
||||
|
||||
if !self.control_channel_router.has_handler() {
|
||||
tracing::warn!(
|
||||
host = host_to_connect,
|
||||
"no control channel handler configured, rejecting channel open"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let _ = channel;
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
@@ -286,22 +213,10 @@ mod tests {
|
||||
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
||||
}
|
||||
|
||||
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||
Arc::new(ConnectionRateLimiter::new(0))
|
||||
}
|
||||
|
||||
fn make_handler(
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
) -> ServerHandler {
|
||||
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_accepts_known_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
@@ -311,7 +226,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_rejects_unknown_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
|
||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key = russh::keys::parse_public_key_base64(
|
||||
@@ -334,7 +249,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_empty_config_rejects_all() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler
|
||||
@@ -353,7 +268,7 @@ mod tests {
|
||||
async fn auth_logging_includes_remote_addr() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
@@ -361,12 +276,20 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn reserved_wraith_destination_routing() {
|
||||
assert!("wraith-control".starts_with(WRAITH_PREFIX));
|
||||
assert!("wraith-status".starts_with(WRAITH_PREFIX));
|
||||
assert!("wraith-events".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"example.com".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"localhost".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
|
||||
use crate::server::control_channel::is_reserved_destination;
|
||||
assert!(is_reserved_destination("wraith-control"));
|
||||
assert!(is_reserved_destination("wraith-status"));
|
||||
assert!(is_reserved_destination("wraith-events"));
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("wraith.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_without_control_handler_rejects_wraith_destinations() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler = ServerHandler::new(auth_config, None, None);
|
||||
assert!(!handler.control_channel_router().has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -397,7 +320,7 @@ mod tests {
|
||||
});
|
||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||
|
||||
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
|
||||
assert!(handler.outbound_proxy.is_some());
|
||||
assert!(handler.remote_addr.is_some());
|
||||
}
|
||||
@@ -405,108 +328,9 @@ mod tests {
|
||||
#[test]
|
||||
fn one_handler_per_connection() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||
|
||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
2,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
|
||||
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
|
||||
|
||||
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
|
||||
|
||||
assert!(!handler.auth_limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_rate_limit_blocks_over_limit() {
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let auth_config = make_empty_auth_config();
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(h1.is_connection_allowed());
|
||||
|
||||
let h2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(!h2.is_connection_allowed());
|
||||
|
||||
drop(h1);
|
||||
|
||||
let h3 = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(h3.is_connection_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_display() {
|
||||
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportKind::Tls.to_string(), "tls");
|
||||
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_log_includes_user_field() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tls,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_logs_duration_on_drop() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let _handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
pub mod control_channel;
|
||||
pub mod handler;
|
||||
pub mod rate_limit;
|
||||
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
pub use control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
|
||||
WRAITH_PREFIX, is_reserved_destination,
|
||||
};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
@@ -1,193 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub struct ConnectionRateLimiter {
|
||||
max_per_ip: usize,
|
||||
active: Mutex<HashMap<IpAddr, usize>>,
|
||||
}
|
||||
|
||||
impl ConnectionRateLimiter {
|
||||
pub fn new(max_per_ip: usize) -> Self {
|
||||
Self {
|
||||
max_per_ip,
|
||||
active: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, ip: IpAddr) -> bool {
|
||||
if self.max_per_ip == 0 {
|
||||
return true;
|
||||
}
|
||||
let active = self.active.lock().unwrap();
|
||||
let count = active.get(&ip).copied().unwrap_or(0);
|
||||
count < self.max_per_ip
|
||||
}
|
||||
|
||||
pub fn on_connect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
*active.entry(ip).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
if let Some(count) = active.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
active.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthAttemptLimiter {
|
||||
max_attempts: usize,
|
||||
failures: usize,
|
||||
}
|
||||
|
||||
impl AuthAttemptLimiter {
|
||||
pub fn new(max_attempts: usize) -> Self {
|
||||
Self {
|
||||
max_attempts,
|
||||
failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self) -> bool {
|
||||
if self.max_attempts == 0 {
|
||||
return true;
|
||||
}
|
||||
self.failures < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn on_failure(&mut self) {
|
||||
self.failures += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
fn ip(n: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_when_under_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_blocks_when_at_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_after_disconnect() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
limiter.on_disconnect(ip(1));
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_unlimited_when_zero() {
|
||||
let limiter = ConnectionRateLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_connect(ip(1));
|
||||
}
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_tracks_per_ip_independently() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
assert!(limiter.check(ip(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_ipv6() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||
limiter.on_connect(ip6);
|
||||
assert!(!limiter.check(ip6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_disconnect(ip(1));
|
||||
{
|
||||
let active = limiter.active.lock().unwrap();
|
||||
assert!(!active.contains_key(&ip(1)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_allows_when_under_limit() {
|
||||
let limiter = AuthAttemptLimiter::new(3);
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_blocks_after_max_failures() {
|
||||
let mut limiter = AuthAttemptLimiter::new(2);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_unlimited_when_zero() {
|
||||
let mut limiter = AuthAttemptLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_failure();
|
||||
}
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||
let mut limiter = AuthAttemptLimiter::new(3);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(limiter.check());
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_thread_safety() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let lim = Arc::clone(&limiter);
|
||||
handles.push(thread::spawn(move || {
|
||||
let ip_addr = ip((i % 3) as u8 + 1);
|
||||
lim.on_connect(ip_addr);
|
||||
assert!(lim.check(ip_addr));
|
||||
lim.on_disconnect(ip_addr);
|
||||
}));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user