Files
alknet/crates/wraith-core/src/server/handler.rs
glm-5.1 f13a1c985f Merge remote-tracking branch 'origin/feat/server/channel-proxy'
# Conflicts:
#	crates/wraith-core/src/error.rs
#	crates/wraith-core/src/server/mod.rs
2026-06-02 11:32:28 +00:00

533 lines
16 KiB
Rust

use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use crate::auth::ServerAuthConfig;
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
#[derive(Debug, Clone)]
pub enum ProxyMode {
Direct,
Socks5(SocketAddr),
HttpConnect(SocketAddr),
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
}
}
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
#[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
control_channel_router: ControlChannelRouter,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
connected_at: Instant,
}
impl ServerHandler {
pub fn new(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
auth_config,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
connected_at: Instant::now(),
}
}
pub fn is_connection_allowed(&self) -> bool {
self.connection_allowed
}
pub fn remote_ip(&self) -> Option<IpAddr> {
self.remote_addr.map(|a| a.ip())
}
}
impl Drop for ServerHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
}
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
impl ServerHandler {
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
pub fn control_channel_router(&self) -> &ControlChannelRouter {
&self.control_channel_router
}
}
#[async_trait]
impl Handler for ServerHandler {
type Error = russh::Error;
async fn auth_publickey(
&mut self,
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
let result = self.auth_config.authenticate_publickey(&russh_pub);
match result {
Ok(()) => {
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
);
Ok(Auth::Accept)
}
Err(_) => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
Ok(Auth::Reject {
proceed_with_methods: None,
})
}
}
}
async fn channel_open_direct_tcpip(
&mut self,
channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
_session: &mut Session,
) -> Result<bool, Self::Error> {
if host_to_connect.starts_with(WRAITH_PREFIX) {
tracing::info!(
host = host_to_connect,
port = port_to_connect,
"routing to internal control channel handler"
);
if !self.control_channel_router.has_handler() {
tracing::warn!(
host = host_to_connect,
"no control channel handler configured, rejecting channel open"
);
return Ok(false);
}
let _ = channel;
return Ok(true);
}
let _ = (host_to_connect, port_to_connect, originator_address, originator_port, channel);
Ok(false)
}
async fn channel_open_session(
&mut self,
_channel: Channel<Msg>,
_session: &mut Session,
) -> Result<bool, Self::Error> {
Ok(false)
}
async fn channel_open_x11(
&mut self,
_channel: Channel<Msg>,
_originator_address: &str,
_originator_port: u32,
_session: &mut Session,
) -> Result<bool, Self::Error> {
Ok(false)
}
async fn channel_open_forwarded_tcpip(
&mut self,
_channel: Channel<Msg>,
_host_to_connect: &str,
_port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
_session: &mut Session,
) -> Result<bool, Self::Error> {
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use russh::keys::{decode_secret_key, PrivateKey};
use std::io::Write;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(keys_content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
fn load_key() -> PrivateKey {
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
}
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
let f = make_authorized_keys_file(keys_content);
Arc::new(
ServerAuthConfig::from_keys_and_ca(
Some(KeySource::File(f.path().to_path_buf())),
None,
)
.unwrap(),
)
}
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
Arc::new(ConnectionRateLimiter::new(0))
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
}
#[tokio::test]
async fn auth_delegation_accepts_known_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(result, Auth::Accept);
}
#[tokio::test]
async fn auth_delegation_rejects_unknown_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
other_key_text.split_whitespace().nth(1).unwrap(),
)
.unwrap();
let result = handler
.auth_publickey("testuser", &other_ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_delegation_empty_config_rejects_all() {
let auth_config = make_empty_auth_config();
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
.auth_publickey("testuser", &ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_logging_includes_remote_addr() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
let mut handler = make_handler(auth_config, None, Some(remote_addr));
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn reserved_wraith_destination_routing() {
use crate::server::control_channel::is_reserved_destination;
assert!(is_reserved_destination("wraith-control"));
assert!(is_reserved_destination("wraith-status"));
assert!(is_reserved_destination("wraith-events"));
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("wraith.example.com"));
}
#[test]
fn server_handler_without_control_handler_rejects_wraith_destinations() {
let auth_config = make_empty_auth_config();
let handler = make_handler(auth_config, None, None);
assert!(!handler.control_channel_router().has_handler());
}
#[test]
fn proxy_mode_variants() {
let direct = ProxyMode::Direct;
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
match direct {
ProxyMode::Direct => {}
_ => panic!("expected Direct"),
}
match socks5 {
ProxyMode::Socks5(_) => {}
_ => panic!("expected Socks5"),
}
match http {
ProxyMode::HttpConnect(_) => {}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn server_handler_holds_config() {
let auth_config = make_empty_auth_config();
let proxy = Some(ProxyConfig {
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
});
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
let handler = make_handler(auth_config, proxy.clone(), remote);
assert!(handler.outbound_proxy.is_some());
assert!(handler.remote_addr.is_some());
}
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
assert!(handler1.remote_addr != handler2.remote_addr);
}
#[tokio::test]
async fn auth_rate_limit_rejects_after_max_failures() {
let auth_config = make_empty_auth_config();
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let mut handler = ServerHandler::new(
auth_config,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
limiter,
2,
);
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert!(!handler.auth_limiter.check());
}
#[test]
fn connection_rate_limit_blocks_over_limit() {
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let auth_config = make_empty_auth_config();
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(h1.is_connection_allowed());
let h2 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(!h2.is_connection_allowed());
drop(h1);
let h3 = ServerHandler::new(
auth_config,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(h3.is_connection_allowed());
}
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
}
#[tokio::test]
async fn auth_log_includes_user_field() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn connection_closed_logs_duration_on_drop() {
let auth_config = make_empty_auth_config();
let _handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tcp,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
}
}