refactor!: rebrand wraith to alknet

Rename all crates, CLI commands, constants, type names, doc comments,
and documentation from wraith to alknet. Includes wire-protocol changes:
ALPN wraith-ssh -> alknet-ssh, reserved destination prefix wraith- ->
alknet-, SSH auth username wraith -> alknet.
This commit is contained in:
2026-06-05 10:04:32 +00:00
parent af7f4d0006
commit 596c89ce24
101 changed files with 552 additions and 552 deletions

View File

@@ -0,0 +1,563 @@
//! Outbound connection proxy for SSH channel targets.
//!
//! Connects to the requested `host:port` either directly, via SOCKS5 proxy, or
//! via HTTP CONNECT proxy, then proxies bytes bidirectionally between the SSH
//! channel and the outbound TCP stream.
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::handler::{ProxyConfig, ProxyMode};
#[derive(Debug, thiserror::Error)]
pub enum ChannelProxyError {
#[error("connection refused")]
ConnectionRefused,
#[error("target unreachable")]
TargetUnreachable,
#[error("socks5 proxy handshake failed")]
Socks5HandshakeFailed,
#[error("socks5 proxy rejected connection")]
Socks5ProxyRejected,
#[error("http connect proxy handshake failed")]
HttpConnectHandshakeFailed,
#[error("http connect proxy rejected: {0}")]
HttpConnectProxyRejected(String),
#[error("io error")]
Io(#[from] std::io::Error),
}
pub async fn connect_outbound(
target: SocketAddr,
proxy: &ProxyConfig,
) -> Result<TcpStream, ChannelProxyError> {
match &proxy.mode {
ProxyMode::Direct => connect_direct(target).await,
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
}
}
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
TcpStream::connect(target)
.await
.map_err(|e| map_connection_error(e, target))
}
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
let mut stream = TcpStream::connect(proxy_addr)
.await
.map_err(ChannelProxyError::from)?;
stream.write_all(&[0x05, 0x01, 0x00]).await?;
stream.flush().await?;
let mut resp = [0u8; 2];
stream.read_exact(&mut resp).await?;
if resp[0] != 0x05 || resp[1] != 0x00 {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
let ip_bytes = target.ip().to_string();
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
connect_req.push(ip_bytes.len() as u8);
connect_req.extend_from_slice(ip_bytes.as_bytes());
connect_req.extend_from_slice(&target.port().to_be_bytes());
stream.write_all(&connect_req).await?;
stream.flush().await?;
let mut reply_header = [0u8; 4];
stream.read_exact(&mut reply_header).await?;
if reply_header[0] != 0x05 {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
if reply_header[1] != 0x00 {
return Err(ChannelProxyError::Socks5ProxyRejected);
}
let atyp = reply_header[3];
match atyp {
0x01 => {
let mut _addr = [0u8; 4];
stream.read_exact(&mut _addr).await?;
}
0x04 => {
let mut _addr = [0u8; 16];
stream.read_exact(&mut _addr).await?;
}
0x03 => {
let len = stream.read_u8().await?;
let mut _domain = vec![0u8; len as usize];
stream.read_exact(&mut _domain).await?;
}
_ => {
return Err(ChannelProxyError::Socks5HandshakeFailed);
}
}
let mut _port = [0u8; 2];
stream.read_exact(&mut _port).await?;
Ok(stream)
}
async fn connect_http_connect(
target: SocketAddr,
proxy_addr: SocketAddr,
) -> Result<TcpStream, ChannelProxyError> {
let mut stream = TcpStream::connect(proxy_addr)
.await
.map_err(ChannelProxyError::from)?;
let connect_request = format!(
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
target.ip(),
target.port(),
target.ip(),
target.port()
);
stream.write_all(connect_request.as_bytes()).await?;
stream.flush().await?;
let mut response = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = stream.read(&mut buf).await?;
if n == 0 {
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
}
response.extend_from_slice(&buf[..n]);
if response.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response_str = String::from_utf8_lossy(&response);
let status_line = response_str
.lines()
.next()
.unwrap_or("");
if status_line.contains("200") {
Ok(stream)
} else {
Err(ChannelProxyError::HttpConnectProxyRejected(
status_line.to_string(),
))
}
}
fn map_connection_error(e: std::io::Error, _target: SocketAddr) -> ChannelProxyError {
match e.kind() {
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
std::io::ErrorKind::AddrNotAvailable
| std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
_ => ChannelProxyError::Io(e),
}
}
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
if let Ok(outbound) = connect_outbound(target, proxy).await {
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
let (mut read_out, mut write_out) = outbound.into_split();
let client_to_target = tokio::spawn(async move {
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
let _ = write_out.shutdown().await;
});
let target_to_client = tokio::spawn(async move {
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
let _ = write_chan.shutdown().await;
});
let _ = client_to_target.await;
let _ = target_to_client.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio::net::TcpListener;
fn direct_config() -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::Direct,
}
}
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
}
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
}
#[tokio::test]
async fn direct_connection_to_echo_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
let (mut read, mut write) = stream.into_split();
write.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
read.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
let _ = server.await;
}
#[tokio::test]
async fn direct_connection_target_unreachable() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let result = connect_outbound(target, &direct_config()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn socks5_proxy_handshake() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = target_listener.local_addr().unwrap();
let target_server = tokio::spawn(async move {
let (mut sock, _) = target_listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut greeting = [0u8; 3];
proxy_sock.read_exact(&mut greeting).await.unwrap();
assert_eq!(greeting[0], 0x05);
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
let mut req_header = [0u8; 4];
proxy_sock.read_exact(&mut req_header).await.unwrap();
assert_eq!(req_header[0], 0x05);
assert_eq!(req_header[1], 0x01);
let atyp = req_header[3];
assert_eq!(atyp, 0x03);
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
let mut domain = vec![0u8; domain_len];
proxy_sock.read_exact(&mut domain).await.unwrap();
let mut port_bytes = [0u8; 2];
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
let target: SocketAddr = format!(
"{}:{}",
String::from_utf8_lossy(&domain),
u16::from_be_bytes(port_bytes)
)
.parse()
.unwrap();
let reply = vec![
0x05, 0x00, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
proxy_sock.write_all(&reply).await.unwrap();
let mut target_stream = TcpStream::connect(target).await.unwrap();
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
});
let config = socks5_config(proxy_addr);
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
stream.write_all(b"hello socks").await.unwrap();
let mut buf = [0u8; 11];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello socks");
drop(stream);
let _ = target_server.await;
let _ = proxy_server.await;
}
#[tokio::test]
async fn socks5_proxy_rejected() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut greeting = [0u8; 3];
proxy_sock.read_exact(&mut greeting).await.unwrap();
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
let mut req_header = [0u8; 4];
proxy_sock.read_exact(&mut req_header).await.unwrap();
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
let mut domain = vec![0u8; domain_len];
proxy_sock.read_exact(&mut domain).await.unwrap();
let mut port_bytes = [0u8; 2];
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
let reply = vec![
0x05, 0x05, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
proxy_sock.write_all(&reply).await.unwrap();
});
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let config = socks5_config(proxy_addr);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ChannelProxyError::Socks5ProxyRejected
));
let _ = proxy_server.await;
}
#[tokio::test]
async fn http_connect_proxy_handshake() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = target_listener.local_addr().unwrap();
let target_server = tokio::spawn(async move {
let (mut sock, _) = target_listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = proxy_sock.read(&mut buf).await.unwrap();
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
proxy_sock.write_all(response.as_bytes()).await.unwrap();
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
});
let config = http_connect_config(proxy_addr);
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
stream.write_all(b"hello http").await.unwrap();
let mut buf = [0u8; 10];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello http");
drop(stream);
let _ = target_server.await;
let _ = proxy_server.await;
}
fn extract_connect_target(request: &str) -> String {
let connect_line = request.lines().next().unwrap_or("");
let parts: Vec<&str> = connect_line.split_whitespace().collect();
if parts.len() >= 2 {
parts[1].to_string()
} else {
String::new()
}
}
#[tokio::test]
async fn http_connect_proxy_rejected() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = proxy_listener.local_addr().unwrap();
let proxy_server = tokio::spawn(async move {
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = proxy_sock.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
proxy_sock.write_all(response.as_bytes()).await.unwrap();
});
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let config = http_connect_config(proxy_addr);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
match result.unwrap_err() {
ChannelProxyError::HttpConnectProxyRejected(msg) => {
assert!(msg.contains("403"));
}
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
}
let _ = proxy_server.await;
}
#[tokio::test]
async fn target_unreachable_returns_appropriate_error() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let result = connect_outbound(target, &direct_config()).await;
match result.unwrap_err() {
ChannelProxyError::TargetUnreachable
| ChannelProxyError::ConnectionRefused
| ChannelProxyError::Io(_) => {}
other => panic!("unexpected error type: {:?}", other),
}
}
#[tokio::test]
async fn socks5_proxy_unreachable() {
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
let config = socks5_config(bad_proxy);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn http_connect_proxy_unreachable() {
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
let config = http_connect_config(bad_proxy);
let result = connect_outbound(target, &config).await;
assert!(result.is_err());
}
struct MockChannel {
read_half: tokio::io::ReadHalf<DuplexStream>,
write_half: tokio::io::WriteHalf<DuplexStream>,
}
impl tokio::io::AsyncRead for MockChannel {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
}
}
impl tokio::io::AsyncWrite for MockChannel {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
}
}
fn make_mock_channel() -> (MockChannel, DuplexStream) {
let (client, server) = duplex(4096);
let (read_half, write_half) = tokio::io::split(client);
(
MockChannel {
read_half,
write_half,
},
server,
)
}
#[tokio::test]
async fn proxy_channel_bidirectional_data_flow() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let target_addr = listener.local_addr().unwrap();
let echo_server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = sock.read(&mut buf).await.unwrap();
sock.write_all(&buf[..n]).await.unwrap();
});
let (channel, mut channel_peer) = make_mock_channel();
let target = target_addr;
let proxy = direct_config();
tokio::spawn(async move {
proxy_channel(channel, target, &proxy).await;
});
channel_peer.write_all(b"ping").await.unwrap();
channel_peer.flush().await.unwrap();
let mut buf = [0u8; 4];
channel_peer.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
drop(channel_peer);
let _ = echo_server.await;
}
#[tokio::test]
async fn proxy_channel_target_unreachable_closes_cleanly() {
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
let (channel, _channel_peer) = make_mock_channel();
let proxy = direct_config();
proxy_channel(channel, target, &proxy).await;
}
}

