1 Commits

Author SHA1 Message Date
24b70f5651 Implement server rate limiting and fail2ban-friendly structured logging
Add ConnectionRateLimiter (HashMap<IpAddr, usize>) and AuthAttemptLimiter
with check/on_connect/on_disconnect and check/on_failure methods.
Integrate into ServerHandler with structured tracing::info! logging for
auth attempts, connection opened/closed events. No logging of tunnel
destinations per ADR-006. Also add ForwardError type and fix type
annotation in forward.rs to unblock compilation.
2026-06-02 11:02:55 +00:00
5 changed files with 430 additions and 10 deletions

View File

@@ -125,7 +125,7 @@ impl LocalForwarder {
handle: Arc<Mutex<client::Handle<H>>>,
) -> Result<(), ForwardError> {
let listen_addr = self.spec.listen_addr()?;
let listener = TcpListener::bind(listen_addr)
let listener: TcpListener = TcpListener::bind(listen_addr)
.await
.map_err(|e| ForwardError::BindFailed { source: e })?;
self.listener = Some(listener);

View File

@@ -60,6 +60,22 @@ pub enum ConfigError {
IncompatibleOptions,
}
#[derive(Debug, thiserror::Error)]
pub enum ForwardError {
#[error("invalid forward spec: {spec}")]
InvalidSpec { spec: String },
#[error("bind failed")]
BindFailed {
#[source]
source: io::Error,
},
#[error("channel open failed")]
ChannelOpenFailed {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,5 +1,6 @@
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
@@ -7,6 +8,7 @@ use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use crate::auth::ServerAuthConfig;
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
const WRAITH_PREFIX: &str = "wraith-";
@@ -22,10 +24,32 @@ pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
}
}
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
connected_at: Instant,
}
impl ServerHandler {
@@ -33,11 +57,65 @@ impl ServerHandler {
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
auth_config,
outbound_proxy,
remote_addr,
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
connected_at: Instant::now(),
}
}
pub fn is_connection_allowed(&self) -> bool {
self.connection_allowed
}
pub fn remote_ip(&self) -> Option<IpAddr> {
self.remote_addr.map(|a| a.ip())
}
}
impl Drop for ServerHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
}
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
@@ -51,6 +129,23 @@ impl Handler for ServerHandler {
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
@@ -63,6 +158,7 @@ impl Handler for ServerHandler {
Ok(()) => {
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
@@ -70,8 +166,10 @@ impl Handler for ServerHandler {
Ok(Auth::Accept)
}
Err(_) => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
@@ -188,10 +286,22 @@ mod tests {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
Arc::new(ConnectionRateLimiter::new(0))
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
}
#[tokio::test]
async fn auth_delegation_accepts_known_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
@@ -201,7 +311,7 @@ mod tests {
#[tokio::test]
async fn auth_delegation_rejects_unknown_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
@@ -224,7 +334,7 @@ mod tests {
#[tokio::test]
async fn auth_delegation_empty_config_rejects_all() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(auth_config, None, None);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
@@ -243,7 +353,7 @@ mod tests {
async fn auth_logging_includes_remote_addr() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
let mut handler = make_handler(auth_config, None, Some(remote_addr));
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
@@ -287,7 +397,7 @@ mod tests {
});
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
let handler = make_handler(auth_config, proxy.clone(), remote);
assert!(handler.outbound_proxy.is_some());
assert!(handler.remote_addr.is_some());
}
@@ -295,9 +405,108 @@ mod tests {
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
assert!(handler1.remote_addr != handler2.remote_addr);
}
#[tokio::test]
async fn auth_rate_limit_rejects_after_max_failures() {
let auth_config = make_empty_auth_config();
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let mut handler = ServerHandler::new(
auth_config,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
limiter,
2,
);
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert!(!handler.auth_limiter.check());
}
#[test]
fn connection_rate_limit_blocks_over_limit() {
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let auth_config = make_empty_auth_config();
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(h1.is_connection_allowed());
let h2 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(!h2.is_connection_allowed());
drop(h1);
let h3 = ServerHandler::new(
auth_config,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(h3.is_connection_allowed());
}
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
}
#[tokio::test]
async fn auth_log_includes_user_field() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn connection_closed_logs_duration_on_drop() {
let auth_config = make_empty_auth_config();
let _handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tcp,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
}
}

View File

@@ -1,3 +1,5 @@
pub mod handler;
pub mod rate_limit;
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};

View File

@@ -0,0 +1,193 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
pub struct ConnectionRateLimiter {
max_per_ip: usize,
active: Mutex<HashMap<IpAddr, usize>>,
}
impl ConnectionRateLimiter {
pub fn new(max_per_ip: usize) -> Self {
Self {
max_per_ip,
active: Mutex::new(HashMap::new()),
}
}
pub fn check(&self, ip: IpAddr) -> bool {
if self.max_per_ip == 0 {
return true;
}
let active = self.active.lock().unwrap();
let count = active.get(&ip).copied().unwrap_or(0);
count < self.max_per_ip
}
pub fn on_connect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
*active.entry(ip).or_insert(0) += 1;
}
pub fn on_disconnect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
if let Some(count) = active.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
active.remove(&ip);
}
}
}
}
pub struct AuthAttemptLimiter {
max_attempts: usize,
failures: usize,
}
impl AuthAttemptLimiter {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
failures: 0,
}
}
pub fn check(&self) -> bool {
if self.max_attempts == 0 {
return true;
}
self.failures < self.max_attempts
}
pub fn on_failure(&mut self) {
self.failures += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn ip(n: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
}
#[test]
fn connection_limiter_allows_when_under_limit() {
let limiter = ConnectionRateLimiter::new(3);
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_blocks_when_at_limit() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
}
#[test]
fn connection_limiter_allows_after_disconnect() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
limiter.on_disconnect(ip(1));
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_unlimited_when_zero() {
let limiter = ConnectionRateLimiter::new(0);
for _ in 0..100 {
limiter.on_connect(ip(1));
}
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_tracks_per_ip_independently() {
let limiter = ConnectionRateLimiter::new(1);
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
assert!(limiter.check(ip(2)));
}
#[test]
fn connection_limiter_ipv6() {
let limiter = ConnectionRateLimiter::new(1);
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
limiter.on_connect(ip6);
assert!(!limiter.check(ip6));
}
#[test]
fn connection_limiter_disconnect_removes_zero_entry() {
let limiter = ConnectionRateLimiter::new(3);
limiter.on_connect(ip(1));
limiter.on_disconnect(ip(1));
{
let active = limiter.active.lock().unwrap();
assert!(!active.contains_key(&ip(1)));
}
}
#[test]
fn auth_limiter_allows_when_under_limit() {
let limiter = AuthAttemptLimiter::new(3);
assert!(limiter.check());
}
#[test]
fn auth_limiter_blocks_after_max_failures() {
let mut limiter = AuthAttemptLimiter::new(2);
limiter.on_failure();
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn auth_limiter_unlimited_when_zero() {
let mut limiter = AuthAttemptLimiter::new(0);
for _ in 0..100 {
limiter.on_failure();
}
assert!(limiter.check());
}
#[test]
fn auth_limiter_still_allows_at_one_below_limit() {
let mut limiter = AuthAttemptLimiter::new(3);
limiter.on_failure();
limiter.on_failure();
assert!(limiter.check());
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn connection_limiter_thread_safety() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(ConnectionRateLimiter::new(100));
let mut handles = vec![];
for i in 0..10 {
let lim = Arc::clone(&limiter);
handles.push(thread::spawn(move || {
let ip_addr = ip((i % 3) as u8 + 1);
lim.on_connect(ip_addr);
assert!(lim.check(ip_addr));
lim.on_disconnect(ip_addr);
}));
}
for h in handles {
h.join().unwrap();
}
}
}