Compare commits
1 Commits
feat/serve
...
feat/serve
| Author | SHA1 | Date | |
|---|---|---|---|
| 7dcf7502b7 |
@@ -62,7 +62,7 @@ pub enum ConfigError {
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ForwardError {
|
||||
#[error("invalid forward specification: {spec}")]
|
||||
#[error("invalid port forward spec: {spec}")]
|
||||
InvalidSpec { spec: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
@@ -74,6 +74,11 @@ pub enum ForwardError {
|
||||
#[source]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
},
|
||||
#[error("connect to local target failed")]
|
||||
LocalConnectFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub const WRAITH_CONTROL_DESTINATION: &str = "wraith-control";
|
||||
pub const WRAITH_PREFIX: &str = "wraith-";
|
||||
|
||||
pub fn is_reserved_destination(host: &str) -> bool {
|
||||
host.starts_with(WRAITH_PREFIX)
|
||||
}
|
||||
|
||||
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ControlChannelHandler: Send + Sync {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
|
||||
}
|
||||
|
||||
pub struct ControlChannelRouter {
|
||||
handler: Option<Box<dyn ControlChannelHandler>>,
|
||||
}
|
||||
|
||||
impl ControlChannelRouter {
|
||||
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
|
||||
pub fn without_handler() -> Self {
|
||||
Self { handler: None }
|
||||
}
|
||||
|
||||
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
Self {
|
||||
handler: Some(handler),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_handler(&self) -> bool {
|
||||
self.handler.is_some()
|
||||
}
|
||||
|
||||
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
|
||||
match &self.handler {
|
||||
Some(handler) => {
|
||||
handler.handle_channel(stream).await;
|
||||
Ok(())
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionRefused,
|
||||
"no control channel handler configured",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn wraith_control_destination_constant() {
|
||||
assert_eq!(WRAITH_CONTROL_DESTINATION, "wraith-control");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wraith_prefix_constant() {
|
||||
assert_eq!(WRAITH_PREFIX, "wraith-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_destination_detected() {
|
||||
assert!(is_reserved_destination("wraith-control"));
|
||||
assert!(is_reserved_destination("wraith-status"));
|
||||
assert!(is_reserved_destination("wraith-events"));
|
||||
assert!(is_reserved_destination("wraith-"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_reserved_destination_passes_through() {
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("192.168.1.1"));
|
||||
assert!(!is_reserved_destination("wraith.example.com"));
|
||||
assert!(!is_reserved_destination(""));
|
||||
assert!(!is_reserved_destination("wrait-control"));
|
||||
assert!(!is_reserved_destination("WRAITH-control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_matching_case_sensitive() {
|
||||
assert!(!is_reserved_destination("Wraith-control"));
|
||||
assert!(!is_reserved_destination("WRAITH-control"));
|
||||
assert!(is_reserved_destination("wraith-Control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_without_handler_has_no_handler() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
assert!(!router.has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_with_handler_has_handler() {
|
||||
struct DummyHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for DummyHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
|
||||
}
|
||||
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
|
||||
assert!(router.has_handler());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_without_handler_returns_error() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_succeeds() {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TrackedHandler {
|
||||
called: Arc<AtomicBool>,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackedHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let called = Arc::new(AtomicBool::new(false));
|
||||
let handler = TrackedHandler {
|
||||
called: called.clone(),
|
||||
};
|
||||
let router = ControlChannelRouter::with_handler(Box::new(handler));
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_can_read_write() {
|
||||
struct EchoHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for EchoHandler {
|
||||
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
|
||||
let mut buf = [0u8; 64];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
stream.write_all(&buf[..n]).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
|
||||
let (client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
tokio::spawn(async move {
|
||||
router.route(stream).await.unwrap();
|
||||
});
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
let mut client = client;
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_destination_matches_prefix() {
|
||||
assert!(is_reserved_destination(WRAITH_CONTROL_DESTINATION));
|
||||
}
|
||||
}
|
||||
@@ -7,9 +7,8 @@ use russh::server::{Auth, Handler, Msg, Session};
|
||||
use russh::Channel;
|
||||
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::server::control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
|
||||
};
|
||||
|
||||
const WRAITH_PREFIX: &str = "wraith-";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyMode {
|
||||
@@ -27,7 +26,6 @@ pub struct ServerHandler {
|
||||
auth_config: Arc<ServerAuthConfig>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
@@ -40,21 +38,8 @@ impl ServerHandler {
|
||||
auth_config,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_control_channel_handler(
|
||||
mut self,
|
||||
handler: Box<dyn ControlChannelHandler>,
|
||||
) -> Self {
|
||||
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn control_channel_router(&self) -> &ControlChannelRouter {
|
||||
&self.control_channel_router
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -113,16 +98,6 @@ impl Handler for ServerHandler {
|
||||
port = port_to_connect,
|
||||
"routing to internal control channel handler"
|
||||
);
|
||||
|
||||
if !self.control_channel_router.has_handler() {
|
||||
tracing::warn!(
|
||||
host = host_to_connect,
|
||||
"no control channel handler configured, rejecting channel open"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let _ = channel;
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
@@ -276,20 +251,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn reserved_wraith_destination_routing() {
|
||||
use crate::server::control_channel::is_reserved_destination;
|
||||
assert!(is_reserved_destination("wraith-control"));
|
||||
assert!(is_reserved_destination("wraith-status"));
|
||||
assert!(is_reserved_destination("wraith-events"));
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("wraith.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_without_control_handler_rejects_wraith_destinations() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler = ServerHandler::new(auth_config, None, None);
|
||||
assert!(!handler.control_channel_router().has_handler());
|
||||
assert!("wraith-control".starts_with(WRAITH_PREFIX));
|
||||
assert!("wraith-status".starts_with(WRAITH_PREFIX));
|
||||
assert!("wraith-events".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"example.com".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"localhost".starts_with(WRAITH_PREFIX));
|
||||
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
pub mod control_channel;
|
||||
pub mod handler;
|
||||
pub mod stealth;
|
||||
|
||||
pub use control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
|
||||
WRAITH_PREFIX, is_reserved_destination,
|
||||
};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};
|
||||
218
crates/wraith-core/src/server/stealth.rs
Normal file
218
crates/wraith-core/src/server/stealth.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
|
||||
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
|
||||
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProtocolDetection {
|
||||
Ssh,
|
||||
Http,
|
||||
}
|
||||
|
||||
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let detection = match reader.fill_buf().await {
|
||||
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
|
||||
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
Ok(buf) if !buf.is_empty() => {
|
||||
if buf.starts_with(SSH_BANNER_PREFIX) {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
_ => ProtocolDetection::Http,
|
||||
};
|
||||
|
||||
(detection, reader)
|
||||
}
|
||||
|
||||
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
|
||||
let _ = reader.get_mut().shutdown().await;
|
||||
}
|
||||
|
||||
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
|
||||
if stealth && !transport_is_tls {
|
||||
return Err("stealth mode requires TLS transport (--transport tls)");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client.write_all(data).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
detection
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_detected() {
|
||||
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_other_implementation() {
|
||||
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_minimal() {
|
||||
let detection = write_and_detect(b"SSH-2.0-X\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_get_detected_as_http() {
|
||||
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_post_detected_as_http() {
|
||||
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn random_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_stream_detected_as_http() {
|
||||
let (client, server) = duplex(1024);
|
||||
drop(client);
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_bytes_preserved_by_bufreader() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
|
||||
client.write_all(banner).await.unwrap();
|
||||
client.write_all(b"subsequent data").await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
|
||||
let mut all_data = Vec::new();
|
||||
reader.read_to_end(&mut all_data).await.unwrap();
|
||||
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fake_nginx_404_response() {
|
||||
let (client, server) = duplex(1024);
|
||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||
|
||||
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
||||
drop(client_write);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client_read.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.contains("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn protocol_detection_enum_equality() {
|
||||
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
|
||||
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
|
||||
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_without_tls_rejected() {
|
||||
let result = validate_stealth_config(true, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("TLS transport"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(true, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tcp_accepted() {
|
||||
let result = validate_stealth_config(false, false);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(false, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"GE").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_ssh_prefix_detected_as_http() {
|
||||
let detection = write_and_detect(b"SSH-1.").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_request_gets_404_then_closed() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
|
||||
let mut extra = [0u8; 16];
|
||||
let result = client.read(&mut extra).await;
|
||||
assert!(result.is_err() || result.unwrap() == 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user