View File

@@ -0,0 +1,192 @@
//! Control channel routing for reserved `alknet-*` destinations.
//!
//! SSH channels opened with a destination starting with `alknet-` are intercepted
//! by the server and routed to a `ControlChannelHandler` instead of proxied to a
//! TCP target. See ADR-018 for the design rationale.
use std::io;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
pub const ALKNET_CONTROL_DESTINATION: &str = "alknet-control";
pub const ALKNET_PREFIX: &str = "alknet-";
pub fn is_reserved_destination(host: &str) -> bool {
host.starts_with(ALKNET_PREFIX)
}
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
#[async_trait]
pub trait ControlChannelHandler: Send + Sync {
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
}
pub struct ControlChannelRouter {
handler: Option<Box<dyn ControlChannelHandler>>,
}
impl ControlChannelRouter {
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
Self { handler }
}
pub fn without_handler() -> Self {
Self { handler: None }
}
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
Self {
handler: Some(handler),
}
}
pub fn has_handler(&self) -> bool {
self.handler.is_some()
}
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
match &self.handler {
Some(handler) => {
handler.handle_channel(stream).await;
Ok(())
}
None => Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"no control channel handler configured",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[test]
fn alknet_control_destination_constant() {
assert_eq!(ALKNET_CONTROL_DESTINATION, "alknet-control");
}
#[test]
fn alknet_prefix_constant() {
assert_eq!(ALKNET_PREFIX, "alknet-");
}
#[test]
fn reserved_destination_detected() {
assert!(is_reserved_destination("alknet-control"));
assert!(is_reserved_destination("alknet-status"));
assert!(is_reserved_destination("alknet-events"));
assert!(is_reserved_destination("alknet-"));
}
#[test]
fn non_reserved_destination_passes_through() {
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("192.168.1.1"));
assert!(!is_reserved_destination("alknet.example.com"));
assert!(!is_reserved_destination(""));
assert!(!is_reserved_destination("alkne-control"));
assert!(!is_reserved_destination("ALKNET-control"));
}
#[test]
fn prefix_matching_case_sensitive() {
assert!(!is_reserved_destination("Alknet-control"));
assert!(!is_reserved_destination("ALKNET-control"));
assert!(is_reserved_destination("alknet-Control"));
}
#[test]
fn router_without_handler_has_no_handler() {
let router = ControlChannelRouter::without_handler();
assert!(!router.has_handler());
}
#[test]
fn router_with_handler_has_handler() {
struct DummyHandler;
#[async_trait]
impl ControlChannelHandler for DummyHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
}
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
assert!(router.has_handler());
}
#[tokio::test]
async fn route_without_handler_returns_error() {
let router = ControlChannelRouter::without_handler();
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
}
#[tokio::test]
async fn route_with_handler_succeeds() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
struct TrackedHandler {
called: Arc<AtomicBool>,
}
#[async_trait]
impl ControlChannelHandler for TrackedHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
self.called.store(true, Ordering::SeqCst);
}
}
let called = Arc::new(AtomicBool::new(false));
let handler = TrackedHandler {
called: called.clone(),
};
let router = ControlChannelRouter::with_handler(Box::new(handler));
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_ok());
assert!(called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn route_with_handler_can_read_write() {
struct EchoHandler;
#[async_trait]
impl ControlChannelHandler for EchoHandler {
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
let mut buf = [0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
stream.write_all(&buf[..n]).await.unwrap();
}
}
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
let (client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
tokio::spawn(async move {
router.route(stream).await.unwrap();
});
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut client = client;
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[test]
fn control_channel_destination_matches_prefix() {
assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION));
}
}

