Merge remote-tracking branch 'origin/feat/server/rate-limiting-and-logging'
# Conflicts: # crates/wraith-core/src/error.rs # crates/wraith-core/src/server/handler.rs # crates/wraith-core/src/server/mod.rs
This commit is contained in:
@@ -125,7 +125,7 @@ impl LocalForwarder {
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
) -> Result<(), ForwardError> {
|
||||
let listen_addr = self.spec.listen_addr()?;
|
||||
let listener = TcpListener::bind(listen_addr)
|
||||
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||
.await
|
||||
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||
self.listener = Some(listener);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
@@ -10,6 +11,7 @@ use crate::auth::ServerAuthConfig;
|
||||
use crate::server::control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
|
||||
};
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyMode {
|
||||
@@ -23,11 +25,33 @@ pub struct ProxyConfig {
|
||||
pub mode: ProxyMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportKind::Tcp => write!(f, "tcp"),
|
||||
TransportKind::Tls => write!(f, "tls"),
|
||||
TransportKind::Iroh => write!(f, "iroh"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerHandler {
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
connected_at: Instant,
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
@@ -35,15 +59,71 @@ impl ServerHandler {
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
auth_config,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
connected_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_connection_allowed(&self) -> bool {
|
||||
self.connection_allowed
|
||||
}
|
||||
|
||||
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||
self.remote_addr.map(|a| a.ip())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
}
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
pub fn with_control_channel_handler(
|
||||
mut self,
|
||||
handler: Box<dyn ControlChannelHandler>,
|
||||
@@ -66,6 +146,23 @@ impl Handler for ServerHandler {
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
@@ -78,6 +175,7 @@ impl Handler for ServerHandler {
|
||||
Ok(()) => {
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
@@ -85,8 +183,10 @@ impl Handler for ServerHandler {
|
||||
Ok(Auth::Accept)
|
||||
}
|
||||
Err(_) => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
@@ -213,10 +313,22 @@ mod tests {
|
||||
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
||||
}
|
||||
|
||||
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||
Arc::new(ConnectionRateLimiter::new(0))
|
||||
}
|
||||
|
||||
fn make_handler(
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
) -> ServerHandler {
|
||||
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_accepts_known_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
@@ -226,7 +338,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_rejects_unknown_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key = russh::keys::parse_public_key_base64(
|
||||
@@ -249,7 +361,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_empty_config_rejects_all() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler
|
||||
@@ -268,7 +380,7 @@ mod tests {
|
||||
async fn auth_logging_includes_remote_addr() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
|
||||
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
@@ -288,7 +400,7 @@ mod tests {
|
||||
#[test]
|
||||
fn server_handler_without_control_handler_rejects_wraith_destinations() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler = ServerHandler::new(auth_config, None, None);
|
||||
let handler = make_handler(auth_config, None, None);
|
||||
assert!(!handler.control_channel_router().has_handler());
|
||||
}
|
||||
|
||||
@@ -320,7 +432,7 @@ mod tests {
|
||||
});
|
||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||
|
||||
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
|
||||
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||
assert!(handler.outbound_proxy.is_some());
|
||||
assert!(handler.remote_addr.is_some());
|
||||
}
|
||||
@@ -328,9 +440,108 @@ mod tests {
|
||||
#[test]
|
||||
fn one_handler_per_connection() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||
|
||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
2,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
|
||||
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
|
||||
|
||||
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
|
||||
|
||||
assert!(!handler.auth_limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_rate_limit_blocks_over_limit() {
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let auth_config = make_empty_auth_config();
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(h1.is_connection_allowed());
|
||||
|
||||
let h2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(!h2.is_connection_allowed());
|
||||
|
||||
drop(h1);
|
||||
|
||||
let h3 = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(h3.is_connection_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_display() {
|
||||
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportKind::Tls.to_string(), "tls");
|
||||
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_log_includes_user_field() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tls,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_logs_duration_on_drop() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let _handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
pub mod control_channel;
|
||||
pub mod handler;
|
||||
pub mod rate_limit;
|
||||
|
||||
pub use control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
|
||||
WRAITH_PREFIX, is_reserved_destination,
|
||||
};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub struct ConnectionRateLimiter {
|
||||
max_per_ip: usize,
|
||||
active: Mutex<HashMap<IpAddr, usize>>,
|
||||
}
|
||||
|
||||
impl ConnectionRateLimiter {
|
||||
pub fn new(max_per_ip: usize) -> Self {
|
||||
Self {
|
||||
max_per_ip,
|
||||
active: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, ip: IpAddr) -> bool {
|
||||
if self.max_per_ip == 0 {
|
||||
return true;
|
||||
}
|
||||
let active = self.active.lock().unwrap();
|
||||
let count = active.get(&ip).copied().unwrap_or(0);
|
||||
count < self.max_per_ip
|
||||
}
|
||||
|
||||
pub fn on_connect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
*active.entry(ip).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
if let Some(count) = active.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
active.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthAttemptLimiter {
|
||||
max_attempts: usize,
|
||||
failures: usize,
|
||||
}
|
||||
|
||||
impl AuthAttemptLimiter {
|
||||
pub fn new(max_attempts: usize) -> Self {
|
||||
Self {
|
||||
max_attempts,
|
||||
failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self) -> bool {
|
||||
if self.max_attempts == 0 {
|
||||
return true;
|
||||
}
|
||||
self.failures < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn on_failure(&mut self) {
|
||||
self.failures += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
fn ip(n: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_when_under_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_blocks_when_at_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_after_disconnect() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
limiter.on_disconnect(ip(1));
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_unlimited_when_zero() {
|
||||
let limiter = ConnectionRateLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_connect(ip(1));
|
||||
}
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_tracks_per_ip_independently() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
assert!(limiter.check(ip(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_ipv6() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||
limiter.on_connect(ip6);
|
||||
assert!(!limiter.check(ip6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_disconnect(ip(1));
|
||||
{
|
||||
let active = limiter.active.lock().unwrap();
|
||||
assert!(!active.contains_key(&ip(1)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_allows_when_under_limit() {
|
||||
let limiter = AuthAttemptLimiter::new(3);
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_blocks_after_max_failures() {
|
||||
let mut limiter = AuthAttemptLimiter::new(2);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_unlimited_when_zero() {
|
||||
let mut limiter = AuthAttemptLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_failure();
|
||||
}
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||
let mut limiter = AuthAttemptLimiter::new(3);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(limiter.check());
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_thread_safety() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let lim = Arc::clone(&limiter);
|
||||
handles.push(thread::spawn(move || {
|
||||
let ip_addr = ip((i % 3) as u8 + 1);
|
||||
lim.on_connect(ip_addr);
|
||||
assert!(lim.check(ip_addr));
|
||||
lim.on_disconnect(ip_addr);
|
||||
}));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user