From 24b70f56515dd58c323b9c2f8196dca95c1f2c5d Mon Sep 17 00:00:00 2001 From: "glm-5.1" Date: Tue, 2 Jun 2026 11:02:55 +0000 Subject: [PATCH] Implement server rate limiting and fail2ban-friendly structured logging Add ConnectionRateLimiter (HashMap) and AuthAttemptLimiter with check/on_connect/on_disconnect and check/on_failure methods. Integrate into ServerHandler with structured tracing::info! logging for auth attempts, connection opened/closed events. No logging of tunnel destinations per ADR-006. Also add ForwardError type and fix type annotation in forward.rs to unblock compilation. --- crates/wraith-core/src/client/forward.rs | 2 +- crates/wraith-core/src/error.rs | 16 ++ crates/wraith-core/src/server/handler.rs | 225 +++++++++++++++++++- crates/wraith-core/src/server/mod.rs | 4 +- crates/wraith-core/src/server/rate_limit.rs | 193 +++++++++++++++++ 5 files changed, 430 insertions(+), 10 deletions(-) create mode 100644 crates/wraith-core/src/server/rate_limit.rs diff --git a/crates/wraith-core/src/client/forward.rs b/crates/wraith-core/src/client/forward.rs index b8987f9..eea3de0 100644 --- a/crates/wraith-core/src/client/forward.rs +++ b/crates/wraith-core/src/client/forward.rs @@ -125,7 +125,7 @@ impl LocalForwarder { handle: Arc>>, ) -> Result<(), ForwardError> { let listen_addr = self.spec.listen_addr()?; - let listener = TcpListener::bind(listen_addr) + let listener: TcpListener = TcpListener::bind(listen_addr) .await .map_err(|e| ForwardError::BindFailed { source: e })?; self.listener = Some(listener); diff --git a/crates/wraith-core/src/error.rs b/crates/wraith-core/src/error.rs index 3b4c152..bf22b38 100644 --- a/crates/wraith-core/src/error.rs +++ b/crates/wraith-core/src/error.rs @@ -60,6 +60,22 @@ pub enum ConfigError { IncompatibleOptions, } +#[derive(Debug, thiserror::Error)] +pub enum ForwardError { + #[error("invalid forward spec: {spec}")] + InvalidSpec { spec: String }, + #[error("bind failed")] + BindFailed { + #[source] + source: io::Error, + }, + #[error("channel open failed")] + ChannelOpenFailed { + #[source] + source: Box, + }, +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/wraith-core/src/server/handler.rs b/crates/wraith-core/src/server/handler.rs index bf226c9..f931fc3 100644 --- a/crates/wraith-core/src/server/handler.rs +++ b/crates/wraith-core/src/server/handler.rs @@ -1,5 +1,6 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; +use std::time::Instant; use async_trait::async_trait; use russh::keys::ssh_key::HashAlg; @@ -7,6 +8,7 @@ use russh::server::{Auth, Handler, Msg, Session}; use russh::Channel; use crate::auth::ServerAuthConfig; +use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; const WRAITH_PREFIX: &str = "wraith-"; @@ -22,10 +24,32 @@ pub struct ProxyConfig { pub mode: ProxyMode, } +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TransportKind { + Tcp, + Tls, + Iroh, +} + +impl std::fmt::Display for TransportKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransportKind::Tcp => write!(f, "tcp"), + TransportKind::Tls => write!(f, "tls"), + TransportKind::Iroh => write!(f, "iroh"), + } + } +} + pub struct ServerHandler { auth_config: Arc, outbound_proxy: Option, remote_addr: Option, + transport: TransportKind, + connection_limiter: Arc, + connection_allowed: bool, + auth_limiter: AuthAttemptLimiter, + connected_at: Instant, } impl ServerHandler { @@ -33,11 +57,65 @@ impl ServerHandler { auth_config: Arc, outbound_proxy: Option, remote_addr: Option, + transport: TransportKind, + connection_limiter: Arc, + max_auth_attempts: usize, ) -> Self { + let allowed = if let Some(addr) = remote_addr { + let ip = addr.ip(); + if connection_limiter.check(ip) { + connection_limiter.on_connect(ip); + tracing::info!( + remote_addr = %addr, + transport = %transport, + "connection opened" + ); + true + } else { + tracing::info!( + remote_addr = %addr, + transport = %transport, + "connection rejected" + ); + false + } + } else { + true + }; + Self { auth_config, outbound_proxy, remote_addr, + transport, + connection_limiter, + connection_allowed: allowed, + auth_limiter: AuthAttemptLimiter::new(max_auth_attempts), + connected_at: Instant::now(), + } + } + + pub fn is_connection_allowed(&self) -> bool { + self.connection_allowed + } + + pub fn remote_ip(&self) -> Option { + self.remote_addr.map(|a| a.ip()) + } +} + +impl Drop for ServerHandler { + fn drop(&mut self) { + if let Some(addr) = self.remote_addr { + if self.connection_allowed { + self.connection_limiter.on_disconnect(addr.ip()); + } + let duration = self.connected_at.elapsed(); + tracing::info!( + remote_addr = %addr, + duration_secs = duration.as_secs_f64(), + "connection closed" + ); } } } @@ -51,6 +129,23 @@ impl Handler for ServerHandler { user: &str, public_key: &russh::keys::ssh_key::PublicKey, ) -> Result { + if !self.auth_limiter.check() { + let remote_addr_display = self + .remote_addr + .map_or("unknown".to_string(), |a| a.to_string()); + let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); + tracing::info!( + remote_addr = %remote_addr_display, + user = user, + key_fingerprint = %fingerprint, + result = "reject", + "auth attempt" + ); + return Ok(Auth::Reject { + proceed_with_methods: None, + }); + } + let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); let remote_addr_display = self .remote_addr @@ -63,6 +158,7 @@ impl Handler for ServerHandler { Ok(()) => { tracing::info!( remote_addr = %remote_addr_display, + user = user, key_fingerprint = %fingerprint, result = "accept", "auth attempt" @@ -70,8 +166,10 @@ impl Handler for ServerHandler { Ok(Auth::Accept) } Err(_) => { + self.auth_limiter.on_failure(); tracing::info!( remote_addr = %remote_addr_display, + user = user, key_fingerprint = %fingerprint, result = "reject", "auth attempt" @@ -188,10 +286,22 @@ mod tests { Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()) } + fn default_limiter() -> Arc { + Arc::new(ConnectionRateLimiter::new(0)) + } + + fn make_handler( + auth_config: Arc, + outbound_proxy: Option, + remote_addr: Option, + ) -> ServerHandler { + ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10) + } + #[tokio::test] async fn auth_delegation_accepts_known_key() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); - let mut handler = ServerHandler::new(auth_config, None, None); + let mut handler = make_handler(auth_config, None, None); let ssh_key = load_key().public_key().clone(); let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap(); @@ -201,7 +311,7 @@ mod tests { #[tokio::test] async fn auth_delegation_rejects_unknown_key() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); - let mut handler = ServerHandler::new(auth_config, None, None); + let mut handler = make_handler(auth_config, None, None); let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host"; let other_ssh_key = russh::keys::parse_public_key_base64( @@ -224,7 +334,7 @@ mod tests { #[tokio::test] async fn auth_delegation_empty_config_rejects_all() { let auth_config = make_empty_auth_config(); - let mut handler = ServerHandler::new(auth_config, None, None); + let mut handler = make_handler(auth_config, None, None); let ssh_key = load_key().public_key().clone(); let result = handler @@ -243,7 +353,7 @@ mod tests { async fn auth_logging_includes_remote_addr() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap(); - let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr)); + let mut handler = make_handler(auth_config, None, Some(remote_addr)); let ssh_key = load_key().public_key().clone(); let _ = handler.auth_publickey("root", &ssh_key).await.unwrap(); @@ -287,7 +397,7 @@ mod tests { }); let remote: Option = Some("10.0.0.1:22".parse().unwrap()); - let handler = ServerHandler::new(auth_config, proxy.clone(), remote); + let handler = make_handler(auth_config, proxy.clone(), remote); assert!(handler.outbound_proxy.is_some()); assert!(handler.remote_addr.is_some()); } @@ -295,9 +405,108 @@ mod tests { #[test] fn one_handler_per_connection() { let auth_config = make_empty_auth_config(); - let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap())); - let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap())); + let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap())); + let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap())); assert!(handler1.remote_addr != handler2.remote_addr); } + + #[tokio::test] + async fn auth_rate_limit_rejects_after_max_failures() { + let auth_config = make_empty_auth_config(); + let limiter = Arc::new(ConnectionRateLimiter::new(0)); + let mut handler = ServerHandler::new( + auth_config, + None, + Some("10.0.0.1:22".parse().unwrap()), + TransportKind::Tcp, + limiter, + 2, + ); + + let ssh_key = load_key().public_key().clone(); + + let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap(); + assert_eq!(r1, Auth::Reject { proceed_with_methods: None }); + + let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap(); + assert_eq!(r2, Auth::Reject { proceed_with_methods: None }); + + assert!(!handler.auth_limiter.check()); + } + + #[test] + fn connection_rate_limit_blocks_over_limit() { + let limiter = Arc::new(ConnectionRateLimiter::new(1)); + let auth_config = make_empty_auth_config(); + let addr: SocketAddr = "10.0.0.1:22".parse().unwrap(); + + let h1 = ServerHandler::new( + auth_config.clone(), + None, + Some(addr), + TransportKind::Tcp, + limiter.clone(), + 10, + ); + assert!(h1.is_connection_allowed()); + + let h2 = ServerHandler::new( + auth_config.clone(), + None, + Some(addr), + TransportKind::Tcp, + limiter.clone(), + 10, + ); + assert!(!h2.is_connection_allowed()); + + drop(h1); + + let h3 = ServerHandler::new( + auth_config, + None, + Some(addr), + TransportKind::Tcp, + limiter, + 10, + ); + assert!(h3.is_connection_allowed()); + } + + #[test] + fn transport_kind_display() { + assert_eq!(TransportKind::Tcp.to_string(), "tcp"); + assert_eq!(TransportKind::Tls.to_string(), "tls"); + assert_eq!(TransportKind::Iroh.to_string(), "iroh"); + } + + #[tokio::test] + async fn auth_log_includes_user_field() { + let auth_config = make_empty_auth_config(); + let mut handler = ServerHandler::new( + auth_config, + None, + Some("203.0.113.50:12345".parse().unwrap()), + TransportKind::Tls, + Arc::new(ConnectionRateLimiter::new(0)), + 10, + ); + + let ssh_key = load_key().public_key().clone(); + let _ = handler.auth_publickey("root", &ssh_key).await.unwrap(); + } + + #[test] + fn connection_closed_logs_duration_on_drop() { + let auth_config = make_empty_auth_config(); + let _handler = ServerHandler::new( + auth_config, + None, + Some("203.0.113.50:12345".parse().unwrap()), + TransportKind::Tcp, + Arc::new(ConnectionRateLimiter::new(0)), + 10, + ); + } } \ No newline at end of file diff --git a/crates/wraith-core/src/server/mod.rs b/crates/wraith-core/src/server/mod.rs index 1fd1705..b0dda5d 100644 --- a/crates/wraith-core/src/server/mod.rs +++ b/crates/wraith-core/src/server/mod.rs @@ -1,3 +1,5 @@ pub mod handler; +pub mod rate_limit; -pub use handler::{ProxyConfig, ProxyMode, ServerHandler}; \ No newline at end of file +pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; +pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; \ No newline at end of file diff --git a/crates/wraith-core/src/server/rate_limit.rs b/crates/wraith-core/src/server/rate_limit.rs new file mode 100644 index 0000000..4fcc8ee --- /dev/null +++ b/crates/wraith-core/src/server/rate_limit.rs @@ -0,0 +1,193 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Mutex; + +pub struct ConnectionRateLimiter { + max_per_ip: usize, + active: Mutex>, +} + +impl ConnectionRateLimiter { + pub fn new(max_per_ip: usize) -> Self { + Self { + max_per_ip, + active: Mutex::new(HashMap::new()), + } + } + + pub fn check(&self, ip: IpAddr) -> bool { + if self.max_per_ip == 0 { + return true; + } + let active = self.active.lock().unwrap(); + let count = active.get(&ip).copied().unwrap_or(0); + count < self.max_per_ip + } + + pub fn on_connect(&self, ip: IpAddr) { + let mut active = self.active.lock().unwrap(); + *active.entry(ip).or_insert(0) += 1; + } + + pub fn on_disconnect(&self, ip: IpAddr) { + let mut active = self.active.lock().unwrap(); + if let Some(count) = active.get_mut(&ip) { + if *count > 1 { + *count -= 1; + } else { + active.remove(&ip); + } + } + } +} + +pub struct AuthAttemptLimiter { + max_attempts: usize, + failures: usize, +} + +impl AuthAttemptLimiter { + pub fn new(max_attempts: usize) -> Self { + Self { + max_attempts, + failures: 0, + } + } + + pub fn check(&self) -> bool { + if self.max_attempts == 0 { + return true; + } + self.failures < self.max_attempts + } + + pub fn on_failure(&mut self) { + self.failures += 1; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + fn ip(n: u8) -> IpAddr { + IpAddr::V4(Ipv4Addr::new(192, 168, 1, n)) + } + + #[test] + fn connection_limiter_allows_when_under_limit() { + let limiter = ConnectionRateLimiter::new(3); + assert!(limiter.check(ip(1))); + } + + #[test] + fn connection_limiter_blocks_when_at_limit() { + let limiter = ConnectionRateLimiter::new(2); + limiter.on_connect(ip(1)); + limiter.on_connect(ip(1)); + assert!(!limiter.check(ip(1))); + } + + #[test] + fn connection_limiter_allows_after_disconnect() { + let limiter = ConnectionRateLimiter::new(2); + limiter.on_connect(ip(1)); + limiter.on_connect(ip(1)); + assert!(!limiter.check(ip(1))); + limiter.on_disconnect(ip(1)); + assert!(limiter.check(ip(1))); + } + + #[test] + fn connection_limiter_unlimited_when_zero() { + let limiter = ConnectionRateLimiter::new(0); + for _ in 0..100 { + limiter.on_connect(ip(1)); + } + assert!(limiter.check(ip(1))); + } + + #[test] + fn connection_limiter_tracks_per_ip_independently() { + let limiter = ConnectionRateLimiter::new(1); + limiter.on_connect(ip(1)); + assert!(!limiter.check(ip(1))); + assert!(limiter.check(ip(2))); + } + + #[test] + fn connection_limiter_ipv6() { + let limiter = ConnectionRateLimiter::new(1); + let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST); + limiter.on_connect(ip6); + assert!(!limiter.check(ip6)); + } + + #[test] + fn connection_limiter_disconnect_removes_zero_entry() { + let limiter = ConnectionRateLimiter::new(3); + limiter.on_connect(ip(1)); + limiter.on_disconnect(ip(1)); + { + let active = limiter.active.lock().unwrap(); + assert!(!active.contains_key(&ip(1))); + } + } + + #[test] + fn auth_limiter_allows_when_under_limit() { + let limiter = AuthAttemptLimiter::new(3); + assert!(limiter.check()); + } + + #[test] + fn auth_limiter_blocks_after_max_failures() { + let mut limiter = AuthAttemptLimiter::new(2); + limiter.on_failure(); + limiter.on_failure(); + assert!(!limiter.check()); + } + + #[test] + fn auth_limiter_unlimited_when_zero() { + let mut limiter = AuthAttemptLimiter::new(0); + for _ in 0..100 { + limiter.on_failure(); + } + assert!(limiter.check()); + } + + #[test] + fn auth_limiter_still_allows_at_one_below_limit() { + let mut limiter = AuthAttemptLimiter::new(3); + limiter.on_failure(); + limiter.on_failure(); + assert!(limiter.check()); + limiter.on_failure(); + assert!(!limiter.check()); + } + + #[test] + fn connection_limiter_thread_safety() { + use std::sync::Arc; + use std::thread; + + let limiter = Arc::new(ConnectionRateLimiter::new(100)); + let mut handles = vec![]; + + for i in 0..10 { + let lim = Arc::clone(&limiter); + handles.push(thread::spawn(move || { + let ip_addr = ip((i % 3) as u8 + 1); + lim.on_connect(ip_addr); + assert!(lim.check(ip_addr)); + lim.on_disconnect(ip_addr); + })); + } + + for h in handles { + h.join().unwrap(); + } + } +} \ No newline at end of file