use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Instant; use async_trait::async_trait; use russh::keys::ssh_key::HashAlg; use russh::server::{Auth, Handler, Msg, Session}; use russh::Channel; use crate::auth::ServerAuthConfig; use crate::server::control_channel::{ ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX, }; use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; #[derive(Debug, Clone)] pub enum ProxyMode { Direct, Socks5(SocketAddr), HttpConnect(SocketAddr), } #[derive(Debug, Clone)] pub struct ProxyConfig { pub mode: ProxyMode, } #[derive(Debug, Clone, Copy, PartialEq)] pub enum TransportKind { Tcp, Tls, Iroh, } impl std::fmt::Display for TransportKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TransportKind::Tcp => write!(f, "tcp"), TransportKind::Tls => write!(f, "tls"), TransportKind::Iroh => write!(f, "iroh"), } } } pub struct ServerHandler { auth_config: Arc, #[allow(dead_code)] outbound_proxy: Option, remote_addr: Option, control_channel_router: ControlChannelRouter, transport: TransportKind, connection_limiter: Arc, connection_allowed: bool, auth_limiter: AuthAttemptLimiter, connected_at: Instant, } impl ServerHandler { pub fn new( auth_config: Arc, outbound_proxy: Option, remote_addr: Option, transport: TransportKind, connection_limiter: Arc, max_auth_attempts: usize, ) -> Self { let allowed = if let Some(addr) = remote_addr { let ip = addr.ip(); if connection_limiter.check(ip) { connection_limiter.on_connect(ip); tracing::info!( remote_addr = %addr, transport = %transport, "connection opened" ); true } else { tracing::info!( remote_addr = %addr, transport = %transport, "connection rejected" ); false } } else { true }; Self { auth_config, outbound_proxy, remote_addr, control_channel_router: ControlChannelRouter::without_handler(), transport, connection_limiter, connection_allowed: allowed, auth_limiter: AuthAttemptLimiter::new(max_auth_attempts), connected_at: Instant::now(), } } pub fn is_connection_allowed(&self) -> bool { self.connection_allowed } pub fn remote_ip(&self) -> Option { self.remote_addr.map(|a| a.ip()) } } impl Drop for ServerHandler { fn drop(&mut self) { if let Some(addr) = self.remote_addr { if self.connection_allowed { self.connection_limiter.on_disconnect(addr.ip()); } let duration = self.connected_at.elapsed(); tracing::info!( remote_addr = %addr, duration_secs = duration.as_secs_f64(), "connection closed" ); } } } impl ServerHandler { pub fn with_control_channel_handler( mut self, handler: Box, ) -> Self { self.control_channel_router = ControlChannelRouter::with_handler(handler); self } pub fn control_channel_router(&self) -> &ControlChannelRouter { &self.control_channel_router } } #[async_trait] impl Handler for ServerHandler { type Error = russh::Error; async fn auth_publickey( &mut self, user: &str, public_key: &russh::keys::ssh_key::PublicKey, ) -> Result { if !self.auth_limiter.check() { let remote_addr_display = self .remote_addr .map_or("unknown".to_string(), |a| a.to_string()); let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); tracing::info!( remote_addr = %remote_addr_display, user = user, key_fingerprint = %fingerprint, result = "reject", "auth attempt" ); return Ok(Auth::Reject { proceed_with_methods: None, }); } let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); let remote_addr_display = self .remote_addr .map_or("unknown".to_string(), |a| a.to_string()); let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user); let result = self.auth_config.authenticate_publickey(&russh_pub); match result { Ok(()) => { tracing::info!( remote_addr = %remote_addr_display, user = user, key_fingerprint = %fingerprint, result = "accept", "auth attempt" ); Ok(Auth::Accept) } Err(_) => { self.auth_limiter.on_failure(); tracing::info!( remote_addr = %remote_addr_display, user = user, key_fingerprint = %fingerprint, result = "reject", "auth attempt" ); Ok(Auth::Reject { proceed_with_methods: None, }) } } } async fn channel_open_direct_tcpip( &mut self, channel: Channel, host_to_connect: &str, port_to_connect: u32, originator_address: &str, originator_port: u32, _session: &mut Session, ) -> Result { if host_to_connect.starts_with(WRAITH_PREFIX) { tracing::info!( host = host_to_connect, port = port_to_connect, "routing to internal control channel handler" ); if !self.control_channel_router.has_handler() { tracing::warn!( host = host_to_connect, "no control channel handler configured, rejecting channel open" ); return Ok(false); } let _ = channel; return Ok(true); } let _ = (host_to_connect, port_to_connect, originator_address, originator_port, channel); Ok(false) } async fn channel_open_session( &mut self, _channel: Channel, _session: &mut Session, ) -> Result { Ok(false) } async fn channel_open_x11( &mut self, _channel: Channel, _originator_address: &str, _originator_port: u32, _session: &mut Session, ) -> Result { Ok(false) } async fn channel_open_forwarded_tcpip( &mut self, _channel: Channel, _host_to_connect: &str, _port_to_connect: u32, _originator_address: &str, _originator_port: u32, _session: &mut Session, ) -> Result { Ok(false) } } #[cfg(test)] mod tests { use super::*; use crate::auth::keys::KeySource; use russh::keys::{decode_secret_key, PrivateKey}; use std::io::Write; const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n"; const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096"; fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile { let mut f = tempfile::NamedTempFile::new().unwrap(); f.write_all(keys_content.as_bytes()).unwrap(); f.flush().unwrap(); f } fn load_key() -> PrivateKey { decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap() } fn make_auth_config(keys_content: &str) -> Arc { let f = make_authorized_keys_file(keys_content); Arc::new( ServerAuthConfig::from_keys_and_ca( Some(KeySource::File(f.path().to_path_buf())), None, ) .unwrap(), ) } fn make_empty_auth_config() -> Arc { Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()) } fn default_limiter() -> Arc { Arc::new(ConnectionRateLimiter::new(0)) } fn make_handler( auth_config: Arc, outbound_proxy: Option, remote_addr: Option, ) -> ServerHandler { ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10) } #[tokio::test] async fn auth_delegation_accepts_known_key() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); let mut handler = make_handler(auth_config, None, None); let ssh_key = load_key().public_key().clone(); let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap(); assert_eq!(result, Auth::Accept); } #[tokio::test] async fn auth_delegation_rejects_unknown_key() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); let mut handler = make_handler(auth_config, None, None); let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host"; let other_ssh_key = russh::keys::parse_public_key_base64( other_key_text.split_whitespace().nth(1).unwrap(), ) .unwrap(); let result = handler .auth_publickey("testuser", &other_ssh_key) .await .unwrap(); assert_eq!( result, Auth::Reject { proceed_with_methods: None } ); } #[tokio::test] async fn auth_delegation_empty_config_rejects_all() { let auth_config = make_empty_auth_config(); let mut handler = make_handler(auth_config, None, None); let ssh_key = load_key().public_key().clone(); let result = handler .auth_publickey("testuser", &ssh_key) .await .unwrap(); assert_eq!( result, Auth::Reject { proceed_with_methods: None } ); } #[tokio::test] async fn auth_logging_includes_remote_addr() { let auth_config = make_auth_config(ED25519_PUBLIC_KEY); let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap(); let mut handler = make_handler(auth_config, None, Some(remote_addr)); let ssh_key = load_key().public_key().clone(); let _ = handler.auth_publickey("root", &ssh_key).await.unwrap(); } #[test] fn reserved_wraith_destination_routing() { use crate::server::control_channel::is_reserved_destination; assert!(is_reserved_destination("wraith-control")); assert!(is_reserved_destination("wraith-status")); assert!(is_reserved_destination("wraith-events")); assert!(!is_reserved_destination("example.com")); assert!(!is_reserved_destination("localhost")); assert!(!is_reserved_destination("wraith.example.com")); } #[test] fn server_handler_without_control_handler_rejects_wraith_destinations() { let auth_config = make_empty_auth_config(); let handler = make_handler(auth_config, None, None); assert!(!handler.control_channel_router().has_handler()); } #[test] fn proxy_mode_variants() { let direct = ProxyMode::Direct; let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()); let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap()); match direct { ProxyMode::Direct => {} _ => panic!("expected Direct"), } match socks5 { ProxyMode::Socks5(_) => {} _ => panic!("expected Socks5"), } match http { ProxyMode::HttpConnect(_) => {} _ => panic!("expected HttpConnect"), } } #[test] fn server_handler_holds_config() { let auth_config = make_empty_auth_config(); let proxy = Some(ProxyConfig { mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()), }); let remote: Option = Some("10.0.0.1:22".parse().unwrap()); let handler = make_handler(auth_config, proxy.clone(), remote); assert!(handler.outbound_proxy.is_some()); assert!(handler.remote_addr.is_some()); } #[test] fn one_handler_per_connection() { let auth_config = make_empty_auth_config(); let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap())); let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap())); assert!(handler1.remote_addr != handler2.remote_addr); } #[tokio::test] async fn auth_rate_limit_rejects_after_max_failures() { let auth_config = make_empty_auth_config(); let limiter = Arc::new(ConnectionRateLimiter::new(0)); let mut handler = ServerHandler::new( auth_config, None, Some("10.0.0.1:22".parse().unwrap()), TransportKind::Tcp, limiter, 2, ); let ssh_key = load_key().public_key().clone(); let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap(); assert_eq!(r1, Auth::Reject { proceed_with_methods: None }); let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap(); assert_eq!(r2, Auth::Reject { proceed_with_methods: None }); assert!(!handler.auth_limiter.check()); } #[test] fn connection_rate_limit_blocks_over_limit() { let limiter = Arc::new(ConnectionRateLimiter::new(1)); let auth_config = make_empty_auth_config(); let addr: SocketAddr = "10.0.0.1:22".parse().unwrap(); let h1 = ServerHandler::new( auth_config.clone(), None, Some(addr), TransportKind::Tcp, limiter.clone(), 10, ); assert!(h1.is_connection_allowed()); let h2 = ServerHandler::new( auth_config.clone(), None, Some(addr), TransportKind::Tcp, limiter.clone(), 10, ); assert!(!h2.is_connection_allowed()); drop(h1); let h3 = ServerHandler::new( auth_config, None, Some(addr), TransportKind::Tcp, limiter, 10, ); assert!(h3.is_connection_allowed()); } #[test] fn transport_kind_display() { assert_eq!(TransportKind::Tcp.to_string(), "tcp"); assert_eq!(TransportKind::Tls.to_string(), "tls"); assert_eq!(TransportKind::Iroh.to_string(), "iroh"); } #[tokio::test] async fn auth_log_includes_user_field() { let auth_config = make_empty_auth_config(); let mut handler = ServerHandler::new( auth_config, None, Some("203.0.113.50:12345".parse().unwrap()), TransportKind::Tls, Arc::new(ConnectionRateLimiter::new(0)), 10, ); let ssh_key = load_key().public_key().clone(); let _ = handler.auth_publickey("root", &ssh_key).await.unwrap(); } #[test] fn connection_closed_logs_duration_on_drop() { let auth_config = make_empty_auth_config(); let _handler = ServerHandler::new( auth_config, None, Some("203.0.113.50:12345".parse().unwrap()), TransportKind::Tcp, Arc::new(ConnectionRateLimiter::new(0)), 10, ); } }