1 Commits

Author SHA1 Message Date
7dcf7502b7 feat(server): implement stealth mode protocol multiplexing (ADR-017)
Add stealth mode detection that peeks at the first bytes after TLS handshake
to determine SSH vs HTTP protocol. SSH connections proceed to russh handler;
non-SSH connections receive a fake nginx 404 response, making the server
indistinguishable from an ordinary HTTPS site to scanners and DPI systems.

- ProtocolDetection enum (Ssh, Http) for protocol classification
- detect_protocol() uses BufReader::fill_buf() to peek without consuming bytes
- send_fake_nginx_404() writes HTTP/1.1 404 + Server: nginx headers
- validate_stealth_config() enforces TLS transport requirement for stealth
- 17 unit tests covering SSH banner, HTTP, random data, and edge cases
2026-06-02 11:13:15 +00:00
5 changed files with 235 additions and 234 deletions

View File

@@ -62,7 +62,7 @@ pub enum ConfigError {
#[derive(Debug, thiserror::Error)]
pub enum ForwardError {
#[error("invalid forward specification: {spec}")]
#[error("invalid port forward spec: {spec}")]
InvalidSpec { spec: String },
#[error("bind failed")]
BindFailed {
@@ -74,6 +74,11 @@ pub enum ForwardError {
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("connect to local target failed")]
LocalConnectFailed {
#[source]
source: io::Error,
},
}
#[cfg(test)]

View File

@@ -1,186 +0,0 @@
use std::io;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
pub const WRAITH_CONTROL_DESTINATION: &str = "wraith-control";
pub const WRAITH_PREFIX: &str = "wraith-";
pub fn is_reserved_destination(host: &str) -> bool {
host.starts_with(WRAITH_PREFIX)
}
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
#[async_trait]
pub trait ControlChannelHandler: Send + Sync {
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
}
pub struct ControlChannelRouter {
handler: Option<Box<dyn ControlChannelHandler>>,
}
impl ControlChannelRouter {
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
Self { handler }
}
pub fn without_handler() -> Self {
Self { handler: None }
}
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
Self {
handler: Some(handler),
}
}
pub fn has_handler(&self) -> bool {
self.handler.is_some()
}
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
match &self.handler {
Some(handler) => {
handler.handle_channel(stream).await;
Ok(())
}
None => Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"no control channel handler configured",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[test]
fn wraith_control_destination_constant() {
assert_eq!(WRAITH_CONTROL_DESTINATION, "wraith-control");
}
#[test]
fn wraith_prefix_constant() {
assert_eq!(WRAITH_PREFIX, "wraith-");
}
#[test]
fn reserved_destination_detected() {
assert!(is_reserved_destination("wraith-control"));
assert!(is_reserved_destination("wraith-status"));
assert!(is_reserved_destination("wraith-events"));
assert!(is_reserved_destination("wraith-"));
}
#[test]
fn non_reserved_destination_passes_through() {
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("192.168.1.1"));
assert!(!is_reserved_destination("wraith.example.com"));
assert!(!is_reserved_destination(""));
assert!(!is_reserved_destination("wrait-control"));
assert!(!is_reserved_destination("WRAITH-control"));
}
#[test]
fn prefix_matching_case_sensitive() {
assert!(!is_reserved_destination("Wraith-control"));
assert!(!is_reserved_destination("WRAITH-control"));
assert!(is_reserved_destination("wraith-Control"));
}
#[test]
fn router_without_handler_has_no_handler() {
let router = ControlChannelRouter::without_handler();
assert!(!router.has_handler());
}
#[test]
fn router_with_handler_has_handler() {
struct DummyHandler;
#[async_trait]
impl ControlChannelHandler for DummyHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
}
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
assert!(router.has_handler());
}
#[tokio::test]
async fn route_without_handler_returns_error() {
let router = ControlChannelRouter::without_handler();
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
}
#[tokio::test]
async fn route_with_handler_succeeds() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
struct TrackedHandler {
called: Arc<AtomicBool>,
}
#[async_trait]
impl ControlChannelHandler for TrackedHandler {
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
self.called.store(true, Ordering::SeqCst);
}
}
let called = Arc::new(AtomicBool::new(false));
let handler = TrackedHandler {
called: called.clone(),
};
let router = ControlChannelRouter::with_handler(Box::new(handler));
let (_client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
let result = router.route(stream).await;
assert!(result.is_ok());
assert!(called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn route_with_handler_can_read_write() {
struct EchoHandler;
#[async_trait]
impl ControlChannelHandler for EchoHandler {
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
let mut buf = [0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
stream.write_all(&buf[..n]).await.unwrap();
}
}
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
let (client, server) = duplex(64);
let stream: Box<dyn DuplexStream> = Box::new(server);
tokio::spawn(async move {
router.route(stream).await.unwrap();
});
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut client = client;
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[test]
fn control_channel_destination_matches_prefix() {
assert!(is_reserved_destination(WRAITH_CONTROL_DESTINATION));
}
}

View File

@@ -7,9 +7,8 @@ use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use crate::auth::ServerAuthConfig;
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
};
const WRAITH_PREFIX: &str = "wraith-";
#[derive(Debug, Clone)]
pub enum ProxyMode {
@@ -27,7 +26,6 @@ pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
control_channel_router: ControlChannelRouter,
}
impl ServerHandler {
@@ -40,21 +38,8 @@ impl ServerHandler {
auth_config,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
}
}
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
pub fn control_channel_router(&self) -> &ControlChannelRouter {
&self.control_channel_router
}
}
#[async_trait]
@@ -113,16 +98,6 @@ impl Handler for ServerHandler {
port = port_to_connect,
"routing to internal control channel handler"
);
if !self.control_channel_router.has_handler() {
tracing::warn!(
host = host_to_connect,
"no control channel handler configured, rejecting channel open"
);
return Ok(false);
}
let _ = channel;
return Ok(true);
}
@@ -276,20 +251,12 @@ mod tests {
#[test]
fn reserved_wraith_destination_routing() {
use crate::server::control_channel::is_reserved_destination;
assert!(is_reserved_destination("wraith-control"));
assert!(is_reserved_destination("wraith-status"));
assert!(is_reserved_destination("wraith-events"));
assert!(!is_reserved_destination("example.com"));
assert!(!is_reserved_destination("localhost"));
assert!(!is_reserved_destination("wraith.example.com"));
}
#[test]
fn server_handler_without_control_handler_rejects_wraith_destinations() {
let auth_config = make_empty_auth_config();
let handler = ServerHandler::new(auth_config, None, None);
assert!(!handler.control_channel_router().has_handler());
assert!("wraith-control".starts_with(WRAITH_PREFIX));
assert!("wraith-status".starts_with(WRAITH_PREFIX));
assert!("wraith-events".starts_with(WRAITH_PREFIX));
assert!(!"example.com".starts_with(WRAITH_PREFIX));
assert!(!"localhost".starts_with(WRAITH_PREFIX));
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
}
#[test]

View File

@@ -1,8 +1,5 @@
pub mod control_channel;
pub mod handler;
pub mod stealth;
pub use control_channel::{
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
WRAITH_PREFIX, is_reserved_destination,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};

View File

@@ -0,0 +1,218 @@
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolDetection {
Ssh,
Http,
}
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
where
S: AsyncRead + Unpin,
{
let mut reader = BufReader::new(stream);
let detection = match reader.fill_buf().await {
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
Ok(buf) if !buf.is_empty() => {
if buf.starts_with(SSH_BANNER_PREFIX) {
ProtocolDetection::Ssh
} else {
ProtocolDetection::Http
}
}
_ => ProtocolDetection::Http,
};
(detection, reader)
}
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
let _ = reader.get_mut().shutdown().await;
}
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
if stealth && !transport_is_tls {
return Err("stealth mode requires TLS transport (--transport tls)");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(data).await.unwrap();
drop(client);
let (detection, _) = detect_protocol(server).await;
detection
}
#[tokio::test]
async fn ssh_banner_detected() {
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_other_implementation() {
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn ssh_banner_minimal() {
let detection = write_and_detect(b"SSH-2.0-X\n").await;
assert_eq!(detection, ProtocolDetection::Ssh);
}
#[tokio::test]
async fn http_get_detected_as_http() {
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_post_detected_as_http() {
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn random_data_detected_as_http() {
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn empty_stream_detected_as_http() {
let (client, server) = duplex(1024);
drop(client);
let (detection, _) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn ssh_banner_bytes_preserved_by_bufreader() {
let (client, server) = duplex(1024);
let mut client = client;
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
client.write_all(banner).await.unwrap();
client.write_all(b"subsequent data").await.unwrap();
drop(client);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Ssh);
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data).await.unwrap();
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
}
#[tokio::test]
async fn fake_nginx_404_response() {
let (client, server) = duplex(1024);
let (mut client_read, mut client_write) = tokio::io::split(client);
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
drop(client_write);
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client_read.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.contains("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
}
#[tokio::test]
async fn protocol_detection_enum_equality() {
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
}
#[test]
fn validate_stealth_without_tls_rejected() {
let result = validate_stealth_config(true, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("TLS transport"));
}
#[test]
fn validate_stealth_with_tls_accepted() {
let result = validate_stealth_config(true, true);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tcp_accepted() {
let result = validate_stealth_config(false, false);
assert!(result.is_ok());
}
#[test]
fn validate_no_stealth_with_tls_accepted() {
let result = validate_stealth_config(false, true);
assert!(result.is_ok());
}
#[tokio::test]
async fn short_data_detected_as_http() {
let detection = write_and_detect(b"GE").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn partial_ssh_prefix_detected_as_http() {
let detection = write_and_detect(b"SSH-1.").await;
assert_eq!(detection, ProtocolDetection::Http);
}
#[tokio::test]
async fn http_request_gets_404_then_closed() {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
send_fake_nginx_404(&mut reader).await;
let mut buf = [0u8; 256];
let n = client.read(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf[..n]);
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
assert!(response.contains("Server: nginx"));
let mut extra = [0u8; 16];
let result = client.read(&mut extra).await;
assert!(result.is_err() || result.unwrap() == 0);
}
}