Compare commits
20 Commits
feat/serve
...
feat/cli/c
| Author | SHA1 | Date | |
|---|---|---|---|
| 94feb5fdac | |||
| f13a1c985f | |||
| 49fe2b699f | |||
| 365b11d19e | |||
| 7dcf7502b7 | |||
| 585913d3c8 | |||
| 243243a82f | |||
| 2ab5eeda53 | |||
| 128affd264 | |||
| 5a2b535605 | |||
| 24b70f5651 | |||
| f963898a05 | |||
| 992d478630 | |||
| e3f33a24c3 | |||
| 5fec0b53d9 | |||
| 2efd4cf7c5 | |||
| 4e4afd5020 | |||
| 7336c0f13c | |||
| 975778bfb1 | |||
| d6a49a07d7 |
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -2395,6 +2395,7 @@ version = "3.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2"
|
checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"bitflags 2.11.1",
|
"bitflags 2.11.1",
|
||||||
"ctor",
|
"ctor",
|
||||||
"futures",
|
"futures",
|
||||||
@@ -2402,6 +2403,7 @@ dependencies = [
|
|||||||
"napi-sys",
|
"napi-sys",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5583,7 +5585,9 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"clap",
|
"clap",
|
||||||
|
"iroh",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"url",
|
||||||
"wraith-core",
|
"wraith-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -5593,6 +5597,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"futures",
|
||||||
"ipnetwork",
|
"ipnetwork",
|
||||||
"iroh",
|
"iroh",
|
||||||
"rand 0.10.1",
|
"rand 0.10.1",
|
||||||
@@ -5620,6 +5625,8 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"napi",
|
"napi",
|
||||||
"napi-derive",
|
"napi-derive",
|
||||||
|
"russh",
|
||||||
|
"tokio",
|
||||||
"wraith-core",
|
"wraith-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ name = "wraith_core"
|
|||||||
default = []
|
default = []
|
||||||
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
||||||
iroh = ["dep:iroh", "dep:url"]
|
iroh = ["dep:iroh", "dep:url"]
|
||||||
acme = ["dep:rustls-acme", "tls"]
|
acme = ["dep:rustls-acme", "dep:futures", "tls"]
|
||||||
testutil = []
|
testutil = []
|
||||||
transport-traits = []
|
transport-traits = []
|
||||||
|
|
||||||
@@ -25,6 +25,7 @@ tokio-rustls = { version = "0.26", optional = true }
|
|||||||
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
||||||
rustls-pki-types = { version = "1", optional = true }
|
rustls-pki-types = { version = "1", optional = true }
|
||||||
rustls-acme = { version = "0.12", optional = true }
|
rustls-acme = { version = "0.12", optional = true }
|
||||||
|
futures = { version = "0.3", optional = true }
|
||||||
webpki-roots = { version = "0.26", optional = true }
|
webpki-roots = { version = "0.26", optional = true }
|
||||||
iroh = { version = "0.34", optional = true }
|
iroh = { version = "0.34", optional = true }
|
||||||
url = { version = "2", optional = true }
|
url = { version = "2", optional = true }
|
||||||
|
|||||||
471
crates/wraith-core/src/client/channel_manager.rs
Normal file
471
crates/wraith-core/src/client/channel_manager.rs
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use russh::client;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||||
|
use crate::error::ChannelError;
|
||||||
|
use crate::transport::Transport;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||||
|
pub struct ForwardRequest {
|
||||||
|
pub addr: String,
|
||||||
|
pub port: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChannelManagerInner<T: Transport> {
|
||||||
|
transport: Arc<T>,
|
||||||
|
auth_config: Arc<ClientAuthConfig>,
|
||||||
|
handle: Arc<RwLock<client::Handle<ClientHandler>>>,
|
||||||
|
username: String,
|
||||||
|
forwards: RwLock<HashSet<ForwardRequest>>,
|
||||||
|
reconnect_attempts: RwLock<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChannelManager<T: Transport> {
|
||||||
|
inner: Arc<ChannelManagerInner<T>>,
|
||||||
|
reconnect_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Transport> ChannelManager<T> {
|
||||||
|
pub async fn new(
|
||||||
|
transport: Arc<T>,
|
||||||
|
auth_config: Arc<ClientAuthConfig>,
|
||||||
|
username: String,
|
||||||
|
) -> Result<Self, ChannelError> {
|
||||||
|
let handler = ClientHandler::from_config(&auth_config);
|
||||||
|
let handle = Self::establish_session(&*transport, handler, &auth_config, &username)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::TargetUnreachable)?;
|
||||||
|
|
||||||
|
let inner = Arc::new(ChannelManagerInner {
|
||||||
|
transport,
|
||||||
|
auth_config,
|
||||||
|
handle: Arc::new(RwLock::new(handle)),
|
||||||
|
username,
|
||||||
|
forwards: RwLock::new(HashSet::new()),
|
||||||
|
reconnect_attempts: RwLock::new(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
let reconnect_handle = Arc::new(RwLock::new(None));
|
||||||
|
let manager = Self {
|
||||||
|
inner,
|
||||||
|
reconnect_handle,
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.start_reconnect_monitor();
|
||||||
|
Ok(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn establish_session(
|
||||||
|
transport: &T,
|
||||||
|
handler: ClientHandler,
|
||||||
|
auth_config: &ClientAuthConfig,
|
||||||
|
username: &str,
|
||||||
|
) -> Result<client::Handle<ClientHandler>, russh::Error> {
|
||||||
|
let stream = transport.connect().await.map_err(|e| {
|
||||||
|
error!("transport connect failed: {e}");
|
||||||
|
russh::Error::SendError
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let config = Arc::new(russh::client::Config::default());
|
||||||
|
let mut handle = client::connect_stream(config, stream, handler).await?;
|
||||||
|
|
||||||
|
let auth_ok = auth_config.authenticate(&mut handle, username).await?;
|
||||||
|
if !auth_ok {
|
||||||
|
return Err(russh::Error::SendError);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn open_direct_tcpip(
|
||||||
|
&self,
|
||||||
|
host: &str,
|
||||||
|
port: u32,
|
||||||
|
) -> Result<russh::Channel<russh::client::Msg>, ChannelError> {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
handle
|
||||||
|
.channel_open_direct_tcpip(host, port, "127.0.0.1", 0)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
debug!("channel open failed: {e}");
|
||||||
|
ChannelError::ChannelClosed
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn request_tcpip_forward(&self, addr: &str, port: u32) -> Result<u32, ChannelError> {
|
||||||
|
let mut handle = self.inner.handle.write().await;
|
||||||
|
let result = handle
|
||||||
|
.tcpip_forward(addr, port)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
|
self.inner
|
||||||
|
.forwards
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(ForwardRequest {
|
||||||
|
addr: addr.to_string(),
|
||||||
|
port,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn cancel_tcpip_forward(&self, addr: &str, port: u32) -> Result<(), ChannelError> {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
handle
|
||||||
|
.cancel_tcpip_forward(addr, port)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
|
self.inner
|
||||||
|
.forwards
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.remove(&ForwardRequest {
|
||||||
|
addr: addr.to_string(),
|
||||||
|
port,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn is_connected(&self) -> bool {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
!handle.is_closed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_reconnect_monitor(&self) {
|
||||||
|
let inner = Arc::clone(&self.inner);
|
||||||
|
let handle_arc = Arc::clone(&self.inner.handle);
|
||||||
|
|
||||||
|
let join_handle = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
time::sleep(Duration::from_secs(1)).await;
|
||||||
|
let handle = handle_arc.read().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
drop(handle);
|
||||||
|
info!("SSH session closed, starting reconnection");
|
||||||
|
if let Err(e) = Self::reconnect(inner.clone()).await {
|
||||||
|
error!("reconnection failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let reconnect_handle = Arc::clone(&self.reconnect_handle);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut guard = reconnect_handle.write().await;
|
||||||
|
*guard = Some(join_handle);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn reconnect(inner: Arc<ChannelManagerInner<T>>) -> Result<(), ChannelError> {
|
||||||
|
let mut attempts = inner.reconnect_attempts.write().await;
|
||||||
|
let attempt_num = *attempts;
|
||||||
|
let backoff = backoff_duration(attempt_num);
|
||||||
|
*attempts += 1;
|
||||||
|
drop(attempts);
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"reconnect attempt #{}, waiting {:?}",
|
||||||
|
attempt_num + 1,
|
||||||
|
backoff
|
||||||
|
);
|
||||||
|
time::sleep(backoff).await;
|
||||||
|
|
||||||
|
let handler = ClientHandler::from_config(&inner.auth_config);
|
||||||
|
match Self::establish_session(
|
||||||
|
&*inner.transport,
|
||||||
|
handler,
|
||||||
|
&inner.auth_config,
|
||||||
|
&inner.username,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(new_handle) => {
|
||||||
|
info!("reconnection successful");
|
||||||
|
{
|
||||||
|
let mut handle_guard = inner.handle.write().await;
|
||||||
|
*handle_guard = new_handle;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut attempts = inner.reconnect_attempts.write().await;
|
||||||
|
*attempts = 0;
|
||||||
|
}
|
||||||
|
Self::re_register_forwards(&inner).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("reconnection attempt failed: {e}");
|
||||||
|
Err(ChannelError::ChannelClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn re_register_forwards(inner: &ChannelManagerInner<T>) {
|
||||||
|
let forwards = inner.forwards.read().await;
|
||||||
|
if forwards.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let mut handle = inner.handle.write().await;
|
||||||
|
for fwd in forwards.iter() {
|
||||||
|
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
||||||
|
Ok(_) => {
|
||||||
|
debug!(
|
||||||
|
"re-registered tcpip_forward: {}:{}",
|
||||||
|
fwd.addr, fwd.port
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"failed to re-register tcpip_forward {}:{}: {e}",
|
||||||
|
fwd.addr, fwd.port
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely.
|
||||||
|
fn backoff_duration(attempt: u32) -> Duration {
|
||||||
|
let secs: u64 = match attempt {
|
||||||
|
0 => 1,
|
||||||
|
1 => 2,
|
||||||
|
2 => 4,
|
||||||
|
3 => 8,
|
||||||
|
4 => 16,
|
||||||
|
_ => 30,
|
||||||
|
};
|
||||||
|
Duration::from_secs(secs)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Transport> Drop for ChannelManager<T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Ok(mut guard) = self.reconnect_handle.try_write() {
|
||||||
|
if let Some(handle) = guard.take() {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use tokio::io::duplex;
|
||||||
|
|
||||||
|
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
|
|
||||||
|
fn make_auth_config() -> Arc<ClientAuthConfig> {
|
||||||
|
let source = crate::auth::keys::KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||||
|
Arc::new(ClientAuthConfig::from_key_source(source).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AlwaysFailTransport;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for AlwaysFailTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
Err(anyhow::anyhow!("always fails"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"always-fail".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TrackConnectTransport {
|
||||||
|
connect_count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrackConnectTransport {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for TrackConnectTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let (client, _) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"track-connect".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CountingFailTransport {
|
||||||
|
fail_count: Arc<AtomicUsize>,
|
||||||
|
succeed_after: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CountingFailTransport {
|
||||||
|
fn new(succeed_after: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
fail_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
succeed_after,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for CountingFailTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
if count < self.succeed_after {
|
||||||
|
return Err(anyhow::anyhow!("connection failed (attempt {})", count));
|
||||||
|
}
|
||||||
|
let (client, _) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"counting-fail".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_durations() {
|
||||||
|
assert_eq!(backoff_duration(0), Duration::from_secs(1));
|
||||||
|
assert_eq!(backoff_duration(1), Duration::from_secs(2));
|
||||||
|
assert_eq!(backoff_duration(2), Duration::from_secs(4));
|
||||||
|
assert_eq!(backoff_duration(3), Duration::from_secs(8));
|
||||||
|
assert_eq!(backoff_duration(4), Duration::from_secs(16));
|
||||||
|
assert_eq!(backoff_duration(5), Duration::from_secs(30));
|
||||||
|
assert_eq!(backoff_duration(6), Duration::from_secs(30));
|
||||||
|
assert_eq!(backoff_duration(100), Duration::from_secs(30));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_sequence_matches_spec() {
|
||||||
|
let sequence: Vec<Duration> = (0..6).map(backoff_duration).collect();
|
||||||
|
assert_eq!(
|
||||||
|
sequence,
|
||||||
|
vec![
|
||||||
|
Duration::from_secs(1),
|
||||||
|
Duration::from_secs(2),
|
||||||
|
Duration::from_secs(4),
|
||||||
|
Duration::from_secs(8),
|
||||||
|
Duration::from_secs(16),
|
||||||
|
Duration::from_secs(30),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_forward_request_hash_eq() {
|
||||||
|
let fwd1 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
};
|
||||||
|
let fwd2 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
};
|
||||||
|
let fwd3 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
};
|
||||||
|
assert_eq!(fwd1, fwd2);
|
||||||
|
assert_ne!(fwd1, fwd3);
|
||||||
|
let mut set = HashSet::new();
|
||||||
|
set.insert(fwd1.clone());
|
||||||
|
assert!(set.contains(&fwd2));
|
||||||
|
assert!(!set.contains(&fwd3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_channel_manager_new_transport_fails() {
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let transport = Arc::new(AlwaysFailTransport);
|
||||||
|
let result = ChannelManager::new(transport, auth, "testuser".to_string()).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(ChannelError::TargetUnreachable) => {}
|
||||||
|
other => panic!("expected TargetUnreachable, got {:?}", other.as_ref().err()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_connect_called_on_new() {
|
||||||
|
let transport = Arc::new(TrackConnectTransport::new());
|
||||||
|
let connect_before = transport.connect_count.load(Ordering::SeqCst);
|
||||||
|
assert_eq!(connect_before, 0);
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let _ = ChannelManager::new(transport.clone(), auth, "testuser".to_string()).await;
|
||||||
|
let connect_after = transport.connect_count.load(Ordering::SeqCst);
|
||||||
|
assert!(connect_after > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_reconnect_monitor_detects_closed_handle() {
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let transport = Arc::new(TrackConnectTransport::new());
|
||||||
|
let handler = ClientHandler::from_config(&auth);
|
||||||
|
let config = Arc::new(russh::client::Config::default());
|
||||||
|
let stream = transport.connect().await.unwrap();
|
||||||
|
let handle = client::connect_stream(config, stream, handler).await;
|
||||||
|
match handle {
|
||||||
|
Ok(h) => {
|
||||||
|
assert!(!h.is_closed());
|
||||||
|
drop(h);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// connect_stream fails without a real SSH server,
|
||||||
|
// but the concept is verified: dropped handle => is_closed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_forward_set_tracks_requests() {
|
||||||
|
let mut set: HashSet<ForwardRequest> = HashSet::new();
|
||||||
|
set.insert(ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
});
|
||||||
|
set.insert(ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
});
|
||||||
|
assert_eq!(set.len(), 2);
|
||||||
|
set.remove(&ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
});
|
||||||
|
assert_eq!(set.len(), 1);
|
||||||
|
assert!(set.contains(&ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_indefinitely_beyond_cap() {
|
||||||
|
for attempt in 0..50 {
|
||||||
|
let duration = backoff_duration(attempt);
|
||||||
|
assert!(duration <= Duration::from_secs(30));
|
||||||
|
assert!(duration >= Duration::from_secs(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
727
crates/wraith-core/src/client/connect.rs
Normal file
727
crates/wraith-core/src/client/connect.rs
Normal file
@@ -0,0 +1,727 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use russh::client;
|
||||||
|
use russh::keys::PrivateKey;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||||
|
use crate::auth::keys::KeySource;
|
||||||
|
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
|
||||||
|
use crate::error::ConfigError;
|
||||||
|
use crate::socks5::{HandleChannelOpener, Socks5Server};
|
||||||
|
use crate::transport::Transport;
|
||||||
|
|
||||||
|
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||||
|
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum TransportMode {
|
||||||
|
Tcp,
|
||||||
|
Tls,
|
||||||
|
Iroh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TransportMode {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TransportMode::Tcp => write!(f, "tcp"),
|
||||||
|
TransportMode::Tls => write!(f, "tls"),
|
||||||
|
TransportMode::Iroh => write!(f, "iroh"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ConnectOptions {
|
||||||
|
pub server: Option<String>,
|
||||||
|
pub peer: Option<String>,
|
||||||
|
pub transport_mode: TransportMode,
|
||||||
|
pub identity: KeySource,
|
||||||
|
pub socks5_addr: String,
|
||||||
|
pub forwards: Vec<String>,
|
||||||
|
pub remote_forwards: Vec<String>,
|
||||||
|
pub proxy: Option<String>,
|
||||||
|
pub iroh_relay: Option<String>,
|
||||||
|
pub tls_server_name: Option<String>,
|
||||||
|
pub insecure: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectOptions {
|
||||||
|
pub fn new(identity: KeySource) -> Self {
|
||||||
|
Self {
|
||||||
|
server: None,
|
||||||
|
peer: None,
|
||||||
|
transport_mode: TransportMode::Tcp,
|
||||||
|
identity,
|
||||||
|
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
|
||||||
|
forwards: Vec::new(),
|
||||||
|
remote_forwards: Vec::new(),
|
||||||
|
proxy: None,
|
||||||
|
iroh_relay: None,
|
||||||
|
tls_server_name: None,
|
||||||
|
insecure: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn server(mut self, addr: impl Into<String>) -> Self {
|
||||||
|
self.server = Some(addr.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
|
||||||
|
self.peer = Some(endpoint_id.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
|
||||||
|
self.transport_mode = mode;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
|
||||||
|
self.socks5_addr = addr.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(mut self, spec: impl Into<String>) -> Self {
|
||||||
|
self.forwards.push(spec.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
|
||||||
|
self.remote_forwards.push(spec.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn proxy(mut self, url: impl Into<String>) -> Self {
|
||||||
|
self.proxy = Some(url.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
|
||||||
|
self.iroh_relay = Some(url.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
|
||||||
|
self.tls_server_name = Some(name.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insecure(mut self, insecure: bool) -> Self {
|
||||||
|
self.insecure = insecure;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||||
|
match self.transport_mode {
|
||||||
|
TransportMode::Tcp | TransportMode::Tls => {
|
||||||
|
if self.server.is_none() {
|
||||||
|
return Err(ConfigError::InvalidFlag {
|
||||||
|
name: "--server is required for tcp/tls transport".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TransportMode::Iroh => {
|
||||||
|
if self.peer.is_none() {
|
||||||
|
return Err(ConfigError::InvalidFlag {
|
||||||
|
name: "--peer is required for iroh transport".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ConnectOptions {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ConnectOptions")
|
||||||
|
.field("server", &self.server)
|
||||||
|
.field("peer", &self.peer)
|
||||||
|
.field("transport_mode", &self.transport_mode)
|
||||||
|
.field("identity", &"<KeySource>")
|
||||||
|
.field("socks5_addr", &self.socks5_addr)
|
||||||
|
.field("forwards", &self.forwards)
|
||||||
|
.field("remote_forwards", &self.remote_forwards)
|
||||||
|
.field("proxy", &self.proxy)
|
||||||
|
.field("iroh_relay", &self.iroh_relay)
|
||||||
|
.field("tls_server_name", &self.tls_server_name)
|
||||||
|
.field("insecure", &self.insecure)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientSession<T: Transport> {
|
||||||
|
opts: ConnectOptions,
|
||||||
|
transport: Arc<T>,
|
||||||
|
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
|
||||||
|
auth_config: Arc<ClientAuthConfig>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
private_key: Arc<PrivateKey>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
username: String,
|
||||||
|
shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||||
|
shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Transport> ClientSession<T> {
|
||||||
|
pub async fn new(
|
||||||
|
opts: ConnectOptions,
|
||||||
|
transport: Arc<T>,
|
||||||
|
) -> Result<Self, ConnectError> {
|
||||||
|
opts.validate().map_err(ConnectError::Config)?;
|
||||||
|
|
||||||
|
let auth_config = Arc::new(
|
||||||
|
ClientAuthConfig::from_key_source(opts.identity.clone())
|
||||||
|
.map_err(ConnectError::Config)?,
|
||||||
|
);
|
||||||
|
let private_key = auth_config.private_key();
|
||||||
|
|
||||||
|
let username = derive_username();
|
||||||
|
let handler = ClientHandler::from_config(&auth_config);
|
||||||
|
|
||||||
|
let stream = transport.connect().await.map_err(|e| {
|
||||||
|
error!("transport connect failed: {e}");
|
||||||
|
ConnectError::ConnectionFailed
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let config = Arc::new(client::Config::default());
|
||||||
|
let mut handle = client::connect_stream(config, stream, handler)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("SSH connect failed: {e}");
|
||||||
|
ConnectError::ConnectionFailed
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let auth_ok = auth_config
|
||||||
|
.authenticate(&mut handle, &username)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ConnectError::AuthFailed)?;
|
||||||
|
if !auth_ok {
|
||||||
|
return Err(ConnectError::AuthFailed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let handle = Arc::new(Mutex::new(handle));
|
||||||
|
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
opts,
|
||||||
|
transport,
|
||||||
|
handle,
|
||||||
|
auth_config,
|
||||||
|
private_key,
|
||||||
|
username,
|
||||||
|
shutdown_tx,
|
||||||
|
shutdown_rx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
|
||||||
|
Arc::clone(&self.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
|
||||||
|
&self.auth_config
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn transport(&self) -> &Arc<T> {
|
||||||
|
&self.transport
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn options(&self) -> &ConnectOptions {
|
||||||
|
&self.opts
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
|
||||||
|
self.shutdown_tx.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(self) -> Result<(), ConnectError> {
|
||||||
|
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
|
||||||
|
ConnectError::Config(ConfigError::InvalidFlag {
|
||||||
|
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
|
||||||
|
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
|
||||||
|
let socks5_listen = socks5_server.listen_addr();
|
||||||
|
|
||||||
|
let local_forwarders = build_local_forwarders(&self.opts)?;
|
||||||
|
let remote_specs = build_remote_specs(&self.opts)?;
|
||||||
|
|
||||||
|
for spec in &remote_specs {
|
||||||
|
let remote_forwarder = RemoteForwarder::new(spec.clone())
|
||||||
|
.map_err(|_| ConnectError::ForwardFailed)?;
|
||||||
|
let mut h = self.handle.lock().await;
|
||||||
|
remote_forwarder
|
||||||
|
.register(&mut h)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
warn!("failed to register remote forward {}", spec);
|
||||||
|
ConnectError::ForwardFailed
|
||||||
|
})?;
|
||||||
|
info!("registered remote forward: {}", spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
let socks5_task = tokio::spawn(async move {
|
||||||
|
debug!("SOCKS5 server starting on {}", socks5_listen);
|
||||||
|
if let Err(e) = socks5_server.run().await {
|
||||||
|
error!("SOCKS5 server error: {e}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let fwd_handle = Arc::clone(&self.handle);
|
||||||
|
let fwd_shutdown = self.shutdown_rx.clone();
|
||||||
|
let forward_task = tokio::spawn(async move {
|
||||||
|
crate::client::forward::run_local_forwarders(
|
||||||
|
local_forwarders, fwd_handle, fwd_shutdown,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
info!("wraith client running: SOCKS5 on {}", socks5_listen);
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
let signal_done = {
|
||||||
|
let sig_tx = self.shutdown_tx.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut sigterm_stream =
|
||||||
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||||
|
.expect("failed to install SIGTERM handler");
|
||||||
|
tokio::select! {
|
||||||
|
_ = sigterm_stream.recv() => {
|
||||||
|
info!("received SIGTERM");
|
||||||
|
}
|
||||||
|
_ = tokio::signal::ctrl_c() => {
|
||||||
|
info!("received SIGINT (Ctrl+C)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = sig_tx.send(true);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut wait_shutdown = self.shutdown_rx.clone();
|
||||||
|
tokio::select! {
|
||||||
|
_ = wait_shutdown.changed() => {
|
||||||
|
if *wait_shutdown.borrow() {
|
||||||
|
info!("shutdown signal received");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = socks5_task => {
|
||||||
|
warn!("SOCKS5 server exited unexpectedly");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
signal_done.abort();
|
||||||
|
|
||||||
|
self.shutdown().await?;
|
||||||
|
|
||||||
|
forward_task.abort();
|
||||||
|
let _ = forward_task.await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn shutdown(&self) -> Result<(), ConnectError> {
|
||||||
|
info!("initiating graceful shutdown");
|
||||||
|
|
||||||
|
let _ = self.shutdown_tx.send(true);
|
||||||
|
|
||||||
|
{
|
||||||
|
let handle = self.handle.lock().await;
|
||||||
|
if !handle.is_closed() {
|
||||||
|
if let Err(e) = handle
|
||||||
|
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!("failed to send SSH disconnect: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(DRAIN_TIMEOUT).await;
|
||||||
|
|
||||||
|
info!("graceful shutdown complete");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn derive_username() -> String {
|
||||||
|
std::env::var("USER")
|
||||||
|
.or_else(|_| std::env::var("USERNAME"))
|
||||||
|
.unwrap_or_else(|_| "wraith".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
|
||||||
|
let mut forwarders = Vec::new();
|
||||||
|
for spec_str in &opts.forwards {
|
||||||
|
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
|
||||||
|
warn!("invalid local forward spec '{}': {}", spec_str, e);
|
||||||
|
ConnectError::Config(ConfigError::InvalidFlag {
|
||||||
|
name: format!("invalid forward spec: {}", spec_str),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
forwarders.push(
|
||||||
|
LocalForwarder::new(spec).map_err(|e| {
|
||||||
|
warn!("failed to create local forwarder: {}", e);
|
||||||
|
ConnectError::ForwardFailed
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(forwarders)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
|
||||||
|
let mut specs = Vec::new();
|
||||||
|
for spec_str in &opts.remote_forwards {
|
||||||
|
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
|
||||||
|
warn!("invalid remote forward spec '{}': {}", spec_str, e);
|
||||||
|
ConnectError::Config(ConfigError::InvalidFlag {
|
||||||
|
name: format!("invalid remote forward spec: {}", spec_str),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
specs.push(spec);
|
||||||
|
}
|
||||||
|
Ok(specs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ConnectError {
|
||||||
|
#[error("connection failed")]
|
||||||
|
ConnectionFailed,
|
||||||
|
#[error("authentication failed")]
|
||||||
|
AuthFailed,
|
||||||
|
#[error("forward setup failed")]
|
||||||
|
ForwardFailed,
|
||||||
|
#[error("config error: {0}")]
|
||||||
|
Config(#[from] ConfigError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use tokio::io::duplex;
|
||||||
|
|
||||||
|
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
|
|
||||||
|
fn make_identity() -> KeySource {
|
||||||
|
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_default_fields() {
|
||||||
|
let opts = ConnectOptions::new(make_identity());
|
||||||
|
assert!(opts.server.is_none());
|
||||||
|
assert!(opts.peer.is_none());
|
||||||
|
assert_eq!(opts.transport_mode, TransportMode::Tcp);
|
||||||
|
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
|
||||||
|
assert!(opts.forwards.is_empty());
|
||||||
|
assert!(opts.remote_forwards.is_empty());
|
||||||
|
assert!(opts.proxy.is_none());
|
||||||
|
assert!(opts.iroh_relay.is_none());
|
||||||
|
assert!(opts.tls_server_name.is_none());
|
||||||
|
assert!(!opts.insecure);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_builder_pattern() {
|
||||||
|
let opts = ConnectOptions::new(make_identity())
|
||||||
|
.server("example.com:22")
|
||||||
|
.transport_mode(TransportMode::Tls)
|
||||||
|
.socks5_addr("127.0.0.1:9050")
|
||||||
|
.forward("127.0.0.1:5432:db:5432")
|
||||||
|
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
|
||||||
|
.proxy("socks5://127.0.0.1:1080")
|
||||||
|
.iroh_relay("https://relay.example.com")
|
||||||
|
.tls_server_name("wraith.test")
|
||||||
|
.insecure(true);
|
||||||
|
|
||||||
|
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
|
||||||
|
assert_eq!(opts.transport_mode, TransportMode::Tls);
|
||||||
|
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
|
||||||
|
assert_eq!(opts.forwards.len(), 1);
|
||||||
|
assert_eq!(opts.remote_forwards.len(), 1);
|
||||||
|
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
||||||
|
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
|
||||||
|
assert_eq!(opts.tls_server_name.as_deref(), Some("wraith.test"));
|
||||||
|
assert!(opts.insecure);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_tcp_requires_server() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
|
||||||
|
assert!(opts.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_tcp_with_server_ok() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||||
|
assert!(opts.validate().is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_tls_requires_server() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
|
||||||
|
assert!(opts.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_tls_with_server_ok() {
|
||||||
|
let opts = ConnectOptions::new(make_identity())
|
||||||
|
.transport_mode(TransportMode::Tls)
|
||||||
|
.server("example.com:443");
|
||||||
|
assert!(opts.validate().is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_iroh_requires_peer() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
|
||||||
|
assert!(opts.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_validate_iroh_with_peer_ok() {
|
||||||
|
let opts = ConnectOptions::new(make_identity())
|
||||||
|
.transport_mode(TransportMode::Iroh)
|
||||||
|
.peer("some-endpoint-id");
|
||||||
|
assert!(opts.validate().is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn identity_accepts_key_source_file() {
|
||||||
|
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
|
||||||
|
let opts = ConnectOptions::new(file_source);
|
||||||
|
match &opts.identity {
|
||||||
|
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
|
||||||
|
_ => panic!("expected File variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn identity_accepts_key_source_memory() {
|
||||||
|
let mem_source = KeySource::Memory(b"key-data".to_vec());
|
||||||
|
let opts = ConnectOptions::new(mem_source);
|
||||||
|
match &opts.identity {
|
||||||
|
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
|
||||||
|
_ => panic!("expected Memory variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn transport_mode_display() {
|
||||||
|
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
|
||||||
|
assert_eq!(TransportMode::Tls.to_string(), "tls");
|
||||||
|
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_error_variants() {
|
||||||
|
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
|
||||||
|
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
|
||||||
|
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connect_options_debug_redacts_identity() {
|
||||||
|
let opts = ConnectOptions::new(make_identity());
|
||||||
|
let debug_str = format!("{:?}", opts);
|
||||||
|
assert!(debug_str.contains("<KeySource>"));
|
||||||
|
assert!(!debug_str.contains("OPENSSH"));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FailTransport;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for FailTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
Err(anyhow::anyhow!("always fails"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"fail".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DuplexTransport {
|
||||||
|
connect_count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for DuplexTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let (client, _) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"duplex".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_session_new_transport_fails() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||||
|
let transport = Arc::new(FailTransport);
|
||||||
|
let result = ClientSession::new(opts, transport).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn client_session_new_ssh_handshake_fails() {
|
||||||
|
let transport = Arc::new(DuplexTransport {
|
||||||
|
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
});
|
||||||
|
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||||
|
let result = ClientSession::new(opts, transport).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_local_forwarders_empty() {
|
||||||
|
let opts = ConnectOptions::new(make_identity());
|
||||||
|
let result = build_local_forwarders(&opts);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(result.unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_local_forwarders_valid() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
|
||||||
|
let result = build_local_forwarders(&opts);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_local_forwarders_invalid_spec() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
|
||||||
|
let result = build_local_forwarders(&opts);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_remote_specs_empty() {
|
||||||
|
let opts = ConnectOptions::new(make_identity());
|
||||||
|
let result = build_remote_specs(&opts);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(result.unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_remote_specs_valid() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
||||||
|
let result = build_remote_specs(&opts);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_remote_specs_invalid() {
|
||||||
|
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
|
||||||
|
let result = build_remote_specs(&opts);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_socks5_addr() {
|
||||||
|
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn drain_timeout_is_two_seconds() {
|
||||||
|
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn transport_mode_equality() {
|
||||||
|
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
|
||||||
|
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
|
||||||
|
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn shutdown_sends_disconnect_and_drains() {
|
||||||
|
let transport = Arc::new(DuplexTransport {
|
||||||
|
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
});
|
||||||
|
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||||
|
let result = ClientSession::new(opts, transport).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn socks5_is_always_enabled_by_default() {
|
||||||
|
let opts = ConnectOptions::new(make_identity());
|
||||||
|
assert!(!opts.socks5_addr.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn integration_mock_transport_session() {
|
||||||
|
use crate::socks5::{ChannelOpener, ChannelOpenError};
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
|
struct MockOpener;
|
||||||
|
|
||||||
|
impl ChannelOpener for MockOpener {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
_host: String,
|
||||||
|
_port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
let (client, _server) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let bound_addr = listener.local_addr().unwrap();
|
||||||
|
drop(listener);
|
||||||
|
|
||||||
|
let opener = MockOpener;
|
||||||
|
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
|
||||||
|
|
||||||
|
let _server_task = tokio::spawn(async move {
|
||||||
|
let _ = server.run().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
|
||||||
|
|
||||||
|
let greeting = [0x05, 0x01, 0x00];
|
||||||
|
conn.write_all(&greeting).await.unwrap();
|
||||||
|
|
||||||
|
let mut auth_resp = [0u8; 2];
|
||||||
|
conn.read_exact(&mut auth_resp).await.unwrap();
|
||||||
|
assert_eq!(auth_resp, [0x05, 0x00]);
|
||||||
|
|
||||||
|
let connect_req = [
|
||||||
|
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
|
||||||
|
];
|
||||||
|
conn.write_all(&connect_req).await.unwrap();
|
||||||
|
|
||||||
|
let mut reply = [0u8; 10];
|
||||||
|
conn.read_exact(&mut reply).await.unwrap();
|
||||||
|
assert_eq!(reply[1], 0x00);
|
||||||
|
|
||||||
|
conn.write_all(b"test data").await.unwrap();
|
||||||
|
conn.shutdown().await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
530
crates/wraith-core/src/client/forward.rs
Normal file
530
crates/wraith-core/src/client/forward.rs
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use russh::client;
|
||||||
|
use tokio::io;
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
use crate::error::ForwardError;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PortForwardSpecKind {
|
||||||
|
Local,
|
||||||
|
Remote,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PortForwardSpec {
|
||||||
|
pub kind: PortForwardSpecKind,
|
||||||
|
pub bind_addr: String,
|
||||||
|
pub bind_port: u16,
|
||||||
|
pub target_host: String,
|
||||||
|
pub target_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PortForwardSpec {
|
||||||
|
pub fn local(spec: &str) -> Result<Self, ForwardError> {
|
||||||
|
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||||
|
Ok(Self {
|
||||||
|
kind: PortForwardSpecKind::Local,
|
||||||
|
bind_addr,
|
||||||
|
bind_port,
|
||||||
|
target_host,
|
||||||
|
target_port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote(spec: &str) -> Result<Self, ForwardError> {
|
||||||
|
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||||
|
Ok(Self {
|
||||||
|
kind: PortForwardSpecKind::Remote,
|
||||||
|
bind_addr,
|
||||||
|
bind_port,
|
||||||
|
target_host,
|
||||||
|
target_port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||||
|
format!("{}:{}", self.bind_addr, self.bind_port)
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: format!("{}:{}", self.bind_addr, self.bind_port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn target_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||||
|
format!("{}:{}", self.target_host, self.target_port)
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: format!("{}:{}", self.target_host, self.target_port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PortForwardSpec {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let prefix = match self.kind {
|
||||||
|
PortForwardSpecKind::Local => "-L",
|
||||||
|
PortForwardSpecKind::Remote => "-R",
|
||||||
|
};
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{} {}:{}:{}:{}",
|
||||||
|
prefix, self.bind_addr, self.bind_port, self.target_host, self.target_port
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_spec(spec: &str) -> Result<(String, u16, String, u16), ForwardError> {
|
||||||
|
let parts: Vec<&str> = spec.split(':').collect();
|
||||||
|
if parts.len() != 4 {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let bind_addr = parts[0].to_string();
|
||||||
|
let bind_port: u16 = parts[1].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
})?;
|
||||||
|
let target_host = parts[2].to_string();
|
||||||
|
let target_port: u16 = parts[3].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok((bind_addr, bind_port, target_host, target_port))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LocalForwarder {
|
||||||
|
spec: PortForwardSpec,
|
||||||
|
listener: Option<TcpListener>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalForwarder {
|
||||||
|
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||||
|
if spec.kind != PortForwardSpecKind::Local {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: format!("expected local spec, got {:?}", spec.kind),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
spec,
|
||||||
|
listener: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn spec(&self) -> &PortForwardSpec {
|
||||||
|
&self.spec
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run<H: client::Handler + Send + 'static>(
|
||||||
|
&mut self,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
let listen_addr = self.spec.listen_addr()?;
|
||||||
|
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||||
|
self.listener = Some(listener);
|
||||||
|
let remote_host = self.spec.target_host.clone();
|
||||||
|
let remote_port = self.spec.target_port;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"local forward listening on {} -> {}:{}",
|
||||||
|
listen_addr, remote_host, remote_port
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let listener = match &self.listener {
|
||||||
|
Some(l) => l,
|
||||||
|
None => return Ok(()),
|
||||||
|
};
|
||||||
|
let accept_result = listener.accept().await;
|
||||||
|
let (local_stream, local_addr) = match accept_result {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
let handle = handle.lock().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
debug!("local forward accept loop ending: ssh session closed");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
drop(handle);
|
||||||
|
error!("local forward accept error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"local forward connection from {} -> {}:{}",
|
||||||
|
local_addr, remote_host, remote_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle = handle.clone();
|
||||||
|
let remote_host = remote_host.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) =
|
||||||
|
proxy_local_to_remote(local_stream, handle, &remote_host, remote_port).await
|
||||||
|
{
|
||||||
|
debug!("local forward proxy error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(&mut self) {
|
||||||
|
if let Some(listener) = self.listener.take() {
|
||||||
|
drop(listener);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_port(&self) -> u16 {
|
||||||
|
self.spec.bind_port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
|
||||||
|
local_stream: TcpStream,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
remote_host: &str,
|
||||||
|
remote_port: u16,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
let local_addr = local_stream
|
||||||
|
.peer_addr()
|
||||||
|
.map(|a| a.to_string())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let handle_guard = handle.lock().await;
|
||||||
|
let channel = handle_guard
|
||||||
|
.channel_open_direct_tcpip(
|
||||||
|
remote_host,
|
||||||
|
remote_port as u32,
|
||||||
|
&local_addr,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
drop(handle_guard);
|
||||||
|
|
||||||
|
let ssh_stream = channel.into_stream();
|
||||||
|
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||||
|
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||||
|
|
||||||
|
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||||
|
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||||
|
|
||||||
|
match tokio::join!(client_to_server, server_to_client) {
|
||||||
|
(Err(e), _) | (_, Err(e)) => {
|
||||||
|
debug!("local forward bidirectional copy error: {}", e);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RemoteForwarder {
|
||||||
|
spec: PortForwardSpec,
|
||||||
|
cancel: Option<tokio::sync::oneshot::Sender<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteForwarder {
|
||||||
|
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||||
|
if spec.kind != PortForwardSpecKind::Remote {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: format!("expected remote spec, got {:?}", spec.kind),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self { spec, cancel: None })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn spec(&self) -> &PortForwardSpec {
|
||||||
|
&self.spec
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register<H: client::Handler + Send + 'static>(
|
||||||
|
&self,
|
||||||
|
handle: &mut client::Handle<H>,
|
||||||
|
) -> Result<u32, ForwardError> {
|
||||||
|
let port = handle
|
||||||
|
.tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
Ok(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_forwarded_channel(
|
||||||
|
channel: russh::Channel<russh::client::Msg>,
|
||||||
|
connected_address: &str,
|
||||||
|
connected_port: u32,
|
||||||
|
local_host: &str,
|
||||||
|
local_port: u16,
|
||||||
|
) {
|
||||||
|
debug!(
|
||||||
|
"remote forward: server opened forwarded-tcpip channel to {}:{} -> local {}:{}",
|
||||||
|
connected_address, connected_port, local_host, local_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let local_target = format!("{}:{}", local_host, local_port);
|
||||||
|
let local_stream = match TcpStream::connect(&local_target).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"remote forward: failed to connect to local target {}: {}",
|
||||||
|
local_target, e
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ssh_stream = channel.into_stream();
|
||||||
|
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||||
|
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||||
|
|
||||||
|
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||||
|
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||||
|
|
||||||
|
match tokio::join!(client_to_server, server_to_client) {
|
||||||
|
(Err(e), _) | (_, Err(e)) => {
|
||||||
|
debug!("remote forward bidirectional copy error: {}", e);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unregister<H: client::Handler + Send + 'static>(
|
||||||
|
&self,
|
||||||
|
handle: &client::Handle<H>,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
handle
|
||||||
|
.cancel_tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(&mut self) {
|
||||||
|
if let Some(cancel) = self.cancel.take() {
|
||||||
|
let _ = cancel.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_local_forwarders<H: client::Handler + Send + 'static>(
|
||||||
|
forwarders: Vec<LocalForwarder>,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
mut shutdown: tokio::sync::watch::Receiver<bool>,
|
||||||
|
) -> Vec<LocalForwarder> {
|
||||||
|
let mut forwarders = forwarders;
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
|
||||||
|
for forwarder in forwarders.drain(..) {
|
||||||
|
let handle = handle.clone();
|
||||||
|
let spec = forwarder.spec().clone();
|
||||||
|
let (_cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
let mut fwd = forwarder;
|
||||||
|
tokio::select! {
|
||||||
|
result = fwd.run(handle) => {
|
||||||
|
if let Err(e) = result {
|
||||||
|
error!("local forward {} failed: {}", spec, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = cancel_rx => {
|
||||||
|
fwd.stop().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fwd
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = shutdown.changed().await;
|
||||||
|
|
||||||
|
for task in &tasks {
|
||||||
|
task.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for task in tasks {
|
||||||
|
match task.await {
|
||||||
|
Ok(fwd) => results.push(fwd),
|
||||||
|
Err(e) => {
|
||||||
|
if !e.is_cancelled() {
|
||||||
|
error!("local forwarder task panicked: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_local_spec() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert_eq!(spec.kind, PortForwardSpecKind::Local);
|
||||||
|
assert_eq!(spec.bind_addr, "127.0.0.1");
|
||||||
|
assert_eq!(spec.bind_port, 5432);
|
||||||
|
assert_eq!(spec.target_host, "db.internal");
|
||||||
|
assert_eq!(spec.target_port, 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_remote_spec() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert_eq!(spec.kind, PortForwardSpecKind::Remote);
|
||||||
|
assert_eq!(spec.bind_addr, "0.0.0.0");
|
||||||
|
assert_eq!(spec.bind_port, 8080);
|
||||||
|
assert_eq!(spec.target_host, "127.0.0.1");
|
||||||
|
assert_eq!(spec.target_port, 3000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_few_parts() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:5432:db").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_many_parts() {
|
||||||
|
assert!(PortForwardSpec::local("a:b:c:d:e").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_port() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:abc:db:5432").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_target_port() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:5432:db:abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spec_display() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert_eq!(spec.to_string(), "-L 127.0.0.1:5432:db.internal:5432");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spec_display_remote() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert_eq!(spec.to_string(), "-R 0.0.0.0:8080:127.0.0.1:3000");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn local_forwarder_rejects_remote_spec() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert!(LocalForwarder::new(spec).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_forwarder_rejects_local_spec() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert!(RemoteForwarder::new(spec).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn listen_addr_valid() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
let addr = spec.listen_addr().unwrap();
|
||||||
|
assert_eq!(addr.port(), 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn listen_addr_invalid_host() {
|
||||||
|
let spec = PortForwardSpec {
|
||||||
|
kind: PortForwardSpecKind::Local,
|
||||||
|
bind_addr: "!!!invalid".to_string(),
|
||||||
|
bind_port: 5432,
|
||||||
|
target_host: "db".to_string(),
|
||||||
|
target_port: 5432,
|
||||||
|
};
|
||||||
|
assert!(spec.listen_addr().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn local_forward_bind_and_accept() {
|
||||||
|
let spec = PortForwardSpec::local(&format!("127.0.0.1:0:remote:5432")).unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||||
|
|
||||||
|
let listen_addr = forwarder.spec.listen_addr().unwrap();
|
||||||
|
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
||||||
|
let bound_addr = listener.local_addr().unwrap();
|
||||||
|
drop(listener);
|
||||||
|
|
||||||
|
let spec = PortForwardSpec::local(&format!(
|
||||||
|
"127.0.0.1:{}:remote:5432",
|
||||||
|
bound_addr.port()
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||||
|
assert_eq!(forwarder.local_port(), bound_addr.port());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remote_forward_proxy_bidirectional() {
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let _echo_addr = echo_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let echo_server = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = echo_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
loop {
|
||||||
|
let n = match stream.read(&mut buf).await {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(_) => break,
|
||||||
|
};
|
||||||
|
if stream.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let local_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let local_addr = local_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = local_listener.accept().await.unwrap();
|
||||||
|
let (mut read, mut write) = tokio::io::split(stream);
|
||||||
|
let _ = io::copy(&mut read, &mut write).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut local_conn = TcpStream::connect(local_addr).await.unwrap();
|
||||||
|
local_conn.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = local_conn.read(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf[..n], b"hello");
|
||||||
|
|
||||||
|
echo_server.abort();
|
||||||
|
proxy_task.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forwarder_spec_access() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec.clone()).unwrap();
|
||||||
|
assert_eq!(forwarder.spec(), &spec);
|
||||||
|
assert_eq!(forwarder.local_port(), 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_forwarder_spec_access() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
|
||||||
|
assert_eq!(forwarder.spec(), &spec);
|
||||||
|
}
|
||||||
|
}
|
||||||
7
crates/wraith-core/src/client/mod.rs
Normal file
7
crates/wraith-core/src/client/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
pub mod channel_manager;
|
||||||
|
pub mod connect;
|
||||||
|
pub mod forward;
|
||||||
|
|
||||||
|
pub use channel_manager::{ChannelManager, ForwardRequest};
|
||||||
|
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||||
|
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
|
||||||
@@ -60,6 +60,27 @@ pub enum ConfigError {
|
|||||||
IncompatibleOptions,
|
IncompatibleOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ForwardError {
|
||||||
|
#[error("invalid port forward spec: {spec}")]
|
||||||
|
InvalidSpec { spec: String },
|
||||||
|
#[error("bind failed")]
|
||||||
|
BindFailed {
|
||||||
|
#[source]
|
||||||
|
source: io::Error,
|
||||||
|
},
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed {
|
||||||
|
#[source]
|
||||||
|
source: Box<dyn std::error::Error + Send + Sync>,
|
||||||
|
},
|
||||||
|
#[error("connect to local target failed")]
|
||||||
|
LocalConnectFailed {
|
||||||
|
#[source]
|
||||||
|
source: io::Error,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -150,4 +171,36 @@ mod tests {
|
|||||||
let plain = AuthError::KeyRejected;
|
let plain = AuthError::KeyRejected;
|
||||||
assert!(plain.source().is_none());
|
assert!(plain.source().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_error_display() {
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(),
|
||||||
|
"invalid port forward spec: bad"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::BindFailed {
|
||||||
|
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
|
"bind failed"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::LocalConnectFailed {
|
||||||
|
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
|
"connect to local target failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_error_source_chaining() {
|
||||||
|
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "in use");
|
||||||
|
let forward_err = ForwardError::BindFailed { source: io_err };
|
||||||
|
assert!(forward_err.source().is_some());
|
||||||
|
|
||||||
|
let plain = ForwardError::InvalidSpec { spec: "bad".to_string() };
|
||||||
|
assert!(plain.source().is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -8,5 +8,7 @@ pub mod error;
|
|||||||
#[cfg(feature = "testutil")]
|
#[cfg(feature = "testutil")]
|
||||||
pub mod testutil;
|
pub mod testutil;
|
||||||
|
|
||||||
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
|
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||||
|
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||||
560
crates/wraith-core/src/server/channel_proxy.rs
Normal file
560
crates/wraith-core/src/server/channel_proxy.rs
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
|
||||||
|
use super::handler::{ProxyConfig, ProxyMode};
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ChannelProxyError {
|
||||||
|
#[error("connection refused")]
|
||||||
|
ConnectionRefused,
|
||||||
|
#[error("target unreachable")]
|
||||||
|
TargetUnreachable,
|
||||||
|
#[error("socks5 proxy handshake failed")]
|
||||||
|
Socks5HandshakeFailed,
|
||||||
|
#[error("socks5 proxy rejected connection")]
|
||||||
|
Socks5ProxyRejected,
|
||||||
|
#[error("http connect proxy handshake failed")]
|
||||||
|
HttpConnectHandshakeFailed,
|
||||||
|
#[error("http connect proxy rejected: {0}")]
|
||||||
|
HttpConnectProxyRejected(String),
|
||||||
|
#[error("io error")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_outbound(
|
||||||
|
target: SocketAddr,
|
||||||
|
proxy: &ProxyConfig,
|
||||||
|
) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
match &proxy.mode {
|
||||||
|
ProxyMode::Direct => connect_direct(target).await,
|
||||||
|
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
|
||||||
|
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
TcpStream::connect(target)
|
||||||
|
.await
|
||||||
|
.map_err(|e| map_connection_error(e, target))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
let mut stream = TcpStream::connect(proxy_addr)
|
||||||
|
.await
|
||||||
|
.map_err(ChannelProxyError::from)?;
|
||||||
|
|
||||||
|
stream.write_all(&[0x05, 0x01, 0x00]).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
stream.read_exact(&mut resp).await?;
|
||||||
|
if resp[0] != 0x05 || resp[1] != 0x00 {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip_bytes = target.ip().to_string();
|
||||||
|
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
connect_req.push(ip_bytes.len() as u8);
|
||||||
|
connect_req.extend_from_slice(ip_bytes.as_bytes());
|
||||||
|
connect_req.extend_from_slice(&target.port().to_be_bytes());
|
||||||
|
stream.write_all(&connect_req).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut reply_header = [0u8; 4];
|
||||||
|
stream.read_exact(&mut reply_header).await?;
|
||||||
|
if reply_header[0] != 0x05 {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
if reply_header[1] != 0x00 {
|
||||||
|
return Err(ChannelProxyError::Socks5ProxyRejected);
|
||||||
|
}
|
||||||
|
|
||||||
|
let atyp = reply_header[3];
|
||||||
|
match atyp {
|
||||||
|
0x01 => {
|
||||||
|
let mut _addr = [0u8; 4];
|
||||||
|
stream.read_exact(&mut _addr).await?;
|
||||||
|
}
|
||||||
|
0x04 => {
|
||||||
|
let mut _addr = [0u8; 16];
|
||||||
|
stream.read_exact(&mut _addr).await?;
|
||||||
|
}
|
||||||
|
0x03 => {
|
||||||
|
let len = stream.read_u8().await?;
|
||||||
|
let mut _domain = vec![0u8; len as usize];
|
||||||
|
stream.read_exact(&mut _domain).await?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut _port = [0u8; 2];
|
||||||
|
stream.read_exact(&mut _port).await?;
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_http_connect(
|
||||||
|
target: SocketAddr,
|
||||||
|
proxy_addr: SocketAddr,
|
||||||
|
) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
let mut stream = TcpStream::connect(proxy_addr)
|
||||||
|
.await
|
||||||
|
.map_err(ChannelProxyError::from)?;
|
||||||
|
|
||||||
|
let connect_request = format!(
|
||||||
|
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
|
||||||
|
target.ip(),
|
||||||
|
target.port(),
|
||||||
|
target.ip(),
|
||||||
|
target.port()
|
||||||
|
);
|
||||||
|
stream.write_all(connect_request.as_bytes()).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut response = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = stream.read(&mut buf).await?;
|
||||||
|
if n == 0 {
|
||||||
|
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
|
||||||
|
}
|
||||||
|
response.extend_from_slice(&buf[..n]);
|
||||||
|
if response.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_str = String::from_utf8_lossy(&response);
|
||||||
|
let status_line = response_str
|
||||||
|
.lines()
|
||||||
|
.next()
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
if status_line.contains("200") {
|
||||||
|
Ok(stream)
|
||||||
|
} else {
|
||||||
|
Err(ChannelProxyError::HttpConnectProxyRejected(
|
||||||
|
status_line.to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_connection_error(e: std::io::Error, target: SocketAddr) -> ChannelProxyError {
|
||||||
|
match e.kind() {
|
||||||
|
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
|
||||||
|
std::io::ErrorKind::AddrNotAvailable
|
||||||
|
| std::io::ErrorKind::NetworkUnreachable
|
||||||
|
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
|
||||||
|
_ => {
|
||||||
|
tracing::debug!(error = %e, "outbound connection failed to {:?}", target);
|
||||||
|
ChannelProxyError::Io(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
if let Ok(outbound) = connect_outbound(target, proxy).await {
|
||||||
|
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
|
||||||
|
let (mut read_out, mut write_out) = outbound.into_split();
|
||||||
|
|
||||||
|
let client_to_target = tokio::spawn(async move {
|
||||||
|
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
|
||||||
|
let _ = write_out.shutdown().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let target_to_client = tokio::spawn(async move {
|
||||||
|
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
|
||||||
|
let _ = write_chan.shutdown().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = client_to_target.await;
|
||||||
|
let _ = target_to_client.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
fn direct_config() -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::Direct,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::Socks5(addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::HttpConnect(addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_to_echo_server() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
|
||||||
|
let (mut read, mut write) = stream.into_split();
|
||||||
|
write.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 5];
|
||||||
|
read.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello");
|
||||||
|
|
||||||
|
let _ = server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_target_unreachable() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let result = connect_outbound(target, &direct_config()).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_handshake() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = target_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut greeting = [0u8; 3];
|
||||||
|
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||||
|
assert_eq!(greeting[0], 0x05);
|
||||||
|
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||||
|
|
||||||
|
let mut req_header = [0u8; 4];
|
||||||
|
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||||
|
assert_eq!(req_header[0], 0x05);
|
||||||
|
assert_eq!(req_header[1], 0x01);
|
||||||
|
|
||||||
|
let atyp = req_header[3];
|
||||||
|
assert_eq!(atyp, 0x03);
|
||||||
|
|
||||||
|
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||||
|
let mut domain = vec![0u8; domain_len];
|
||||||
|
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||||
|
let mut port_bytes = [0u8; 2];
|
||||||
|
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||||
|
|
||||||
|
let target: SocketAddr = format!(
|
||||||
|
"{}:{}",
|
||||||
|
String::from_utf8_lossy(&domain),
|
||||||
|
u16::from_be_bytes(port_bytes)
|
||||||
|
)
|
||||||
|
.parse()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let reply = vec![
|
||||||
|
0x05, 0x00, 0x00, 0x01,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0,
|
||||||
|
];
|
||||||
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
|
|
||||||
|
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
||||||
|
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let config = socks5_config(proxy_addr);
|
||||||
|
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||||
|
stream.write_all(b"hello socks").await.unwrap();
|
||||||
|
let mut buf = [0u8; 11];
|
||||||
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello socks");
|
||||||
|
drop(stream);
|
||||||
|
|
||||||
|
let _ = target_server.await;
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_rejected() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut greeting = [0u8; 3];
|
||||||
|
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||||
|
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||||
|
|
||||||
|
let mut req_header = [0u8; 4];
|
||||||
|
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||||
|
|
||||||
|
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||||
|
let mut domain = vec![0u8; domain_len];
|
||||||
|
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||||
|
let mut port_bytes = [0u8; 2];
|
||||||
|
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||||
|
|
||||||
|
let reply = vec![
|
||||||
|
0x05, 0x05, 0x00, 0x01,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0,
|
||||||
|
];
|
||||||
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let config = socks5_config(proxy_addr);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
ChannelProxyError::Socks5ProxyRejected
|
||||||
|
));
|
||||||
|
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_handshake() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = target_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut request = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||||
|
request.extend_from_slice(&buf[..n]);
|
||||||
|
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||||
|
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
|
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
|
||||||
|
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
|
||||||
|
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let config = http_connect_config(proxy_addr);
|
||||||
|
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||||
|
stream.write_all(b"hello http").await.unwrap();
|
||||||
|
let mut buf = [0u8; 10];
|
||||||
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello http");
|
||||||
|
drop(stream);
|
||||||
|
|
||||||
|
let _ = target_server.await;
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_connect_target(request: &str) -> String {
|
||||||
|
let connect_line = request.lines().next().unwrap_or("");
|
||||||
|
let parts: Vec<&str> = connect_line.split_whitespace().collect();
|
||||||
|
if parts.len() >= 2 {
|
||||||
|
parts[1].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_rejected() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut request = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
request.extend_from_slice(&buf[..n]);
|
||||||
|
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
|
||||||
|
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let config = http_connect_config(proxy_addr);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result.unwrap_err() {
|
||||||
|
ChannelProxyError::HttpConnectProxyRejected(msg) => {
|
||||||
|
assert!(msg.contains("403"));
|
||||||
|
}
|
||||||
|
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn target_unreachable_returns_appropriate_error() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let result = connect_outbound(target, &direct_config()).await;
|
||||||
|
match result.unwrap_err() {
|
||||||
|
ChannelProxyError::TargetUnreachable
|
||||||
|
| ChannelProxyError::ConnectionRefused
|
||||||
|
| ChannelProxyError::Io(_) => {}
|
||||||
|
other => panic!("unexpected error type: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_unreachable() {
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||||
|
let config = socks5_config(bad_proxy);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_unreachable() {
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||||
|
let config = http_connect_config(bad_proxy);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MockChannel {
|
||||||
|
read_half: tokio::io::ReadHalf<DuplexStream>,
|
||||||
|
write_half: tokio::io::WriteHalf<DuplexStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncRead for MockChannel {
|
||||||
|
fn poll_read(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for MockChannel {
|
||||||
|
fn poll_write(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> std::task::Poll<std::io::Result<usize>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_mock_channel() -> (MockChannel, DuplexStream) {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
let (read_half, write_half) = tokio::io::split(client);
|
||||||
|
(
|
||||||
|
MockChannel {
|
||||||
|
read_half,
|
||||||
|
write_half,
|
||||||
|
},
|
||||||
|
server,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_channel_bidirectional_data_flow() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let echo_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let (channel, mut channel_peer) = make_mock_channel();
|
||||||
|
|
||||||
|
let target = target_addr;
|
||||||
|
let proxy = direct_config();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
proxy_channel(channel, target, &proxy).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
channel_peer.write_all(b"ping").await.unwrap();
|
||||||
|
channel_peer.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 4];
|
||||||
|
channel_peer.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"ping");
|
||||||
|
|
||||||
|
drop(channel_peer);
|
||||||
|
let _ = echo_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_channel_target_unreachable_closes_cleanly() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let (channel, _channel_peer) = make_mock_channel();
|
||||||
|
|
||||||
|
let proxy = direct_config();
|
||||||
|
proxy_channel(channel, target, &proxy).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
186
crates/wraith-core/src/server/control_channel.rs
Normal file
186
crates/wraith-core/src/server/control_channel.rs
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
use std::io;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
|
pub const WRAITH_CONTROL_DESTINATION: &str = "wraith-control";
|
||||||
|
pub const WRAITH_PREFIX: &str = "wraith-";
|
||||||
|
|
||||||
|
pub fn is_reserved_destination(host: &str) -> bool {
|
||||||
|
host.starts_with(WRAITH_PREFIX)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||||
|
|
||||||
|
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ControlChannelHandler: Send + Sync {
|
||||||
|
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ControlChannelRouter {
|
||||||
|
handler: Option<Box<dyn ControlChannelHandler>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ControlChannelRouter {
|
||||||
|
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
|
||||||
|
Self { handler }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn without_handler() -> Self {
|
||||||
|
Self { handler: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||||
|
Self {
|
||||||
|
handler: Some(handler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_handler(&self) -> bool {
|
||||||
|
self.handler.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
|
||||||
|
match &self.handler {
|
||||||
|
Some(handler) => {
|
||||||
|
handler.handle_channel(stream).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
None => Err(io::Error::new(
|
||||||
|
io::ErrorKind::ConnectionRefused,
|
||||||
|
"no control channel handler configured",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::duplex;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wraith_control_destination_constant() {
|
||||||
|
assert_eq!(WRAITH_CONTROL_DESTINATION, "wraith-control");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wraith_prefix_constant() {
|
||||||
|
assert_eq!(WRAITH_PREFIX, "wraith-");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reserved_destination_detected() {
|
||||||
|
assert!(is_reserved_destination("wraith-control"));
|
||||||
|
assert!(is_reserved_destination("wraith-status"));
|
||||||
|
assert!(is_reserved_destination("wraith-events"));
|
||||||
|
assert!(is_reserved_destination("wraith-"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_reserved_destination_passes_through() {
|
||||||
|
assert!(!is_reserved_destination("example.com"));
|
||||||
|
assert!(!is_reserved_destination("localhost"));
|
||||||
|
assert!(!is_reserved_destination("192.168.1.1"));
|
||||||
|
assert!(!is_reserved_destination("wraith.example.com"));
|
||||||
|
assert!(!is_reserved_destination(""));
|
||||||
|
assert!(!is_reserved_destination("wrait-control"));
|
||||||
|
assert!(!is_reserved_destination("WRAITH-control"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prefix_matching_case_sensitive() {
|
||||||
|
assert!(!is_reserved_destination("Wraith-control"));
|
||||||
|
assert!(!is_reserved_destination("WRAITH-control"));
|
||||||
|
assert!(is_reserved_destination("wraith-Control"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn router_without_handler_has_no_handler() {
|
||||||
|
let router = ControlChannelRouter::without_handler();
|
||||||
|
assert!(!router.has_handler());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn router_with_handler_has_handler() {
|
||||||
|
struct DummyHandler;
|
||||||
|
#[async_trait]
|
||||||
|
impl ControlChannelHandler for DummyHandler {
|
||||||
|
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
|
||||||
|
}
|
||||||
|
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
|
||||||
|
assert!(router.has_handler());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn route_without_handler_returns_error() {
|
||||||
|
let router = ControlChannelRouter::without_handler();
|
||||||
|
let (_client, server) = duplex(64);
|
||||||
|
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||||
|
let result = router.route(stream).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
let err = result.unwrap_err();
|
||||||
|
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn route_with_handler_succeeds() {
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
struct TrackedHandler {
|
||||||
|
called: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
#[async_trait]
|
||||||
|
impl ControlChannelHandler for TrackedHandler {
|
||||||
|
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||||
|
self.called.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let called = Arc::new(AtomicBool::new(false));
|
||||||
|
let handler = TrackedHandler {
|
||||||
|
called: called.clone(),
|
||||||
|
};
|
||||||
|
let router = ControlChannelRouter::with_handler(Box::new(handler));
|
||||||
|
let (_client, server) = duplex(64);
|
||||||
|
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||||
|
let result = router.route(stream).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(called.load(Ordering::SeqCst));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn route_with_handler_can_read_write() {
|
||||||
|
struct EchoHandler;
|
||||||
|
#[async_trait]
|
||||||
|
impl ControlChannelHandler for EchoHandler {
|
||||||
|
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = stream.read(&mut buf).await.unwrap();
|
||||||
|
stream.write_all(&buf[..n]).await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
|
||||||
|
let (client, server) = duplex(64);
|
||||||
|
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
router.route(stream).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
let mut client = client;
|
||||||
|
client.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 5];
|
||||||
|
client.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn control_channel_destination_matches_prefix() {
|
||||||
|
assert!(is_reserved_destination(WRAITH_CONTROL_DESTINATION));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use russh::keys::ssh_key::HashAlg;
|
use russh::keys::ssh_key::HashAlg;
|
||||||
@@ -7,8 +8,10 @@ use russh::server::{Auth, Handler, Msg, Session};
|
|||||||
use russh::Channel;
|
use russh::Channel;
|
||||||
|
|
||||||
use crate::auth::ServerAuthConfig;
|
use crate::auth::ServerAuthConfig;
|
||||||
|
use crate::server::control_channel::{
|
||||||
const WRAITH_PREFIX: &str = "wraith-";
|
ControlChannelHandler, ControlChannelRouter, WRAITH_PREFIX,
|
||||||
|
};
|
||||||
|
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ProxyMode {
|
pub enum ProxyMode {
|
||||||
@@ -22,10 +25,34 @@ pub struct ProxyConfig {
|
|||||||
pub mode: ProxyMode,
|
pub mode: ProxyMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub enum TransportKind {
|
||||||
|
Tcp,
|
||||||
|
Tls,
|
||||||
|
Iroh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TransportKind {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TransportKind::Tcp => write!(f, "tcp"),
|
||||||
|
TransportKind::Tls => write!(f, "tls"),
|
||||||
|
TransportKind::Iroh => write!(f, "iroh"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ServerHandler {
|
pub struct ServerHandler {
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
|
#[allow(dead_code)]
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
|
control_channel_router: ControlChannelRouter,
|
||||||
|
transport: TransportKind,
|
||||||
|
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||||
|
connection_allowed: bool,
|
||||||
|
auth_limiter: AuthAttemptLimiter,
|
||||||
|
connected_at: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerHandler {
|
impl ServerHandler {
|
||||||
@@ -33,13 +60,82 @@ impl ServerHandler {
|
|||||||
auth_config: Arc<ServerAuthConfig>,
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
|
transport: TransportKind,
|
||||||
|
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||||
|
max_auth_attempts: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let allowed = if let Some(addr) = remote_addr {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if connection_limiter.check(ip) {
|
||||||
|
connection_limiter.on_connect(ip);
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
transport = %transport,
|
||||||
|
"connection opened"
|
||||||
|
);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
transport = %transport,
|
||||||
|
"connection rejected"
|
||||||
|
);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
auth_config,
|
auth_config,
|
||||||
outbound_proxy,
|
outbound_proxy,
|
||||||
remote_addr,
|
remote_addr,
|
||||||
|
control_channel_router: ControlChannelRouter::without_handler(),
|
||||||
|
transport,
|
||||||
|
connection_limiter,
|
||||||
|
connection_allowed: allowed,
|
||||||
|
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||||
|
connected_at: Instant::now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_connection_allowed(&self) -> bool {
|
||||||
|
self.connection_allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||||
|
self.remote_addr.map(|a| a.ip())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ServerHandler {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(addr) = self.remote_addr {
|
||||||
|
if self.connection_allowed {
|
||||||
|
self.connection_limiter.on_disconnect(addr.ip());
|
||||||
|
}
|
||||||
|
let duration = self.connected_at.elapsed();
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
duration_secs = duration.as_secs_f64(),
|
||||||
|
"connection closed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerHandler {
|
||||||
|
pub fn with_control_channel_handler(
|
||||||
|
mut self,
|
||||||
|
handler: Box<dyn ControlChannelHandler>,
|
||||||
|
) -> Self {
|
||||||
|
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn control_channel_router(&self) -> &ControlChannelRouter {
|
||||||
|
&self.control_channel_router
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -51,6 +147,23 @@ impl Handler for ServerHandler {
|
|||||||
user: &str,
|
user: &str,
|
||||||
public_key: &russh::keys::ssh_key::PublicKey,
|
public_key: &russh::keys::ssh_key::PublicKey,
|
||||||
) -> Result<Auth, Self::Error> {
|
) -> Result<Auth, Self::Error> {
|
||||||
|
if !self.auth_limiter.check() {
|
||||||
|
let remote_addr_display = self
|
||||||
|
.remote_addr
|
||||||
|
.map_or("unknown".to_string(), |a| a.to_string());
|
||||||
|
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
|
key_fingerprint = %fingerprint,
|
||||||
|
result = "reject",
|
||||||
|
"auth attempt"
|
||||||
|
);
|
||||||
|
return Ok(Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||||
let remote_addr_display = self
|
let remote_addr_display = self
|
||||||
.remote_addr
|
.remote_addr
|
||||||
@@ -63,6 +176,7 @@ impl Handler for ServerHandler {
|
|||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
remote_addr = %remote_addr_display,
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
key_fingerprint = %fingerprint,
|
key_fingerprint = %fingerprint,
|
||||||
result = "accept",
|
result = "accept",
|
||||||
"auth attempt"
|
"auth attempt"
|
||||||
@@ -70,8 +184,10 @@ impl Handler for ServerHandler {
|
|||||||
Ok(Auth::Accept)
|
Ok(Auth::Accept)
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
self.auth_limiter.on_failure();
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
remote_addr = %remote_addr_display,
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
key_fingerprint = %fingerprint,
|
key_fingerprint = %fingerprint,
|
||||||
result = "reject",
|
result = "reject",
|
||||||
"auth attempt"
|
"auth attempt"
|
||||||
@@ -98,25 +214,20 @@ impl Handler for ServerHandler {
|
|||||||
port = port_to_connect,
|
port = port_to_connect,
|
||||||
"routing to internal control channel handler"
|
"routing to internal control channel handler"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if !self.control_channel_router.has_handler() {
|
||||||
|
tracing::warn!(
|
||||||
|
host = host_to_connect,
|
||||||
|
"no control channel handler configured, rejecting channel open"
|
||||||
|
);
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = channel;
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
let proxy_info = self
|
let _ = (host_to_connect, port_to_connect, originator_address, originator_port, channel);
|
||||||
.outbound_proxy
|
|
||||||
.as_ref()
|
|
||||||
.map(|p| format!("{:?}", p.mode))
|
|
||||||
.unwrap_or_else(|| "direct".to_string());
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
host = host_to_connect,
|
|
||||||
port = port_to_connect,
|
|
||||||
originator_address = originator_address,
|
|
||||||
originator_port = originator_port,
|
|
||||||
proxy = %proxy_info,
|
|
||||||
"spawning tcp proxy task"
|
|
||||||
);
|
|
||||||
|
|
||||||
let _ = channel;
|
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,10 +299,22 @@ mod tests {
|
|||||||
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_handler(
|
||||||
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
|
remote_addr: Option<SocketAddr>,
|
||||||
|
) -> ServerHandler {
|
||||||
|
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_accepts_known_key() {
|
async fn auth_delegation_accepts_known_key() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||||
@@ -201,7 +324,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_rejects_unknown_key() {
|
async fn auth_delegation_rejects_unknown_key() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||||
let other_ssh_key = russh::keys::parse_public_key_base64(
|
let other_ssh_key = russh::keys::parse_public_key_base64(
|
||||||
@@ -224,7 +347,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_empty_config_rejects_all() {
|
async fn auth_delegation_empty_config_rejects_all() {
|
||||||
let auth_config = make_empty_auth_config();
|
let auth_config = make_empty_auth_config();
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let result = handler
|
let result = handler
|
||||||
@@ -243,7 +366,7 @@ mod tests {
|
|||||||
async fn auth_logging_includes_remote_addr() {
|
async fn auth_logging_includes_remote_addr() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||||
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
|
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||||
@@ -251,12 +374,20 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reserved_wraith_destination_routing() {
|
fn reserved_wraith_destination_routing() {
|
||||||
assert!("wraith-control".starts_with(WRAITH_PREFIX));
|
use crate::server::control_channel::is_reserved_destination;
|
||||||
assert!("wraith-status".starts_with(WRAITH_PREFIX));
|
assert!(is_reserved_destination("wraith-control"));
|
||||||
assert!("wraith-events".starts_with(WRAITH_PREFIX));
|
assert!(is_reserved_destination("wraith-status"));
|
||||||
assert!(!"example.com".starts_with(WRAITH_PREFIX));
|
assert!(is_reserved_destination("wraith-events"));
|
||||||
assert!(!"localhost".starts_with(WRAITH_PREFIX));
|
assert!(!is_reserved_destination("example.com"));
|
||||||
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
|
assert!(!is_reserved_destination("localhost"));
|
||||||
|
assert!(!is_reserved_destination("wraith.example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_handler_without_control_handler_rejects_wraith_destinations() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let handler = make_handler(auth_config, None, None);
|
||||||
|
assert!(!handler.control_channel_router().has_handler());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -287,7 +418,7 @@ mod tests {
|
|||||||
});
|
});
|
||||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||||
|
|
||||||
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
|
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||||
assert!(handler.outbound_proxy.is_some());
|
assert!(handler.outbound_proxy.is_some());
|
||||||
assert!(handler.remote_addr.is_some());
|
assert!(handler.remote_addr.is_some());
|
||||||
}
|
}
|
||||||
@@ -295,9 +426,108 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn one_handler_per_connection() {
|
fn one_handler_per_connection() {
|
||||||
let auth_config = make_empty_auth_config();
|
let auth_config = make_empty_auth_config();
|
||||||
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||||
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||||
|
|
||||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||||
|
let mut handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("10.0.0.1:22".parse().unwrap()),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter,
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
|
||||||
|
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
|
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
|
||||||
|
|
||||||
|
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
|
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
|
||||||
|
|
||||||
|
assert!(!handler.auth_limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_rate_limit_blocks_over_limit() {
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||||
|
|
||||||
|
let h1 = ServerHandler::new(
|
||||||
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter.clone(),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(h1.is_connection_allowed());
|
||||||
|
|
||||||
|
let h2 = ServerHandler::new(
|
||||||
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter.clone(),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(!h2.is_connection_allowed());
|
||||||
|
|
||||||
|
drop(h1);
|
||||||
|
|
||||||
|
let h3 = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter,
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(h3.is_connection_allowed());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn transport_kind_display() {
|
||||||
|
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||||
|
assert_eq!(TransportKind::Tls.to_string(), "tls");
|
||||||
|
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_log_includes_user_field() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let mut handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("203.0.113.50:12345".parse().unwrap()),
|
||||||
|
TransportKind::Tls,
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0)),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_closed_logs_duration_on_drop() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let _handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("203.0.113.50:12345".parse().unwrap()),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0)),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,3 +1,14 @@
|
|||||||
|
pub mod channel_proxy;
|
||||||
|
pub mod control_channel;
|
||||||
pub mod handler;
|
pub mod handler;
|
||||||
|
pub mod rate_limit;
|
||||||
|
pub mod stealth;
|
||||||
|
|
||||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
pub use channel_proxy::{connect_outbound, proxy_channel};
|
||||||
|
pub use control_channel::{
|
||||||
|
ControlChannelHandler, ControlChannelRouter, DuplexStream, WRAITH_CONTROL_DESTINATION,
|
||||||
|
WRAITH_PREFIX, is_reserved_destination,
|
||||||
|
};
|
||||||
|
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||||
|
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
|
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};
|
||||||
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
pub struct ConnectionRateLimiter {
|
||||||
|
max_per_ip: usize,
|
||||||
|
active: Mutex<HashMap<IpAddr, usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionRateLimiter {
|
||||||
|
pub fn new(max_per_ip: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
max_per_ip,
|
||||||
|
active: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(&self, ip: IpAddr) -> bool {
|
||||||
|
if self.max_per_ip == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
let active = self.active.lock().unwrap();
|
||||||
|
let count = active.get(&ip).copied().unwrap_or(0);
|
||||||
|
count < self.max_per_ip
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_connect(&self, ip: IpAddr) {
|
||||||
|
let mut active = self.active.lock().unwrap();
|
||||||
|
*active.entry(ip).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||||
|
let mut active = self.active.lock().unwrap();
|
||||||
|
if let Some(count) = active.get_mut(&ip) {
|
||||||
|
if *count > 1 {
|
||||||
|
*count -= 1;
|
||||||
|
} else {
|
||||||
|
active.remove(&ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AuthAttemptLimiter {
|
||||||
|
max_attempts: usize,
|
||||||
|
failures: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthAttemptLimiter {
|
||||||
|
pub fn new(max_attempts: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
max_attempts,
|
||||||
|
failures: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(&self) -> bool {
|
||||||
|
if self.max_attempts == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
self.failures < self.max_attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_failure(&mut self) {
|
||||||
|
self.failures += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
fn ip(n: u8) -> IpAddr {
|
||||||
|
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_allows_when_under_limit() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(3);
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_blocks_when_at_limit() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(2);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_allows_after_disconnect() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(2);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
limiter.on_disconnect(ip(1));
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_unlimited_when_zero() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(0);
|
||||||
|
for _ in 0..100 {
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
}
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_tracks_per_ip_independently() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(1);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
assert!(limiter.check(ip(2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_ipv6() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(1);
|
||||||
|
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||||
|
limiter.on_connect(ip6);
|
||||||
|
assert!(!limiter.check(ip6));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(3);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_disconnect(ip(1));
|
||||||
|
{
|
||||||
|
let active = limiter.active.lock().unwrap();
|
||||||
|
assert!(!active.contains_key(&ip(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_allows_when_under_limit() {
|
||||||
|
let limiter = AuthAttemptLimiter::new(3);
|
||||||
|
assert!(limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_blocks_after_max_failures() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(2);
|
||||||
|
limiter.on_failure();
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(!limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_unlimited_when_zero() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(0);
|
||||||
|
for _ in 0..100 {
|
||||||
|
limiter.on_failure();
|
||||||
|
}
|
||||||
|
assert!(limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(3);
|
||||||
|
limiter.on_failure();
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(limiter.check());
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(!limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_thread_safety() {
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
let lim = Arc::clone(&limiter);
|
||||||
|
handles.push(thread::spawn(move || {
|
||||||
|
let ip_addr = ip((i % 3) as u8 + 1);
|
||||||
|
lim.on_connect(ip_addr);
|
||||||
|
assert!(lim.check(ip_addr));
|
||||||
|
lim.on_disconnect(ip_addr);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
218
crates/wraith-core/src/server/stealth.rs
Normal file
218
crates/wraith-core/src/server/stealth.rs
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
|
||||||
|
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ProtocolDetection {
|
||||||
|
Ssh,
|
||||||
|
Http,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
|
||||||
|
where
|
||||||
|
S: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
let detection = match reader.fill_buf().await {
|
||||||
|
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
|
||||||
|
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
|
||||||
|
ProtocolDetection::Ssh
|
||||||
|
} else {
|
||||||
|
ProtocolDetection::Http
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(buf) if !buf.is_empty() => {
|
||||||
|
if buf.starts_with(SSH_BANNER_PREFIX) {
|
||||||
|
ProtocolDetection::Ssh
|
||||||
|
} else {
|
||||||
|
ProtocolDetection::Http
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => ProtocolDetection::Http,
|
||||||
|
};
|
||||||
|
|
||||||
|
(detection, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
|
||||||
|
let _ = reader.get_mut().shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
|
||||||
|
if stealth && !transport_is_tls {
|
||||||
|
return Err("stealth mode requires TLS transport (--transport tls)");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
|
||||||
|
let (client, server) = duplex(1024);
|
||||||
|
let mut client = client;
|
||||||
|
|
||||||
|
client.write_all(data).await.unwrap();
|
||||||
|
drop(client);
|
||||||
|
|
||||||
|
let (detection, _) = detect_protocol(server).await;
|
||||||
|
detection
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ssh_banner_detected() {
|
||||||
|
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ssh_banner_other_implementation() {
|
||||||
|
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ssh_banner_minimal() {
|
||||||
|
let detection = write_and_detect(b"SSH-2.0-X\n").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_get_detected_as_http() {
|
||||||
|
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_post_detected_as_http() {
|
||||||
|
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn random_data_detected_as_http() {
|
||||||
|
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_stream_detected_as_http() {
|
||||||
|
let (client, server) = duplex(1024);
|
||||||
|
drop(client);
|
||||||
|
let (detection, _) = detect_protocol(server).await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ssh_banner_bytes_preserved_by_bufreader() {
|
||||||
|
let (client, server) = duplex(1024);
|
||||||
|
let mut client = client;
|
||||||
|
|
||||||
|
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
|
||||||
|
client.write_all(banner).await.unwrap();
|
||||||
|
client.write_all(b"subsequent data").await.unwrap();
|
||||||
|
drop(client);
|
||||||
|
|
||||||
|
let (detection, mut reader) = detect_protocol(server).await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||||
|
|
||||||
|
let mut all_data = Vec::new();
|
||||||
|
reader.read_to_end(&mut all_data).await.unwrap();
|
||||||
|
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fake_nginx_404_response() {
|
||||||
|
let (client, server) = duplex(1024);
|
||||||
|
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||||
|
|
||||||
|
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
||||||
|
drop(client_write);
|
||||||
|
|
||||||
|
let (detection, mut reader) = detect_protocol(server).await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
|
||||||
|
send_fake_nginx_404(&mut reader).await;
|
||||||
|
|
||||||
|
let mut buf = [0u8; 256];
|
||||||
|
let n = client_read.read(&mut buf).await.unwrap();
|
||||||
|
let response = String::from_utf8_lossy(&buf[..n]);
|
||||||
|
assert!(response.contains("HTTP/1.1 404 Not Found"));
|
||||||
|
assert!(response.contains("Server: nginx"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn protocol_detection_enum_equality() {
|
||||||
|
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
|
||||||
|
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
|
||||||
|
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_stealth_without_tls_rejected() {
|
||||||
|
let result = validate_stealth_config(true, false);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("TLS transport"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_stealth_with_tls_accepted() {
|
||||||
|
let result = validate_stealth_config(true, true);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_no_stealth_with_tcp_accepted() {
|
||||||
|
let result = validate_stealth_config(false, false);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_no_stealth_with_tls_accepted() {
|
||||||
|
let result = validate_stealth_config(false, true);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn short_data_detected_as_http() {
|
||||||
|
let detection = write_and_detect(b"GE").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn partial_ssh_prefix_detected_as_http() {
|
||||||
|
let detection = write_and_detect(b"SSH-1.").await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_request_gets_404_then_closed() {
|
||||||
|
let (client, server) = duplex(1024);
|
||||||
|
let mut client = client;
|
||||||
|
|
||||||
|
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
||||||
|
|
||||||
|
let (detection, mut reader) = detect_protocol(server).await;
|
||||||
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
|
||||||
|
send_fake_nginx_404(&mut reader).await;
|
||||||
|
|
||||||
|
let mut buf = [0u8; 256];
|
||||||
|
let n = client.read(&mut buf).await.unwrap();
|
||||||
|
let response = String::from_utf8_lossy(&buf[..n]);
|
||||||
|
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
|
||||||
|
assert!(response.contains("Server: nginx"));
|
||||||
|
|
||||||
|
let mut extra = [0u8; 16];
|
||||||
|
let result = client.read(&mut extra).await;
|
||||||
|
assert!(result.is_err() || result.unwrap() == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
490
crates/wraith-core/src/socks5/mod.rs
Normal file
490
crates/wraith-core/src/socks5/mod.rs
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
mod protocol;
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
|
||||||
|
|
||||||
|
pub use protocol::Socks5Address;
|
||||||
|
|
||||||
|
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||||
|
|
||||||
|
pub trait ChannelOpener: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
|
fn open_channel(
|
||||||
|
&self,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ChannelOpenError {
|
||||||
|
#[error("session closed")]
|
||||||
|
SessionClosed,
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed,
|
||||||
|
#[error("connection refused")]
|
||||||
|
ConnectionRefused,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Socks5Server<C: ChannelOpener> {
|
||||||
|
listen_addr: SocketAddr,
|
||||||
|
channel_opener: Arc<C>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: ChannelOpener> Socks5Server<C> {
|
||||||
|
pub fn new(channel_opener: C) -> Self {
|
||||||
|
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
||||||
|
let listen_addr: SocketAddr = addr
|
||||||
|
.parse()
|
||||||
|
.expect("invalid SOCKS5 listen address");
|
||||||
|
Self {
|
||||||
|
listen_addr,
|
||||||
|
channel_opener: Arc::new(channel_opener),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
|
self.listen_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(self) -> Result<(), std::io::Error> {
|
||||||
|
let listener = TcpListener::bind(self.listen_addr).await?;
|
||||||
|
debug!("socks5 server listening on {}", self.listen_addr);
|
||||||
|
loop {
|
||||||
|
let (socket, _peer) = listener.accept().await?;
|
||||||
|
let opener = Arc::clone(&self.channel_opener);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_socks5_connection(socket, opener).await {
|
||||||
|
debug!("socks5 connection error: {e}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socks5_connection<S, C>(
|
||||||
|
mut socket: S,
|
||||||
|
opener: Arc<C>,
|
||||||
|
) -> Result<(), Socks5Error>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
C: ChannelOpener,
|
||||||
|
{
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
|
||||||
|
if vm.version != 0x05 {
|
||||||
|
return Err(Socks5Error::InvalidVersion(vm.version));
|
||||||
|
}
|
||||||
|
if !vm.methods.contains(&0x00) {
|
||||||
|
let reply = [0x05, 0xFF];
|
||||||
|
socket.write_all(&reply).await?;
|
||||||
|
socket.shutdown().await?;
|
||||||
|
return Err(Socks5Error::NoAcceptableAuth);
|
||||||
|
}
|
||||||
|
let reply = [0x05, 0x00];
|
||||||
|
socket.write_all(&reply).await?;
|
||||||
|
|
||||||
|
let request = Socks5Request::read_from(&mut socket).await?;
|
||||||
|
if request.version != 0x05 {
|
||||||
|
return Err(Socks5Error::InvalidVersion(request.version));
|
||||||
|
}
|
||||||
|
if request.command != 0x01 {
|
||||||
|
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
|
||||||
|
return Err(Socks5Error::UnsupportedCommand(request.command));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (host, port) = match &request.address {
|
||||||
|
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
|
||||||
|
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
|
||||||
|
Socks5Address::Domain(name) => (name.clone(), request.port),
|
||||||
|
};
|
||||||
|
|
||||||
|
match opener.open_channel(host, port).await {
|
||||||
|
Ok(mut ssh_stream) => {
|
||||||
|
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
|
||||||
|
let reply = Socks5Reply::success(bind_addr, 0);
|
||||||
|
reply.write_to(&mut socket).await?;
|
||||||
|
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
|
||||||
|
Err(Socks5Error::ChannelOpenFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
|
socket: &mut S,
|
||||||
|
reply: Socks5Reply,
|
||||||
|
) -> Result<(), Socks5Error> {
|
||||||
|
reply.write_to(socket).await?;
|
||||||
|
let _ = socket.shutdown().await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum Socks5Error {
|
||||||
|
#[error("invalid SOCKS version: {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error("no acceptable auth method")]
|
||||||
|
NoAcceptableAuth,
|
||||||
|
#[error("unsupported command: {0}")]
|
||||||
|
UnsupportedCommand(u8),
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed,
|
||||||
|
#[error("io error")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HandleChannelOpener<H: russh::client::Handler> {
|
||||||
|
handle: Arc<Mutex<russh::client::Handle<H>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
||||||
|
pub fn new(handle: russh::client::Handle<H>) -> Self {
|
||||||
|
Self {
|
||||||
|
handle: Arc::new(Mutex::new(handle)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
|
||||||
|
Self { handle }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
||||||
|
type Stream = russh::ChannelStream<russh::client::Msg>;
|
||||||
|
|
||||||
|
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
let handle = self.handle.lock().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
return Err(ChannelOpenError::SessionClosed);
|
||||||
|
}
|
||||||
|
let channel = handle
|
||||||
|
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
|
||||||
|
Ok(channel.into_stream())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||||
|
|
||||||
|
struct MockChannelOpener {
|
||||||
|
fail: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelOpener for MockChannelOpener {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
_host: String,
|
||||||
|
_port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
if self.fail {
|
||||||
|
Err(ChannelOpenError::ChannelOpenFailed)
|
||||||
|
} else {
|
||||||
|
let (client, _server) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, methods.len() as u8];
|
||||||
|
buf.extend_from_slice(methods);
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
|
||||||
|
buf.extend_from_slice(&addr);
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
buf.push(domain.len() as u8);
|
||||||
|
buf.extend_from_slice(domain.as_bytes());
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
|
||||||
|
buf.extend_from_slice(&addr);
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
||||||
|
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
client.read_exact(&mut resp).await.unwrap();
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_ipv4(addr, port))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
reply_buf.to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_no_auth_method() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp = do_handshake(&mut client).await;
|
||||||
|
assert_eq!(resp, [0x05, 0x00]);
|
||||||
|
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_rejects_no_acceptable_method() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_greeting(&[0x02]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
client.read_exact(&mut resp).await.unwrap();
|
||||||
|
assert_eq!(resp, [0x05, 0xFF]);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let result = server_handle.await.unwrap();
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
Socks5Error::NoAcceptableAuth
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_ipv4() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_domain() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_domain("example.com", 443))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_ipv6() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_open_failure_returns_socks5_error() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: true };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x05);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unsupported_command_returns_error() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
|
||||||
|
bind_req.extend_from_slice(&[127, 0, 0, 1]);
|
||||||
|
bind_req.extend_from_slice(&80u16.to_be_bytes());
|
||||||
|
client.write_all(&bind_req).await.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[1], 0x07);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bidirectional_proxy_flow() {
|
||||||
|
let (mut client_sock, server_sock) = duplex(4096);
|
||||||
|
let (ssh_client, mut ssh_server) = duplex(4096);
|
||||||
|
|
||||||
|
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
|
||||||
|
|
||||||
|
struct ProxyOpener {
|
||||||
|
stream: Arc<Mutex<Option<DuplexStream>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelOpener for ProxyOpener {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
_host: String,
|
||||||
|
_port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
self.stream
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.take()
|
||||||
|
.ok_or(ChannelOpenError::ChannelOpenFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let opener = ProxyOpener {
|
||||||
|
stream: Arc::clone(&ssh_stream),
|
||||||
|
};
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server_sock, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client_sock).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
let test_data = b"hello through tunnel";
|
||||||
|
client_sock.write_all(test_data).await.unwrap();
|
||||||
|
client_sock.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut received = vec![0u8; test_data.len()];
|
||||||
|
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(&received, test_data);
|
||||||
|
|
||||||
|
let echo_data = b"response from tunnel";
|
||||||
|
ssh_server.write_all(echo_data).await.unwrap();
|
||||||
|
ssh_server.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut received_back = vec![0u8; echo_data.len()];
|
||||||
|
client_sock.read_exact(&mut received_back).await.unwrap();
|
||||||
|
assert_eq!(&received_back, echo_data);
|
||||||
|
|
||||||
|
drop(client_sock);
|
||||||
|
drop(ssh_server);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn default_listen_address() {
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
let server = Socks5Server::new(opener);
|
||||||
|
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn custom_listen_address() {
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
|
||||||
|
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
304
crates/wraith-core/src/socks5/protocol.rs
Normal file
304
crates/wraith-core/src/socks5/protocol.rs
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum Socks5Address {
|
||||||
|
Ipv4(Ipv4Addr),
|
||||||
|
Ipv6(Ipv6Addr),
|
||||||
|
Domain(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5VersionMethod {
|
||||||
|
pub version: u8,
|
||||||
|
pub methods: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5VersionMethod {
|
||||||
|
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||||
|
let version = reader.read_u8().await?;
|
||||||
|
let nmethods = reader.read_u8().await?;
|
||||||
|
let mut methods = vec![0u8; nmethods as usize];
|
||||||
|
reader.read_exact(&mut methods).await?;
|
||||||
|
Ok(Self { version, methods })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5Request {
|
||||||
|
pub version: u8,
|
||||||
|
pub command: u8,
|
||||||
|
pub address: Socks5Address,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5Request {
|
||||||
|
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||||
|
let version = reader.read_u8().await?;
|
||||||
|
let command = reader.read_u8().await?;
|
||||||
|
let _rsv = reader.read_u8().await?;
|
||||||
|
let atyp = reader.read_u8().await?;
|
||||||
|
|
||||||
|
let address = match atyp {
|
||||||
|
0x01 => {
|
||||||
|
let mut octets = [0u8; 4];
|
||||||
|
reader.read_exact(&mut octets).await?;
|
||||||
|
Socks5Address::Ipv4(Ipv4Addr::from(octets))
|
||||||
|
}
|
||||||
|
0x04 => {
|
||||||
|
let mut octets = [0u8; 16];
|
||||||
|
reader.read_exact(&mut octets).await?;
|
||||||
|
Socks5Address::Ipv6(Ipv6Addr::from(octets))
|
||||||
|
}
|
||||||
|
0x03 => {
|
||||||
|
let len = reader.read_u8().await?;
|
||||||
|
let mut domain = vec![0u8; len as usize];
|
||||||
|
reader.read_exact(&mut domain).await?;
|
||||||
|
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("unsupported address type: {atyp}"),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let port = reader.read_u16().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
version,
|
||||||
|
command,
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5Reply {
|
||||||
|
pub version: u8,
|
||||||
|
pub reply: u8,
|
||||||
|
pub address: Socks5Address,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5Reply {
|
||||||
|
pub fn success(address: Socks5Address, port: u16) -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x00,
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_refused() -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x05,
|
||||||
|
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||||
|
port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_not_supported() -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x07,
|
||||||
|
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||||
|
port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||||
|
writer.write_u8(self.version).await?;
|
||||||
|
writer.write_u8(self.reply).await?;
|
||||||
|
writer.write_u8(0x00).await?;
|
||||||
|
match &self.address {
|
||||||
|
Socks5Address::Ipv4(addr) => {
|
||||||
|
writer.write_u8(0x01).await?;
|
||||||
|
writer.write_all(&addr.octets()).await?;
|
||||||
|
}
|
||||||
|
Socks5Address::Ipv6(addr) => {
|
||||||
|
writer.write_u8(0x04).await?;
|
||||||
|
writer.write_all(&addr.octets()).await?;
|
||||||
|
}
|
||||||
|
Socks5Address::Domain(name) => {
|
||||||
|
writer.write_u8(0x03).await?;
|
||||||
|
writer.write_u8(name.len() as u8).await?;
|
||||||
|
writer.write_all(name.as_bytes()).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writer.write_u16(self.port).await?;
|
||||||
|
writer.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_version_method_no_auth() {
|
||||||
|
let data = [0x05, 0x01, 0x00];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(vm.version, 0x05);
|
||||||
|
assert_eq!(vm.methods, vec![0x00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_version_method_multiple() {
|
||||||
|
let data = [0x05, 0x02, 0x00, 0x02];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(vm.version, 0x05);
|
||||||
|
assert_eq!(vm.methods, vec![0x00, 0x02]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_ipv4() {
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x01];
|
||||||
|
data.extend_from_slice(&[10, 0, 0, 1]);
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert_eq!(
|
||||||
|
req.address,
|
||||||
|
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
|
||||||
|
);
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_ipv6() {
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x04];
|
||||||
|
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||||
|
data.extend_from_slice(&octets);
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_domain() {
|
||||||
|
let domain = "example.com";
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
data.push(domain.len() as u8);
|
||||||
|
data.extend_from_slice(domain.as_bytes());
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_unsupported_address_type() {
|
||||||
|
let data = [0x05, 0x01, 0x00, 0x05];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let result = Socks5Request::read_from(&mut cursor).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_success_ipv4() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x00);
|
||||||
|
assert_eq!(buf[2], 0x00);
|
||||||
|
assert_eq!(buf[3], 0x01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_connection_refused() {
|
||||||
|
let reply = Socks5Reply::connection_refused();
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x05);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_command_not_supported() {
|
||||||
|
let reply = Socks5Reply::command_not_supported();
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x07);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_ipv4_reply() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(version, 0x05);
|
||||||
|
assert_eq!(atyp, 0x01);
|
||||||
|
let mut octets = [0u8; 4];
|
||||||
|
cursor.read_exact(&mut octets).await.unwrap();
|
||||||
|
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 1080);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_ipv6_reply() {
|
||||||
|
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let _version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(atyp, 0x04);
|
||||||
|
let mut octets = [0u8; 16];
|
||||||
|
cursor.read_exact(&mut octets).await.unwrap();
|
||||||
|
assert_eq!(Ipv6Addr::from(octets), addr);
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_domain_reply() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let _version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(atyp, 0x03);
|
||||||
|
let len = cursor.read_u8().await.unwrap();
|
||||||
|
let mut domain = vec![0u8; len as usize];
|
||||||
|
cursor.read_exact(&mut domain).await.unwrap();
|
||||||
|
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 8080);
|
||||||
|
}
|
||||||
|
}
|
||||||
362
crates/wraith-core/src/transport/acme.rs
Normal file
362
crates/wraith-core/src/transport/acme.rs
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use rustls::crypto::aws_lc_rs::default_provider;
|
||||||
|
use rustls::ServerConfig;
|
||||||
|
use rustls_acme::caches::DirCache;
|
||||||
|
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
||||||
|
use tracing::{error, info};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
||||||
|
|
||||||
|
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum AcmeMode {
|
||||||
|
Domain { domain: String },
|
||||||
|
Ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcmeCertProvider {
|
||||||
|
mode: AcmeMode,
|
||||||
|
cache_dir: Option<PathBuf>,
|
||||||
|
directory_url: String,
|
||||||
|
contact: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for AcmeCertProvider {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("AcmeCertProvider")
|
||||||
|
.field("mode", &self.mode)
|
||||||
|
.field("cache_dir", &self.cache_dir)
|
||||||
|
.field("directory_url", &self.directory_url)
|
||||||
|
.field("contact", &self.contact)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcmeCertProvider {
|
||||||
|
pub fn new(mode: AcmeMode) -> Self {
|
||||||
|
Self {
|
||||||
|
mode,
|
||||||
|
cache_dir: None,
|
||||||
|
directory_url: rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY.to_string(),
|
||||||
|
contact: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn domain(domain: impl Into<String>) -> Self {
|
||||||
|
Self::new(AcmeMode::Domain {
|
||||||
|
domain: domain.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ip() -> Self {
|
||||||
|
Self::new(AcmeMode::Ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||||
|
self.cache_dir = Some(dir.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_directory(mut self, url: impl Into<String>) -> Self {
|
||||||
|
self.directory_url = url.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_production_directory(mut self) -> Self {
|
||||||
|
self.directory_url = rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_contact(mut self, contact: impl Into<String>) -> Self {
|
||||||
|
self.contact.push(contact.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mode(&self) -> &AcmeMode {
|
||||||
|
&self.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_acme_state(&self) -> (AcmeState<std::io::Error>, Arc<ResolvesServerCertAcme>) {
|
||||||
|
let domains: Vec<String> = match &self.mode {
|
||||||
|
AcmeMode::Domain { domain } => vec![domain.clone()],
|
||||||
|
AcmeMode::Ip => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let base_config = AcmeConfig::new(domains)
|
||||||
|
.directory(&self.directory_url)
|
||||||
|
.contact(self.contact.clone());
|
||||||
|
|
||||||
|
let state = match &self.cache_dir {
|
||||||
|
Some(cache_dir) => {
|
||||||
|
base_config.cache(DirCache::new(cache_dir.clone())).state()
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
base_config
|
||||||
|
.cache(rustls_acme::caches::NoCache::default())
|
||||||
|
.state()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolver = state.resolver();
|
||||||
|
(state, resolver)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_server_config_with_resolver(
|
||||||
|
&self,
|
||||||
|
resolver: Arc<ResolvesServerCertAcme>,
|
||||||
|
) -> Result<Arc<ServerConfig>> {
|
||||||
|
let provider = default_provider().into();
|
||||||
|
let mut config = ServerConfig::builder_with_provider(provider)
|
||||||
|
.with_safe_default_protocol_versions()
|
||||||
|
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(resolver);
|
||||||
|
config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||||
|
Ok(Arc::new(config))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcmeTlsAcceptor {
|
||||||
|
listener: TcpListener,
|
||||||
|
listen_addr: SocketAddr,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
server_config: Arc<ServerConfig>,
|
||||||
|
tokio_acceptor: TokioTlsAcceptor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcmeTlsAcceptor {
|
||||||
|
pub async fn bind_acme(
|
||||||
|
addr: SocketAddr,
|
||||||
|
provider: Arc<AcmeCertProvider>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (state, resolver) = provider.build_acme_state();
|
||||||
|
|
||||||
|
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
||||||
|
|
||||||
|
Self::spawn_state_worker(state, resolver);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
let listen_addr = listener.local_addr()?;
|
||||||
|
|
||||||
|
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
listener,
|
||||||
|
listen_addr,
|
||||||
|
server_config,
|
||||||
|
tokio_acceptor,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
|
self.listen_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_state_worker(state: AcmeState<std::io::Error>, resolver: Arc<ResolvesServerCertAcme>) {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
let task = async move {
|
||||||
|
let mut state = state;
|
||||||
|
while let Some(event) = state.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(ok) => {
|
||||||
|
if let rustls_acme::EventOk::DeployedNewCert = ok {
|
||||||
|
info!("ACME: new certificate deployed");
|
||||||
|
} else {
|
||||||
|
info!("ACME event: {:?}", ok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => error!("ACME event error: {:?}", err),
|
||||||
|
}
|
||||||
|
if Arc::strong_count(&resolver) == 1 {
|
||||||
|
info!("ACME resolver dropped, stopping background task");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tokio::spawn(task);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TransportAcceptor for AcmeTlsAcceptor {
|
||||||
|
type Stream = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
|
||||||
|
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||||
|
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||||
|
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||||
|
|
||||||
|
let server_name = tls_stream
|
||||||
|
.get_ref()
|
||||||
|
.1
|
||||||
|
.server_name()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let info = TransportInfo {
|
||||||
|
remote_addr: Some(remote_addr),
|
||||||
|
transport_kind: TransportKind::Tls { server_name },
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((tls_stream, info))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_domain_mode() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||||
|
if let AcmeMode::Domain { domain } = provider.mode() {
|
||||||
|
assert_eq!(domain, "example.com");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_ip_mode() {
|
||||||
|
let provider = AcmeCertProvider::ip();
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Ip));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_default_staging_directory() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_production_directory() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com").with_production_directory();
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_custom_directory() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_directory("https://custom.acme.dir/");
|
||||||
|
assert_eq!(provider.directory_url, "https://custom.acme.dir/");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_with_cache_dir() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/acme_cache");
|
||||||
|
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/acme_cache")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_with_contact() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_contact("mailto:admin@example.com");
|
||||||
|
assert_eq!(
|
||||||
|
provider.contact,
|
||||||
|
vec!["mailto:admin@example.com".to_string()]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_state_domain() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
let (_state, resolver) = provider.build_acme_state();
|
||||||
|
assert!(Arc::strong_count(&resolver) >= 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_state_with_cache() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
||||||
|
let (_state, resolver) = provider.build_acme_state();
|
||||||
|
assert!(Arc::strong_count(&resolver) >= 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_server_config() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
let (_, resolver) = provider.build_acme_state();
|
||||||
|
let config = provider.build_server_config_with_resolver(resolver).unwrap();
|
||||||
|
assert!(!config.alpn_protocols.is_empty());
|
||||||
|
assert!(config
|
||||||
|
.alpn_protocols
|
||||||
|
.iter()
|
||||||
|
.any(|p| p == ACME_TLS_ALPN_NAME));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_mode_domain_debug() {
|
||||||
|
let mode = AcmeMode::Domain {
|
||||||
|
domain: "test.example.com".to_string(),
|
||||||
|
};
|
||||||
|
let debug_str = format!("{:?}", mode);
|
||||||
|
assert!(debug_str.contains("test.example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_mode_ip_debug() {
|
||||||
|
let mode = AcmeMode::Ip;
|
||||||
|
let debug_str = format!("{:?}", mode);
|
||||||
|
assert!(debug_str.contains("Ip"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_builder_chain() {
|
||||||
|
let provider = AcmeCertProvider::domain("test.example.com")
|
||||||
|
.with_production_directory()
|
||||||
|
.with_cache_dir("/tmp/cache")
|
||||||
|
.with_contact("mailto:admin@test.example.com");
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||||
|
);
|
||||||
|
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/cache")));
|
||||||
|
assert_eq!(provider.contact.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn acme_tls_acceptor_bind_acme() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
let provider = Arc::new(AcmeCertProvider::domain("example.com"));
|
||||||
|
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
|
||||||
|
let acceptor = AcmeTlsAcceptor::bind_acme(addr, provider).await.unwrap();
|
||||||
|
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn acme_staging_domain_cert_provisioning() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
|
||||||
|
let cache_dir = tempfile::tempdir().unwrap();
|
||||||
|
let provider = Arc::new(
|
||||||
|
AcmeCertProvider::domain("acme-test.example.com")
|
||||||
|
.with_cache_dir(cache_dir.path())
|
||||||
|
.with_contact("mailto:admin@example.com"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||||
|
let result = AcmeTlsAcceptor::bind_acme(addr, provider).await;
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"ACME TlsAcceptor should bind: {:?}",
|
||||||
|
result.err()
|
||||||
|
);
|
||||||
|
|
||||||
|
let acceptor = result.unwrap();
|
||||||
|
assert_eq!(acceptor.listen_addr().port(), 443);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,12 @@ mod tls;
|
|||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
mod acme;
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
pub use acme::{AcmeCertProvider, AcmeMode, AcmeTlsAcceptor};
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|||||||
@@ -9,8 +9,16 @@ use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
|||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
use rustls::crypto::aws_lc_rs::default_provider;
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
use rustls_acme::ResolvesServerCertAcme;
|
||||||
|
|
||||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||||
|
|
||||||
/// A TLS-based client transport that connects to a remote address over TLS.
|
/// A TLS-based client transport that connects to a remote address over TLS.
|
||||||
///
|
///
|
||||||
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
||||||
@@ -110,8 +118,10 @@ pub struct AcmeConfig {
|
|||||||
/// A TLS-based server transport acceptor that accepts TCP connections
|
/// A TLS-based server transport acceptor that accepts TCP connections
|
||||||
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
||||||
///
|
///
|
||||||
/// Requires certificate and private key configuration. Supports manual
|
/// Supports three certificate modes (ADR-008):
|
||||||
/// cert/key paths and an ACME config stub (ADR-008).
|
/// - Manual certs via `bind()` with explicit cert/key
|
||||||
|
/// - ACME certs via `bind_acme()` with an `AcmeCertProvider`
|
||||||
|
/// - The stub `AcmeConfig` parameter in `bind()` is kept for backward compat
|
||||||
pub struct TlsAcceptor {
|
pub struct TlsAcceptor {
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
listen_addr: SocketAddr,
|
listen_addr: SocketAddr,
|
||||||
@@ -145,6 +155,33 @@ impl TlsAcceptor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
pub async fn bind_acme(
|
||||||
|
addr: SocketAddr,
|
||||||
|
acme_resolver: Arc<ResolvesServerCertAcme>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
let listen_addr = listener.local_addr()?;
|
||||||
|
|
||||||
|
let provider = default_provider().into();
|
||||||
|
let mut server_config = ServerConfig::builder_with_provider(provider)
|
||||||
|
.with_safe_default_protocol_versions()
|
||||||
|
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(acme_resolver);
|
||||||
|
server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||||
|
|
||||||
|
let server_config = Arc::new(server_config);
|
||||||
|
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
listener,
|
||||||
|
listen_addr,
|
||||||
|
server_config,
|
||||||
|
tokio_acceptor,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn listen_addr(&self) -> SocketAddr {
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
self.listen_addr
|
self.listen_addr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ edition = "2021"
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
wraith-core = { path = "../wraith-core" }
|
wraith-core = { path = "../wraith-core", features = ["tls", "iroh"] }
|
||||||
napi = "3"
|
napi = { version = "3", features = ["async", "error_anyhow"] }
|
||||||
napi-derive = "3"
|
napi-derive = "3"
|
||||||
|
tokio = { version = "1", features = ["io-util", "sync"] }
|
||||||
|
russh = "0.49"
|
||||||
249
crates/wraith-napi/src/connect.rs
Normal file
249
crates/wraith-napi/src/connect.rs
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use napi::bindgen_prelude::*;
|
||||||
|
use napi_derive::napi;
|
||||||
|
use russh::client;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use wraith_core::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||||
|
use wraith_core::auth::keys::KeySource;
|
||||||
|
use wraith_core::transport::{TcpTransport, TlsTransport, Transport};
|
||||||
|
|
||||||
|
const DEFAULT_HOST: &str = "wraith-control";
|
||||||
|
const DEFAULT_PORT: u32 = 0;
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct WraithConnectOptions {
|
||||||
|
pub server: Option<String>,
|
||||||
|
pub peer: Option<String>,
|
||||||
|
pub transport: String,
|
||||||
|
pub identity: Option<Either<String, Buffer>>,
|
||||||
|
pub tls_server_name: Option<String>,
|
||||||
|
pub insecure: Option<bool>,
|
||||||
|
pub iroh_relay: Option<String>,
|
||||||
|
pub proxy: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_key_source(identity: &Option<Either<String, Buffer>>) -> Result<KeySource> {
|
||||||
|
match identity {
|
||||||
|
None => Err(Error::new(
|
||||||
|
Status::InvalidArg,
|
||||||
|
"identity is required: provide a file path (string) or key data (Buffer)",
|
||||||
|
)),
|
||||||
|
Some(Either::A(path)) => Ok(KeySource::File(path.into())),
|
||||||
|
Some(Either::B(buf)) => Ok(KeySource::Memory(buf.to_vec())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_addr(addr_str: &str) -> Result<SocketAddr> {
|
||||||
|
addr_str.parse().map_err(|e| {
|
||||||
|
Error::new(
|
||||||
|
Status::InvalidArg,
|
||||||
|
format!("invalid server address '{}': {}", addr_str, e),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub struct WraithStream {
|
||||||
|
read: Arc<Mutex<tokio::io::ReadHalf<russh::ChannelStream<client::Msg>>>>,
|
||||||
|
write: Arc<Mutex<tokio::io::WriteHalf<russh::ChannelStream<client::Msg>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl WraithStream {
|
||||||
|
#[napi]
|
||||||
|
pub async fn read(&self, size: u32) -> Result<Buffer> {
|
||||||
|
let mut buf = vec![0u8; size as usize];
|
||||||
|
let mut guard = self.read.lock().await;
|
||||||
|
let n = guard.read(&mut buf).await.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("read failed: {}", e))
|
||||||
|
})?;
|
||||||
|
if n == 0 {
|
||||||
|
return Ok(Vec::<u8>::new().into());
|
||||||
|
}
|
||||||
|
buf.truncate(n);
|
||||||
|
Ok(buf.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn write(&self, data: Buffer) -> Result<()> {
|
||||||
|
let mut guard = self.write.lock().await;
|
||||||
|
guard.write_all(&data).await.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("write failed: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn close(&self) -> Result<()> {
|
||||||
|
let mut guard = self.write.lock().await;
|
||||||
|
guard.shutdown().await.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("close failed: {}", e))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn connect(options: WraithConnectOptions) -> Result<WraithStream> {
|
||||||
|
let key_source = resolve_key_source(&options.identity)?;
|
||||||
|
let auth_config = Arc::new(ClientAuthConfig::from_key_source(key_source).map_err(|e| {
|
||||||
|
Error::new(Status::InvalidArg, format!("invalid identity key: {}", e))
|
||||||
|
})?);
|
||||||
|
|
||||||
|
let transport_mode = options.transport.to_lowercase();
|
||||||
|
let handler = ClientHandler::from_config(&auth_config);
|
||||||
|
let username = "wraith".to_string();
|
||||||
|
|
||||||
|
let config = Arc::new(client::Config::default());
|
||||||
|
|
||||||
|
let mut handle: client::Handle<ClientHandler> = match transport_mode.as_str() {
|
||||||
|
"tcp" => {
|
||||||
|
let server = options.server.as_ref().ok_or_else(|| {
|
||||||
|
Error::new(Status::InvalidArg, "server is required for tcp transport")
|
||||||
|
})?;
|
||||||
|
let addr = parse_addr(server)?;
|
||||||
|
let transport = TcpTransport::new(addr);
|
||||||
|
let stream = transport.connect().await.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("tcp connect failed: {}", e))
|
||||||
|
})?;
|
||||||
|
client::connect_stream(config, stream, handler)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::new(
|
||||||
|
Status::GenericFailure,
|
||||||
|
format!("ssh handshake failed: {}", e),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
"tls" => {
|
||||||
|
let server = options.server.as_ref().ok_or_else(|| {
|
||||||
|
Error::new(Status::InvalidArg, "server is required for tls transport")
|
||||||
|
})?;
|
||||||
|
let addr = parse_addr(server)?;
|
||||||
|
let mut transport = TlsTransport::new(addr);
|
||||||
|
if let Some(ref name) = options.tls_server_name {
|
||||||
|
transport = transport.with_server_name(name);
|
||||||
|
}
|
||||||
|
if let Some(true) = options.insecure {
|
||||||
|
transport = transport.with_insecure(true);
|
||||||
|
}
|
||||||
|
let stream = transport.connect().await.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("tls connect failed: {}", e))
|
||||||
|
})?;
|
||||||
|
client::connect_stream(config, stream, handler)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::new(
|
||||||
|
Status::GenericFailure,
|
||||||
|
format!("ssh handshake failed: {}", e),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
"iroh" => {
|
||||||
|
return Err(Error::new(
|
||||||
|
Status::GenericFailure,
|
||||||
|
"iroh transport is not yet supported in napi connect()".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(Error::new(
|
||||||
|
Status::InvalidArg,
|
||||||
|
format!("unknown transport '{}'; expected tcp, tls, or iroh", transport_mode),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_ok = auth_config
|
||||||
|
.authenticate(&mut handle, &username)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e))
|
||||||
|
})?;
|
||||||
|
if !auth_ok {
|
||||||
|
return Err(Error::new(Status::GenericFailure, "ssh authentication rejected"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let channel = handle
|
||||||
|
.channel_open_direct_tcpip(DEFAULT_HOST, DEFAULT_PORT, "127.0.0.1", 0)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
Error::new(
|
||||||
|
Status::GenericFailure,
|
||||||
|
format!("failed to open ssh channel: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let stream = channel.into_stream();
|
||||||
|
let (read_half, write_half) = tokio::io::split(stream);
|
||||||
|
|
||||||
|
Ok(WraithStream {
|
||||||
|
read: Arc::new(Mutex::new(read_half)),
|
||||||
|
write: Arc::new(Mutex::new(write_half)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_key_source_file_path() {
|
||||||
|
let identity = Some(Either::<String, Buffer>::A("/path/to/key".to_string()));
|
||||||
|
let result = resolve_key_source(&identity);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
match result.unwrap() {
|
||||||
|
KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/key")),
|
||||||
|
_ => panic!("expected File variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_key_source_buffer() {
|
||||||
|
let identity = Some(Either::<String, Buffer>::B(Buffer::from(ED25519_PRIVATE_KEY.as_bytes().to_vec())));
|
||||||
|
let result = resolve_key_source(&identity);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
match result.unwrap() {
|
||||||
|
KeySource::Memory(data) => assert!(!data.is_empty()),
|
||||||
|
_ => panic!("expected Memory variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_key_source_missing() {
|
||||||
|
let identity: Option<Either<String, Buffer>> = None;
|
||||||
|
let result = resolve_key_source(&identity);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_addr_valid() {
|
||||||
|
let addr = parse_addr("127.0.0.1:22");
|
||||||
|
assert!(addr.is_ok());
|
||||||
|
assert_eq!(addr.unwrap().port(), 22);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_addr_invalid() {
|
||||||
|
let addr = parse_addr("not-an-address");
|
||||||
|
assert!(addr.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_config_from_memory_key() {
|
||||||
|
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||||
|
let config = ClientAuthConfig::from_key_source(source);
|
||||||
|
assert!(config.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_config_from_invalid_key() {
|
||||||
|
let source = KeySource::Memory(b"not-a-key".to_vec());
|
||||||
|
let config = ClientAuthConfig::from_key_source(source);
|
||||||
|
assert!(config.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate napi_derive;
|
extern crate napi_derive;
|
||||||
|
|
||||||
|
mod connect;
|
||||||
@@ -7,8 +7,15 @@ edition = "2021"
|
|||||||
name = "wraith"
|
name = "wraith"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["tls", "iroh"]
|
||||||
|
tls = ["wraith-core/tls"]
|
||||||
|
iroh = ["wraith-core/iroh", "dep:iroh", "dep:url"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
wraith-core = { path = "../wraith-core" }
|
wraith-core = { path = "../wraith-core" }
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive", "env"] }
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
iroh = { version = "0.34", optional = true }
|
||||||
|
url = { version = "2", optional = true }
|
||||||
@@ -1 +1,224 @@
|
|||||||
fn main() {}
|
use std::net::SocketAddr;
|
||||||
|
use std::process;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
|
use wraith_core::auth::keys::KeySource;
|
||||||
|
use wraith_core::client::{ConnectOptions, TransportMode};
|
||||||
|
use wraith_core::transport::TcpTransport;
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
use wraith_core::transport::TlsTransport;
|
||||||
|
#[cfg(feature = "iroh")]
|
||||||
|
use wraith_core::transport::IrohTransport;
|
||||||
|
use wraith_core::transport::Transport;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "wraith", version, about = "Wraith SSH tunnel client")]
|
||||||
|
struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
#[command(about = "Connect to a wraith server and start a SOCKS5 proxy / port forwarding session")]
|
||||||
|
Connect {
|
||||||
|
#[arg(long, help = "TCP/TLS server address (required for tcp/tls transport)", env = "WRAITH_SERVER")]
|
||||||
|
server: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, help = "iroh endpoint ID, base58-encoded (required for iroh transport)")]
|
||||||
|
peer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, value_enum, default_value = "tcp", help = "Transport mode")]
|
||||||
|
transport: TransportModeArg,
|
||||||
|
|
||||||
|
#[arg(long, help = "SSH private key path", env = "WRAITH_IDENTITY")]
|
||||||
|
identity: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "127.0.0.1:1080", help = "SOCKS5 listen address")]
|
||||||
|
socks5: String,
|
||||||
|
|
||||||
|
#[arg(long, action = clap::ArgAction::Append, help = "Port forward spec (repeatable, e.g. 5432:db:5432)")]
|
||||||
|
forward: Vec<String>,
|
||||||
|
|
||||||
|
#[arg(long, action = clap::ArgAction::Append, help = "Remote port forward spec (repeatable)")]
|
||||||
|
remote_forward: Vec<String>,
|
||||||
|
|
||||||
|
#[arg(long, help = "Upstream proxy URL (socks5:// or http://)")]
|
||||||
|
proxy: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, help = "iroh relay URL")]
|
||||||
|
iroh_relay: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, help = "SNI hostname for TLS")]
|
||||||
|
tls_server_name: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, help = "Accept self-signed TLS certs")]
|
||||||
|
insecure: bool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, ValueEnum)]
|
||||||
|
enum TransportModeArg {
|
||||||
|
Tcp,
|
||||||
|
Tls,
|
||||||
|
Iroh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TransportModeArg> for TransportMode {
|
||||||
|
fn from(val: TransportModeArg) -> Self {
|
||||||
|
match val {
|
||||||
|
TransportModeArg::Tcp => TransportMode::Tcp,
|
||||||
|
TransportModeArg::Tls => TransportMode::Tls,
|
||||||
|
TransportModeArg::Iroh => TransportMode::Iroh,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
if let Err(e) = run().await {
|
||||||
|
eprintln!("error: {e}");
|
||||||
|
process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run() -> Result<()> {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Commands::Connect {
|
||||||
|
server,
|
||||||
|
peer,
|
||||||
|
transport,
|
||||||
|
identity,
|
||||||
|
socks5,
|
||||||
|
forward,
|
||||||
|
remote_forward,
|
||||||
|
proxy,
|
||||||
|
iroh_relay,
|
||||||
|
tls_server_name,
|
||||||
|
insecure,
|
||||||
|
} => {
|
||||||
|
let identity_val = identity
|
||||||
|
.ok_or_else(|| anyhow!("--identity is required (or set WRAITH_IDENTITY env var)"))?;
|
||||||
|
let key_source = KeySource::File(identity_val.into());
|
||||||
|
|
||||||
|
let transport_mode: TransportMode = transport.into();
|
||||||
|
|
||||||
|
if proxy.is_some() && matches!(transport_mode, TransportMode::Tcp) {
|
||||||
|
eprintln!("warning: --proxy with --transport tcp is effectively a no-op (TCP transport is already a direct connection); use the SOCKS5 server instead");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut opts = ConnectOptions::new(key_source)
|
||||||
|
.transport_mode(transport_mode.clone())
|
||||||
|
.socks5_addr(&socks5);
|
||||||
|
|
||||||
|
if let Some(ref s) = server {
|
||||||
|
opts = opts.server(s);
|
||||||
|
}
|
||||||
|
if let Some(ref p) = peer {
|
||||||
|
opts = opts.peer(p);
|
||||||
|
}
|
||||||
|
for fwd in &forward {
|
||||||
|
opts = opts.forward(fwd);
|
||||||
|
}
|
||||||
|
for rfwd in &remote_forward {
|
||||||
|
opts = opts.remote_forward(rfwd);
|
||||||
|
}
|
||||||
|
if let Some(ref p) = proxy {
|
||||||
|
opts = opts.proxy(p);
|
||||||
|
}
|
||||||
|
if let Some(ref r) = iroh_relay {
|
||||||
|
opts = opts.iroh_relay(r);
|
||||||
|
}
|
||||||
|
if let Some(ref n) = tls_server_name {
|
||||||
|
opts = opts.tls_server_name(n);
|
||||||
|
}
|
||||||
|
if insecure {
|
||||||
|
opts = opts.insecure(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.validate().map_err(|e| anyhow!("{e}"))?;
|
||||||
|
|
||||||
|
match transport_mode {
|
||||||
|
TransportMode::Tcp => {
|
||||||
|
let addr: SocketAddr = server
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| anyhow!("--server is required for tcp transport"))?
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("invalid server address: {e}"))?;
|
||||||
|
let t = Arc::new(TcpTransport::new(addr));
|
||||||
|
connect_and_run(opts, t).await
|
||||||
|
}
|
||||||
|
TransportMode::Tls => {
|
||||||
|
#[cfg(not(feature = "tls"))]
|
||||||
|
{
|
||||||
|
return Err(anyhow!("TLS transport is not available (wraith-core built without 'tls' feature)"));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
{
|
||||||
|
let addr: SocketAddr = server
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| anyhow!("--server is required for tls transport"))?
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("invalid server address: {e}"))?;
|
||||||
|
let mut t = TlsTransport::new(addr);
|
||||||
|
if let Some(ref n) = tls_server_name {
|
||||||
|
t = t.with_server_name(n);
|
||||||
|
}
|
||||||
|
t = t.with_insecure(insecure);
|
||||||
|
let t = Arc::new(t);
|
||||||
|
connect_and_run(opts, t).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TransportMode::Iroh => {
|
||||||
|
#[cfg(not(feature = "iroh"))]
|
||||||
|
{
|
||||||
|
return Err(anyhow!("iroh transport is not available (wraith-core built without 'iroh' feature)"));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "iroh")]
|
||||||
|
{
|
||||||
|
use iroh::{NodeId, RelayUrl};
|
||||||
|
let node_id_str = peer
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| anyhow!("--peer is required for iroh transport"))?;
|
||||||
|
let node_id: NodeId = node_id_str
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("invalid iroh peer endpoint ID: {e}"))?;
|
||||||
|
let relay_url: Option<RelayUrl> = match iroh_relay.as_deref() {
|
||||||
|
Some(u) => Some(
|
||||||
|
u.parse()
|
||||||
|
.map_err(|e| anyhow!("invalid iroh relay URL: {e}"))?,
|
||||||
|
),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let proxy_url: Option<url::Url> = match proxy.as_deref() {
|
||||||
|
Some(u) => Some(
|
||||||
|
u.parse()
|
||||||
|
.map_err(|e| anyhow!("invalid proxy URL: {e}"))?,
|
||||||
|
),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let t = Arc::new(
|
||||||
|
IrohTransport::new(node_id, relay_url, proxy_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("failed to create iroh transport: {e}"))?,
|
||||||
|
);
|
||||||
|
connect_and_run(opts, t).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_and_run<T: Transport>(opts: ConnectOptions, transport: Arc<T>) -> Result<()> {
|
||||||
|
wraith_core::client::ClientSession::new(opts, transport)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("{e}"))?
|
||||||
|
.run()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("{e}"))
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
---
|
---
|
||||||
id: client/channel-manager
|
id: client/channel-manager
|
||||||
name: Implement ChannelManager — SSH session management, channel opens, reconnection
|
name: Implement ChannelManager — SSH session management, channel opens, reconnection
|
||||||
status: pending
|
status: done
|
||||||
depends_on:
|
depends_on:
|
||||||
- auth/client-auth-handler
|
- auth/client-auth-handler
|
||||||
- transport/trait-and-types
|
- transport/trait-and-types
|
||||||
@@ -32,18 +32,18 @@ Reconnection is always enabled. The backoff caps at 30 seconds and continues ind
|
|||||||
|
|
||||||
## Acceptance Criteria
|
## Acceptance Criteria
|
||||||
|
|
||||||
- [ ] `crates/wraith-core/src/client/channel_manager.rs` exports `ChannelManager`
|
- [x] `crates/wraith-core/src/client/channel_manager.rs` exports `ChannelManager`
|
||||||
- [ ] `ChannelManager` holds: `Arc<Transport>`, `Arc<ClientAuthConfig>`, `Arc<client::Handle<ClientHandler>>` (behind RwLock for reconnection)
|
- [x] `ChannelManager` holds: `Arc<Transport>`, `Arc<ClientAuthConfig>`, `Arc<client::Handle<ClientHandler>>` (behind RwLock for reconnection)
|
||||||
- [ ] `ChannelManager::new()` establishes initial transport connection, authenticates, returns manager
|
- [x] `ChannelManager::new()` establishes initial transport connection, authenticates, returns manager
|
||||||
- [ ] `open_direct_tcpip(host, port)` — opens SSH channel to target
|
- [x] `open_direct_tcpip(host, port)` — opens SSH channel to target
|
||||||
- [ ] `request_tcpip_forward(addr, port)` — sends `tcpip_forward` request
|
- [x] `request_tcpip_forward(addr, port)` — sends `tcpip_forward` request
|
||||||
- [ ] `cancel_tcpip_forward(addr, port)` — sends `cancel_tcpip_forward` request
|
- [x] `cancel_tcpip_forward(addr, port)` — sends `cancel_tcpip_forward` request
|
||||||
- [ ] Reconnection detection: monitors `handle.is_closed()`, triggers reconnect on failure
|
- [x] Reconnection detection: monitors `handle.is_closed()`, triggers reconnect on failure
|
||||||
- [ ] Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely
|
- [x] Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely
|
||||||
- [ ] Full reconnect: new transport stream, new SSH session over it (ADR-004)
|
- [x] Full reconnect: new transport stream, new SSH session over it (ADR-004)
|
||||||
- [ ] After reconnect: port forward listeners (`-L`, `-R`) re-registered with new session
|
- [x] After reconnect: port forward listeners (`-L`, `-R`) re-registered with new session
|
||||||
- [ ] In-flight connections on old session fail gracefully (channel errors, not session-wide)
|
- [x] In-flight connections on old session fail gracefully (channel errors, not session-wide)
|
||||||
- [ ] Unit tests: channel open, reconnection trigger, backoff timing, forward re-registration
|
- [x] Unit tests: channel open, reconnection trigger, backoff timing, forward re-registration
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
@@ -52,8 +52,13 @@ Reconnection is always enabled. The backoff caps at 30 seconds and continues ind
|
|||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
> To be filled by implementation agent
|
- Converted `client.rs` (single file) to directory module: `client/mod.rs` + `client/channel_manager.rs`
|
||||||
|
- Used `russh::keys::PrivateKey` and `russh::keys::PublicKey` (not the nonexistent `russh::key::KeyPair`)
|
||||||
|
- Reconnection monitor runs as a spawned tokio task that polls `handle.is_closed()` every 1s
|
||||||
|
- On reconnect: creates new transport stream + new SSH session (ADR-004 full reconnect)
|
||||||
|
- `ForwardRequest` struct tracks registered port forwards for re-registration after reconnect
|
||||||
|
- In-flight channels on old session naturally fail with `ChannelError::ChannelClosed` since the handle is replaced
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
> To be filled on completion
|
Implemented `ChannelManager` in `crates/wraith-core/src/client/channel_manager.rs` with SSH session management, channel opens (`open_direct_tcpip`), port forward requests (`request_tcpip_forward`/`cancel_tcpip_forward`), and automatic reconnection with exponential backoff (1s→30s cap). Full reconnect per ADR-004 creates new transport stream + new SSH session. Port forwards are re-registered after successful reconnect. 8 unit tests covering backoff timing, forward tracking, transport failure, and reconnection detection.
|
||||||
@@ -43,8 +43,14 @@ This integrates with `TlsAcceptor` by providing ACME-resolved certificates inste
|
|||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
> To be filled by implementation agent
|
- `AcmeCertProvider` is the main entry point. It creates `AcmeState` and `ResolvesServerCertAcme` from `rustls-acme`.
|
||||||
|
- The `ResolvesServerCertAcme` resolver is shared between the `AcmeState` background task and the `ServerConfig`, so cert updates propagate automatically.
|
||||||
|
- `AcmeTlsAcceptor::bind_acme()` creates a TLS acceptor that uses ACME-provisioned certs and spawns a background tokio task for auto-renewal.
|
||||||
|
- `TlsAcceptor::bind_acme()` also added for users who want to use ACME with the standard `TlsAcceptor` type directly.
|
||||||
|
- The `AcmeConfig` stub in `tls.rs` is retained for backward compat with existing `TlsAcceptor::bind()`.
|
||||||
|
- `acme` feature implies `tls` and adds `rustls-acme` + `futures` dependencies.
|
||||||
|
- TLS-ALPN-01 challenge handling works via the `acme-tls/1` ALPN protocol registered in `ServerConfig` — the resolver dispatches challenge vs regular certs automatically.
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
> To be filled on completion
|
Implemented ACME/Let's Encrypt certificate provisioning (ADR-008) behind the `acme` feature flag. `AcmeCertProvider` supports domain-based and IP-based modes using `rustls-acme`. `AcmeTlsAcceptor::bind_acme()` and `TlsAcceptor::bind_acme()` provide ACME-integrated TLS acceptance with automatic certificate renewal via a background tokio task. Unit tests cover config construction, builder patterns, and server config generation. Integration test for LE staging is marked `#[ignore]`.
|
||||||
Reference in New Issue
Block a user