View File

@@ -0,0 +1,736 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use russh::ChannelId;
use crate::auth::ServerAuthConfig;
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX,
};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
#[derive(Debug, Clone)]
pub enum ProxyMode {
Direct,
Socks5(SocketAddr),
HttpConnect(SocketAddr),
}
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
}
}
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
#[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
control_channel_router: ControlChannelRouter,
#[allow(dead_code)]
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
connected_at: Instant,
}
impl ServerHandler {
pub fn new(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
auth_config,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
connected_at: Instant::now(),
}
}
pub fn is_connection_allowed(&self) -> bool {
self.connection_allowed
}
pub fn remote_ip(&self) -> Option<IpAddr> {
self.remote_addr.map(|a| a.ip())
}
}
impl Drop for ServerHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
}
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
impl ServerHandler {
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
pub fn control_channel_router(&self) -> &ControlChannelRouter {
&self.control_channel_router
}
}
#[async_trait]
impl Handler for ServerHandler {
type Error = russh::Error;
async fn auth_publickey(
&mut self,
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
let result = self.auth_config.authenticate_publickey(&russh_pub);
match result {
Ok(()) => {
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
);
Ok(Auth::Accept)
}
Err(_) => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
Ok(Auth::Reject {
proceed_with_methods: None,
})
}
}
}
async fn channel_open_direct_tcpip(
&mut self,
channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
_session: &mut Session,
) -> Result<bool, Self::Error> {
if host_to_connect.starts_with(ALKNET_PREFIX) {
if !self.control_channel_router.has_handler() {
return Ok(false);
}
let _ = channel;
return Ok(true);
}
let target_host = host_to_connect.to_string();
let target_port = port_to_connect;
let proxy_config = self.outbound_proxy.clone().unwrap_or(ProxyConfig {
mode: ProxyMode::Direct,
});
tokio::spawn(async move {
let target = match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
Ok(addr) => addr,
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => return,
},
Err(_) => return,
},
};
crate::server::channel_proxy::proxy_channel(channel.into_stream(), target, &proxy_config).await;
});
let _ = (originator_address, originator_port);
Ok(true)
}
async fn channel_open_session(
&mut self,
_channel: Channel<Msg>,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected session channel (shell/exec not supported)"
);
let _ = session;
Ok(false)
}
async fn channel_open_x11(
&mut self,
_channel: Channel<Msg>,
_originator_address: &str,
_originator_port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected x11 channel"
);
let _ = session;
Ok(false)
}
async fn channel_open_forwarded_tcpip(
&mut self,
_channel: Channel<Msg>,
host_to_connect: &str,
port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
target = %format!("{host_to_connect}:{port_to_connect}"),
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn exec_request(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
data_len = data.len(),
"rejected exec request on channel (shell/exec not supported)"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn shell_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected shell request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn subsystem_request(
&mut self,
channel: ChannelId,
name: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
subsystem = name,
"rejected subsystem request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn pty_request(
&mut self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(russh::Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
term = term,
"rejected pty request on channel"
);
let _ = (col_width, row_height, pix_width, pix_height, modes);
let _ = session.channel_failure(channel);
Ok(())
}
async fn env_request(
&mut self,
channel: ChannelId,
variable_name: &str,
variable_value: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
variable = variable_name,
"rejected env request on channel"
);
let _ = variable_value;
let _ = session.channel_failure(channel);
Ok(())
}
async fn x11_request(
&mut self,
channel: ChannelId,
single_connection: bool,
x11_auth_protocol: &str,
x11_auth_cookie: &str,
x11_screen_number: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected x11 request on channel"
);
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
let _ = session.channel_failure(channel);
Ok(())
}
async fn agent_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected agent forwarding request on channel"
);
let _ = session;
Ok(false)
}
async fn tcpip_forward(
&mut self,
address: &str,
port: &mut u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
address = address,
port = *port,
"rejected tcpip-forward request (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn cancel_tcpip_forward(
&mut self,
address: &str,
port: u32,
session: &mut Session,
) -> Result<bool, Self::Error> {
let _ = (address, port, session);
Ok(false)
}
async fn streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
socket_path = socket_path,
"rejected streamlocal-forward request"
);
let _ = session;
Ok(false)
}
async fn signal(
&mut self,
channel: ChannelId,
signal: russh::Sig,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::debug!(
remote_addr = ?self.remote_addr,
channel = %channel,
signal = ?signal,
"received signal on channel (ignored)"
);
let _ = session;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use russh::keys::{decode_secret_key, PrivateKey};
use std::io::Write;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(keys_content.as_bytes()).unwrap();
f.flush().unwrap();
f
}
fn load_key() -> PrivateKey {
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
}
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
let f = make_authorized_keys_file(keys_content);
Arc::new(
ServerAuthConfig::from_keys_and_ca(
Some(KeySource::File(f.path().to_path_buf())),
None,
)
.unwrap(),
)
}
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
Arc::new(ConnectionRateLimiter::new(0))
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
}
#[tokio::test]
async fn auth_delegation_accepts_known_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(result, Auth::Accept);
}
#[tokio::test]
async fn auth_delegation_rejects_unknown_key() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
other_key_text.split_whitespace().nth(1).unwrap(),
)
.unwrap();
let result = handler
.auth_publickey("testuser", &other_ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_delegation_empty_config_rejects_all() {
let auth_config = make_empty_auth_config();
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
.auth_publickey("testuser", &ssh_key)
.await
.unwrap();
assert_eq!(
result,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn auth_logging_includes_remote_addr() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
let mut handler = make_handler(auth_config, None, Some(remote_addr));
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn reserved_alknet_destination_routing() {
use crate::server::control_channel::is_reserved_destination;
assert!(is_reserved_destination("alknet-control"));
assert!(is_reserved_destination("alknet-status"));
assert!(is_reserved_destination("alknet-events"));
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("alknet.example.com"));
}
#[test]
fn server_handler_without_control_handler_rejects_alknet_destinations() {
let auth_config = make_empty_auth_config();
let handler = make_handler(auth_config, None, None);
assert!(!handler.control_channel_router().has_handler());
}
#[test]
fn proxy_mode_variants() {
let direct = ProxyMode::Direct;
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
match direct {
ProxyMode::Direct => {}
_ => panic!("expected Direct"),
}
match socks5 {
ProxyMode::Socks5(_) => {}
_ => panic!("expected Socks5"),
}
match http {
ProxyMode::HttpConnect(_) => {}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn server_handler_holds_config() {
let auth_config = make_empty_auth_config();
let proxy = Some(ProxyConfig {
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
});
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
let handler = make_handler(auth_config, proxy.clone(), remote);
assert!(handler.outbound_proxy.is_some());
assert!(handler.remote_addr.is_some());
}
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
assert!(handler1.remote_addr != handler2.remote_addr);
}
#[tokio::test]
async fn auth_rate_limit_rejects_after_max_failures() {
let auth_config = make_empty_auth_config();
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let mut handler = ServerHandler::new(
auth_config,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
limiter,
2,
);
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert!(!handler.auth_limiter.check());
}
#[test]
fn connection_rate_limit_blocks_over_limit() {
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let auth_config = make_empty_auth_config();
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(h1.is_connection_allowed());
let h2 = ServerHandler::new(
auth_config.clone(),
None,
Some(addr),
TransportKind::Tcp,
limiter.clone(),
10,
);
assert!(!h2.is_connection_allowed());
drop(h1);
let h3 = ServerHandler::new(
auth_config,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(h3.is_connection_allowed());
}
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
}
#[tokio::test]
async fn auth_log_includes_user_field() {
let auth_config = make_empty_auth_config();
let mut handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
let ssh_key = load_key().public_key().clone();
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
}
#[test]
fn connection_closed_logs_duration_on_drop() {
let auth_config = make_empty_auth_config();
let _handler = ServerHandler::new(
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tcp,
Arc::new(ConnectionRateLimiter::new(0)),
10,
);
}
}

View File

@@ -0,0 +1,25 @@
//! Server-side SSH connection handling.
//!
//! Provides `Server` for accepting SSH connections over any transport and proxying
//! `direct-tcpip` channel requests to targets. Supports Ed25519 and certificate-authority
//! auth, connection rate limiting, auth attempt limiting, stealth mode (fake nginx 404),
//! and outbound proxy routing (direct/SOCKS5/HTTP CONNECT).
//!
//! Destination hosts starting with `alknet-` are reserved for internal use (control channel, ADR-018).
pub mod channel_proxy;
pub mod control_channel;
pub mod handler;
pub mod rate_limit;
pub mod serve;
pub mod stealth;
pub use channel_proxy::{connect_outbound, proxy_channel};
pub use control_channel::{
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
ALKNET_PREFIX, is_reserved_destination,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use serve::{Server, ServeError, ServeOptions, ServeTransportMode};
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};

View File

@@ -0,0 +1,200 @@
//! Connection rate limiting and auth attempt limiting.
//!
//! `ConnectionRateLimiter` tracks per-IP active connections (thread-safe).
//! `AuthAttemptLimiter` caps failed auth attempts per connection.
//! These complement fail2ban on Linux and provide abuse protection on all platforms.
//! See ADR-013.
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
pub struct ConnectionRateLimiter {
max_per_ip: usize,
active: Mutex<HashMap<IpAddr, usize>>,
}
impl ConnectionRateLimiter {
pub fn new(max_per_ip: usize) -> Self {
Self {
max_per_ip,
active: Mutex::new(HashMap::new()),
}
}
pub fn check(&self, ip: IpAddr) -> bool {
if self.max_per_ip == 0 {
return true;
}
let active = self.active.lock().unwrap();
let count = active.get(&ip).copied().unwrap_or(0);
count < self.max_per_ip
}
pub fn on_connect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
*active.entry(ip).or_insert(0) += 1;
}
pub fn on_disconnect(&self, ip: IpAddr) {
let mut active = self.active.lock().unwrap();
if let Some(count) = active.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
active.remove(&ip);
}
}
}
}
pub struct AuthAttemptLimiter {
max_attempts: usize,
failures: usize,
}
impl AuthAttemptLimiter {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
failures: 0,
}
}
pub fn check(&self) -> bool {
if self.max_attempts == 0 {
return true;
}
self.failures < self.max_attempts
}
pub fn on_failure(&mut self) {
self.failures += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn ip(n: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
}
#[test]
fn connection_limiter_allows_when_under_limit() {
let limiter = ConnectionRateLimiter::new(3);
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_blocks_when_at_limit() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
}
#[test]
fn connection_limiter_allows_after_disconnect() {
let limiter = ConnectionRateLimiter::new(2);
limiter.on_connect(ip(1));
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
limiter.on_disconnect(ip(1));
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_unlimited_when_zero() {
let limiter = ConnectionRateLimiter::new(0);
for _ in 0..100 {
limiter.on_connect(ip(1));
}
assert!(limiter.check(ip(1)));
}
#[test]
fn connection_limiter_tracks_per_ip_independently() {
let limiter = ConnectionRateLimiter::new(1);
limiter.on_connect(ip(1));
assert!(!limiter.check(ip(1)));
assert!(limiter.check(ip(2)));
}
#[test]
fn connection_limiter_ipv6() {
let limiter = ConnectionRateLimiter::new(1);
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
limiter.on_connect(ip6);
assert!(!limiter.check(ip6));
}
#[test]
fn connection_limiter_disconnect_removes_zero_entry() {
let limiter = ConnectionRateLimiter::new(3);
limiter.on_connect(ip(1));
limiter.on_disconnect(ip(1));
{
let active = limiter.active.lock().unwrap();
assert!(!active.contains_key(&ip(1)));
}
}
#[test]
fn auth_limiter_allows_when_under_limit() {
let limiter = AuthAttemptLimiter::new(3);
assert!(limiter.check());
}
#[test]
fn auth_limiter_blocks_after_max_failures() {
let mut limiter = AuthAttemptLimiter::new(2);
limiter.on_failure();
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn auth_limiter_unlimited_when_zero() {
let mut limiter = AuthAttemptLimiter::new(0);
for _ in 0..100 {
limiter.on_failure();
}
assert!(limiter.check());
}
#[test]
fn auth_limiter_still_allows_at_one_below_limit() {
let mut limiter = AuthAttemptLimiter::new(3);
limiter.on_failure();
limiter.on_failure();
assert!(limiter.check());
limiter.on_failure();
assert!(!limiter.check());
}
#[test]
fn connection_limiter_thread_safety() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(ConnectionRateLimiter::new(100));
let mut handles = vec![];
for i in 0..10 {
let lim = Arc::clone(&limiter);
handles.push(thread::spawn(move || {
let ip_addr = ip((i % 3) as u8 + 1);
lim.on_connect(ip_addr);
assert!(lim.check(ip_addr));
lim.on_disconnect(ip_addr);
}));
}
for h in handles {
h.join().unwrap();
}
}
}

View File

@@ -0,0 +1,765 @@
//! Server configuration and accept loop.
//!
//! `Server` binds to a transport acceptor and runs an accept loop, handling
//! authentication, stealth mode protocol detection, and graceful shutdown.
//! `ServeOptions` provides a builder-pattern API for programmatic configuration.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use russh::server::{self, Config};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use crate::auth::keys::KeySource;
use crate::auth::server_auth::ServerAuthConfig;
use crate::error::ConfigError;
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
use crate::server::rate_limit::ConnectionRateLimiter;
use crate::server::stealth::{self, ProtocolDetection};
const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
/// Transport mode for the server listener.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ServeTransportMode {
Tcp,
Tls,
Iroh,
}
impl std::fmt::Display for ServeTransportMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServeTransportMode::Tcp => write!(f, "tcp"),
ServeTransportMode::Tls => write!(f, "tls"),
ServeTransportMode::Iroh => write!(f, "iroh"),
}
}
}
/// Programmatic configuration for an alknet server.
///
/// Construct with `ServeOptions::new(key_source)` and chain builder methods.
/// Call `validate()` before passing to `Server::new()`.
///
/// ```
/// use alknet_core::server::{ServeOptions, ServeTransportMode};
/// use alknet_core::auth::keys::KeySource;
///
/// let opts = ServeOptions::new(KeySource::File("/path/to/host_key".into()))
/// .transport_mode(ServeTransportMode::Tcp)
/// .listen_addr("0.0.0.0:22")
/// .max_connections_per_ip(5)
/// .max_auth_attempts(3);
/// opts.validate().unwrap();
/// ```
pub struct ServeOptions {
pub key: KeySource,
pub authorized_keys: Option<KeySource>,
pub cert_authority: Option<KeySource>,
pub transport_mode: ServeTransportMode,
pub listen_addr: String,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub acme_domain: Option<String>,
pub stealth: bool,
pub proxy: Option<String>,
pub iroh_relay: Option<String>,
pub max_connections_per_ip: usize,
pub max_auth_attempts: usize,
}
impl ServeOptions {
pub fn new(key: KeySource) -> Self {
Self {
key,
authorized_keys: None,
cert_authority: None,
transport_mode: ServeTransportMode::Tcp,
listen_addr: DEFAULT_LISTEN_ADDR.to_string(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
proxy: None,
iroh_relay: None,
max_connections_per_ip: 0,
max_auth_attempts: 10,
}
}
pub fn authorized_keys(mut self, source: KeySource) -> Self {
self.authorized_keys = Some(source);
self
}
pub fn cert_authority(mut self, source: KeySource) -> Self {
self.cert_authority = Some(source);
self
}
pub fn transport_mode(mut self, mode: ServeTransportMode) -> Self {
self.transport_mode = mode;
self
}
pub fn listen_addr(mut self, addr: impl Into<String>) -> Self {
self.listen_addr = addr.into();
self
}
pub fn tls_cert(mut self, path: impl Into<String>) -> Self {
self.tls_cert = Some(path.into());
self
}
pub fn tls_key(mut self, path: impl Into<String>) -> Self {
self.tls_key = Some(path.into());
self
}
pub fn acme_domain(mut self, domain: impl Into<String>) -> Self {
self.acme_domain = Some(domain.into());
self
}
pub fn stealth(mut self, enabled: bool) -> Self {
self.stealth = enabled;
self
}
pub fn proxy(mut self, url: impl Into<String>) -> Self {
self.proxy = Some(url.into());
self
}
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
self.iroh_relay = Some(url.into());
self
}
pub fn max_connections_per_ip(mut self, max: usize) -> Self {
self.max_connections_per_ip = max;
self
}
pub fn max_auth_attempts(mut self, max: usize) -> Self {
self.max_auth_attempts = max;
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.stealth && self.transport_mode != ServeTransportMode::Tls {
return Err(ConfigError::InvalidFlag {
name: "stealth mode requires TLS transport (--transport tls)".to_string(),
});
}
match self.transport_mode {
ServeTransportMode::Tls => {
if self.tls_cert.is_none() && self.acme_domain.is_none() {
return Err(ConfigError::InvalidFlag {
name: "TLS transport requires --tls-cert/--tls-key or --acme-domain"
.to_string(),
});
}
if self.tls_cert.is_some() && self.tls_key.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--tls-cert requires --tls-key".to_string(),
});
}
if self.tls_key.is_some() && self.tls_cert.is_none() {
return Err(ConfigError::InvalidFlag {
name: "--tls-key requires --tls-cert".to_string(),
});
}
}
ServeTransportMode::Tcp | ServeTransportMode::Iroh => {
if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() {
return Err(ConfigError::IncompatibleOptions);
}
}
}
Ok(())
}
}
impl std::fmt::Debug for ServeOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeOptions")
.field("key", &"<KeySource>")
.field("authorized_keys", &"<KeySource>")
.field("cert_authority", &"<KeySource>")
.field("transport_mode", &self.transport_mode)
.field("listen_addr", &self.listen_addr)
.field("stealth", &self.stealth)
.field("max_connections_per_ip", &self.max_connections_per_ip)
.field("max_auth_attempts", &self.max_auth_attempts)
.finish()
}
}
/// Errors that can occur during server setup and operation.
#[derive(Debug, thiserror::Error)]
pub enum ServeError {
#[error("config error: {0}")]
Config(#[from] ConfigError),
#[error("bind failed: {0}")]
BindFailed(anyhow::Error),
#[error("key load failed: {0}")]
KeyLoadFailed(ConfigError),
#[error("accept failed")]
AcceptFailed,
}
struct ActiveSession {
handle: server::Handle,
join: tokio::task::JoinHandle<()>,
}
/// The alknet SSH server.
///
/// Accepts connections over any `TransportAcceptor`, authenticates via Ed25519 keys
/// or certificate authority, and proxies `direct-tcpip` channels to their targets.
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
pub struct Server {
config: Arc<server::Config>,
auth_config: Arc<ServerAuthConfig>,
connection_limiter: Arc<ConnectionRateLimiter>,
outbound_proxy: Option<ProxyConfig>,
stealth: bool,
transport_mode: ServeTransportMode,
listen_addr: String,
max_auth_attempts: usize,
shutdown_tx: tokio::sync::watch::Sender<bool>,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
sessions: Arc<tokio::sync::Mutex<Vec<ActiveSession>>>,
}
impl Server {
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
opts.validate().map_err(ServeError::Config)?;
let private_key =
crate::auth::keys::load_private_key(opts.key.clone()).map_err(ServeError::KeyLoadFailed)?;
let auth_config = Arc::new(
ServerAuthConfig::from_keys_and_ca(opts.authorized_keys.clone(), opts.cert_authority.clone())
.map_err(ServeError::KeyLoadFailed)?,
);
let config = Arc::new(Config {
keys: vec![private_key],
max_auth_attempts: opts.max_auth_attempts,
methods: russh::MethodSet::PUBLICKEY,
preferred: russh::Preferred::DEFAULT,
..Default::default()
});
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
let connection_limiter = Arc::new(ConnectionRateLimiter::new(opts.max_connections_per_ip));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
Ok(Self {
config,
auth_config,
connection_limiter,
outbound_proxy,
stealth: opts.stealth,
transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr,
max_auth_attempts: opts.max_auth_attempts,
shutdown_tx,
shutdown_rx,
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
})
}
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
self.shutdown_tx.clone()
}
pub async fn shutdown(&self) -> Result<(), ServeError> {
info!("initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
{
let sessions = self.sessions.lock().await;
for session in sessions.iter() {
if let Err(e) = session.handle.disconnect(
russh::Disconnect::ByApplication,
"shutdown".to_string(),
String::new(),
).await {
warn!("failed to send SSH disconnect: {e}");
}
}
}
tokio::time::sleep(DRAIN_TIMEOUT).await;
{
let mut sessions = self.sessions.lock().await;
for session in sessions.drain(..) {
session.join.abort();
}
}
info!("graceful shutdown complete");
Ok(())
}
pub fn is_shutdown(&self) -> bool {
*self.shutdown_rx.borrow()
}
pub async fn run<A>(self, acceptor: A, endpoint_info: Option<&str>) -> Result<(), ServeError>
where
A: crate::transport::TransportAcceptor,
{
let transport_kind = match self.transport_mode {
ServeTransportMode::Tcp => TransportKind::Tcp,
ServeTransportMode::Tls => TransportKind::Tls,
ServeTransportMode::Iroh => TransportKind::Iroh,
};
if self.transport_mode == ServeTransportMode::Iroh {
if let Some(id) = endpoint_info {
info!("alknet server running: transport=iroh endpoint_id={}", id);
} else {
info!("alknet server running: transport=iroh");
}
} else {
info!(
"alknet server running: transport={} listen={}",
self.transport_mode, self.listen_addr
);
}
let server = Arc::new(self);
let mut shutdown_rx = server.shutdown_rx.clone();
#[cfg(unix)]
let signal_done = {
let sig_tx = server.shutdown_tx.clone();
tokio::spawn(async move {
let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = sigterm_stream.recv() => {
info!("received SIGTERM");
}
_ = tokio::signal::ctrl_c() => {
info!("received SIGINT (Ctrl+C)");
}
}
let _ = sig_tx.send(true);
})
};
loop {
let shutdown = *shutdown_rx.borrow();
if shutdown {
info!("shutdown signaled, stopping accept loop");
break;
}
let accept_result = tokio::select! {
result = acceptor.accept() => result,
_ = shutdown_rx.changed() => {
info!("shutdown signaled while waiting for connection");
break;
}
};
let (stream, info) = match accept_result {
Ok(conn) => conn,
Err(e) => {
error!("accept failed: {e}");
continue;
}
};
let remote_addr = info.remote_addr;
let handler_transport_kind = transport_kind;
let handler = ServerHandler::new(
Arc::clone(&server.auth_config),
server.outbound_proxy.clone(),
remote_addr,
handler_transport_kind,
Arc::clone(&server.connection_limiter),
server.max_auth_attempts,
);
if !handler.is_connection_allowed() {
continue;
}
let config = Arc::clone(&server.config);
let sessions = Arc::clone(&server.sessions);
let stealth = server.stealth;
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
tokio::spawn(async move {
let result = handle_connection(
stream,
config,
handler,
sessions,
stealth,
transport_is_tls,
)
.await;
if let Err(e) = result {
warn!("connection error: {e}");
}
});
}
#[cfg(unix)]
signal_done.abort();
server.shutdown().await?;
Ok(())
}
}
async fn handle_connection<S>(
stream: S,
config: Arc<Config>,
handler: ServerHandler,
sessions: Arc<tokio::sync::Mutex<Vec<ActiveSession>>>,
stealth: bool,
transport_is_tls: bool,
) -> Result<(), anyhow::Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
if stealth && transport_is_tls {
let (protocol, mut reader) = stealth::detect_protocol(stream).await;
match protocol {
ProtocolDetection::Http => {
stealth::send_fake_nginx_404(&mut reader).await;
return Ok(());
}
ProtocolDetection::Ssh => {
let running = server::run_stream(config, reader, handler).await?;
let handle = running.handle();
let join = tokio::spawn(async {
let _ = running.await;
});
sessions.lock().await.push(ActiveSession { handle, join });
return Ok(());
}
}
}
let running = server::run_stream(config, stream, handler).await?;
let handle = running.handle();
let join = tokio::spawn(async {
let _ = running.await;
});
sessions.lock().await.push(ActiveSession { handle, join });
Ok(())
}
fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
proxy.map(|url| {
if url.starts_with("socks5://") {
let addr: SocketAddr = url
.strip_prefix("socks5://")
.unwrap()
.parse()
.expect("invalid socks5 proxy address");
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
} else if url.starts_with("http://") {
let addr: SocketAddr = url
.strip_prefix("http://")
.unwrap()
.parse()
.expect("invalid http connect proxy address");
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
} else {
panic!("unsupported proxy URL scheme: {url}");
}
})
}
#[cfg(test)]
mod tests {
use super::*;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_key_source() -> KeySource {
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
}
fn make_authorized_keys_source() -> KeySource {
KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec())
}
#[test]
fn serve_options_default_fields() {
let opts = ServeOptions::new(make_key_source());
assert!(opts.authorized_keys.is_none());
assert!(opts.cert_authority.is_none());
assert_eq!(opts.transport_mode, ServeTransportMode::Tcp);
assert_eq!(opts.listen_addr, "0.0.0.0:22");
assert!(opts.tls_cert.is_none());
assert!(opts.tls_key.is_none());
assert!(opts.acme_domain.is_none());
assert!(!opts.stealth);
assert!(opts.proxy.is_none());
assert!(opts.iroh_relay.is_none());
assert_eq!(opts.max_connections_per_ip, 0);
assert_eq!(opts.max_auth_attempts, 10);
}
#[test]
fn serve_options_builder_pattern() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.cert_authority(make_authorized_keys_source())
.transport_mode(ServeTransportMode::Tls)
.listen_addr("0.0.0.0:443")
.tls_cert("/etc/ssl/cert.pem")
.tls_key("/etc/ssl/key.pem")
.stealth(true)
.proxy("socks5://127.0.0.1:9050")
.iroh_relay("https://relay.example.com")
.max_connections_per_ip(5)
.max_auth_attempts(3);
assert!(opts.authorized_keys.is_some());
assert!(opts.cert_authority.is_some());
assert_eq!(opts.transport_mode, ServeTransportMode::Tls);
assert_eq!(opts.listen_addr, "0.0.0.0:443");
assert_eq!(opts.tls_cert.as_deref(), Some("/etc/ssl/cert.pem"));
assert_eq!(opts.tls_key.as_deref(), Some("/etc/ssl/key.pem"));
assert!(opts.stealth);
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:9050"));
assert_eq!(
opts.iroh_relay.as_deref(),
Some("https://relay.example.com")
);
assert_eq!(opts.max_connections_per_ip, 5);
assert_eq!(opts.max_auth_attempts, 3);
}
#[test]
fn serve_options_validate_steam_without_tls_rejected() {
let opts = ServeOptions::new(make_key_source()).stealth(true);
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_stealth_with_tls_ok() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_cert("/cert.pem")
.tls_key("/key.pem")
.stealth(true);
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_tcp_no_tls_options_ok() {
let opts = ServeOptions::new(make_key_source());
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_tls_requires_certs() {
let opts = ServeOptions::new(make_key_source()).transport_mode(ServeTransportMode::Tls);
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tls_cert_without_key_rejected() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_cert("/cert.pem");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tls_key_without_cert_rejected() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.tls_key("/key.pem");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_tcp_with_acme_rejected() {
let opts =
ServeOptions::new(make_key_source()).acme_domain("example.com");
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_acme_domain_with_tls_ok() {
let opts = ServeOptions::new(make_key_source())
.transport_mode(ServeTransportMode::Tls)
.acme_domain("example.com");
assert!(opts.validate().is_ok());
}
#[test]
fn server_new_creates_server() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
assert_eq!(server.max_auth_attempts, 10);
}
#[test]
fn server_new_stealth_without_tls_fails() {
let opts = ServeOptions::new(make_key_source()).stealth(true);
let result = Server::new(opts);
assert!(result.is_err());
}
#[test]
fn server_new_invalid_key_fails() {
let opts = ServeOptions::new(KeySource::Memory(b"not a key".to_vec()));
let result = Server::new(opts);
assert!(result.is_err());
}
#[test]
fn serve_transport_mode_display() {
assert_eq!(ServeTransportMode::Tcp.to_string(), "tcp");
assert_eq!(ServeTransportMode::Tls.to_string(), "tls");
assert_eq!(ServeTransportMode::Iroh.to_string(), "iroh");
}
#[test]
fn serve_transport_mode_equality() {
assert_eq!(ServeTransportMode::Tcp, ServeTransportMode::Tcp);
assert_ne!(ServeTransportMode::Tcp, ServeTransportMode::Tls);
assert_ne!(ServeTransportMode::Tls, ServeTransportMode::Iroh);
}
#[test]
fn serve_options_debug_redacts_keys() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let debug_str = format!("{:?}", opts);
assert!(debug_str.contains("<KeySource>"));
assert!(!debug_str.contains("OPENSSH"));
}
#[test]
fn parse_proxy_config_socks5() {
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::Socks5(addr) => {
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
}
_ => panic!("expected Socks5"),
}
}
#[test]
fn parse_proxy_config_http() {
let config = parse_proxy_config(Some("http://127.0.0.1:8080"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::HttpConnect(addr) => {
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn parse_proxy_config_none() {
assert!(parse_proxy_config(None).is_none());
}
#[test]
fn serve_error_variants() {
assert_eq!(ServeError::AcceptFailed.to_string(), "accept failed");
}
#[test]
fn default_listen_addr() {
assert_eq!(DEFAULT_LISTEN_ADDR, "0.0.0.0:22");
}
#[test]
fn drain_timeout_is_two_seconds() {
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
}
#[test]
fn server_shutdown_sender_clones() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
let sender = server.shutdown_sender();
assert!(!server.is_shutdown());
let _ = sender.send(true);
assert!(server.is_shutdown());
}
#[test]
fn server_holds_listen_addr() {
let opts = ServeOptions::new(make_key_source())
.listen_addr("0.0.0.0:443");
let server = Server::new(opts).unwrap();
assert_eq!(server.listen_addr, "0.0.0.0:443");
}
#[tokio::test]
async fn integration_server_accept_loop_and_shutdown() {
use crate::transport::TcpAcceptor;
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listen_addr(acceptor.listen_addr().to_string());
let server = Server::new(opts).unwrap();
let shutdown_tx = server.shutdown_sender();
let server_handle = tokio::spawn(async move {
server
.run(acceptor, None)
.await
.expect("server run failed")
});
tokio::time::sleep(Duration::from_millis(50)).await;
let _ = shutdown_tx.send(true);
let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
assert!(result.is_ok(), "server should have shut down within timeout");
}
}

View File

@@ -0,0 +1,226 @@
//! Stealth mode: protocol detection on TLS connections.
//!
//! When stealth mode is enabled with TLS transport, the server peeks at the first
//! bytes after the TLS handshake to determine whether the client is speaking SSH
//! or HTTP. Non-SSH connections receive a fake nginx 404 response, making the
//! server appear as an ordinary web server to port scanners and DPI systems.
//! See ADR-017.
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolDetection {
Ssh,
Http,
}
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
where
S: AsyncRead + Unpin,
{
let mut reader = BufReader::new(stream);
let detection = match reader.fill_buf().await {
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
Ok(buf) if !buf.is_empty() => {
if buf.starts_with(SSH_BANNER_PREFIX) {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
_ => ProtocolDetection::Http,
};
(detection, reader)
}
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
let _ = reader.get_mut().shutdown().await;
}
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
if stealth && !transport_is_tls {
return Err("stealth mode requires TLS transport (--transport tls)");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(data).await.unwrap();
drop(client);
let (detection, _) = detect_protocol(server).await;
detection
}
#[tokio::test]
async fn ssh_banner_detected() {
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_other_implementation() {
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_minimal() {
let detection = write_and_detect(b"SSH-2.0-X\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn http_get_detected_as_http() {
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_post_detected_as_http() {
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn random_data_detected_as_http() {
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn empty_stream_detected_as_http() {
let (client, server) = duplex(1024);
drop(client);
let (detection, _) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn ssh_banner_bytes_preserved_by_bufreader() {
let (client, server) = duplex(1024);
let mut client = client;
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
client.write_all(banner).await.unwrap();
client.write_all(b"subsequent data").await.unwrap();
drop(client);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Ssh);
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data).await.unwrap();
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
}
#[tokio::test]
async fn fake_nginx_404_response() {
let (client, server) = duplex(1024);
let (mut client_read, mut client_write) = tokio::io::split(client);
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
drop(client_write);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client_read.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.contains("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
}
#[tokio::test]
async fn protocol_detection_enum_equality() {
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
}
#[test]
fn validate_stealth_without_tls_rejected() {
let result = validate_stealth_config(true, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("TLS transport"));
}
#[test]
fn validate_stealth_with_tls_accepted() {
let result = validate_stealth_config(true, true);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tcp_accepted() {
let result = validate_stealth_config(false, false);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tls_accepted() {
let result = validate_stealth_config(false, true);
assert!(result.is_ok());
}
#[tokio::test]
async fn short_data_detected_as_http() {
let detection = write_and_detect(b"GE").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn partial_ssh_prefix_detected_as_http() {
let detection = write_and_detect(b"SSH-1.").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_request_gets_404_then_closed() {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
let mut extra = [0u8; 16];
let result = client.read(&mut extra).await;
assert!(result.is_err() || result.unwrap() == 0);
}
}