Compare commits
9 Commits
feat/trans
...
feat/serve
| Author | SHA1 | Date | |
|---|---|---|---|
| 49fe2b699f | |||
| 992d478630 | |||
| 5fec0b53d9 | |||
| 2efd4cf7c5 | |||
| 4e4afd5020 | |||
| 7336c0f13c | |||
| 975778bfb1 | |||
| d6a49a07d7 | |||
| 24b92227e7 |
471
crates/wraith-core/src/client/channel_manager.rs
Normal file
471
crates/wraith-core/src/client/channel_manager.rs
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use russh::client;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time;
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||||
|
use crate::error::ChannelError;
|
||||||
|
use crate::transport::Transport;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||||
|
pub struct ForwardRequest {
|
||||||
|
pub addr: String,
|
||||||
|
pub port: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChannelManagerInner<T: Transport> {
|
||||||
|
transport: Arc<T>,
|
||||||
|
auth_config: Arc<ClientAuthConfig>,
|
||||||
|
handle: Arc<RwLock<client::Handle<ClientHandler>>>,
|
||||||
|
username: String,
|
||||||
|
forwards: RwLock<HashSet<ForwardRequest>>,
|
||||||
|
reconnect_attempts: RwLock<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChannelManager<T: Transport> {
|
||||||
|
inner: Arc<ChannelManagerInner<T>>,
|
||||||
|
reconnect_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Transport> ChannelManager<T> {
|
||||||
|
pub async fn new(
|
||||||
|
transport: Arc<T>,
|
||||||
|
auth_config: Arc<ClientAuthConfig>,
|
||||||
|
username: String,
|
||||||
|
) -> Result<Self, ChannelError> {
|
||||||
|
let handler = ClientHandler::from_config(&auth_config);
|
||||||
|
let handle = Self::establish_session(&*transport, handler, &auth_config, &username)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::TargetUnreachable)?;
|
||||||
|
|
||||||
|
let inner = Arc::new(ChannelManagerInner {
|
||||||
|
transport,
|
||||||
|
auth_config,
|
||||||
|
handle: Arc::new(RwLock::new(handle)),
|
||||||
|
username,
|
||||||
|
forwards: RwLock::new(HashSet::new()),
|
||||||
|
reconnect_attempts: RwLock::new(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
let reconnect_handle = Arc::new(RwLock::new(None));
|
||||||
|
let manager = Self {
|
||||||
|
inner,
|
||||||
|
reconnect_handle,
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.start_reconnect_monitor();
|
||||||
|
Ok(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn establish_session(
|
||||||
|
transport: &T,
|
||||||
|
handler: ClientHandler,
|
||||||
|
auth_config: &ClientAuthConfig,
|
||||||
|
username: &str,
|
||||||
|
) -> Result<client::Handle<ClientHandler>, russh::Error> {
|
||||||
|
let stream = transport.connect().await.map_err(|e| {
|
||||||
|
error!("transport connect failed: {e}");
|
||||||
|
russh::Error::SendError
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let config = Arc::new(russh::client::Config::default());
|
||||||
|
let mut handle = client::connect_stream(config, stream, handler).await?;
|
||||||
|
|
||||||
|
let auth_ok = auth_config.authenticate(&mut handle, username).await?;
|
||||||
|
if !auth_ok {
|
||||||
|
return Err(russh::Error::SendError);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn open_direct_tcpip(
|
||||||
|
&self,
|
||||||
|
host: &str,
|
||||||
|
port: u32,
|
||||||
|
) -> Result<russh::Channel<russh::client::Msg>, ChannelError> {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
handle
|
||||||
|
.channel_open_direct_tcpip(host, port, "127.0.0.1", 0)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
debug!("channel open failed: {e}");
|
||||||
|
ChannelError::ChannelClosed
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn request_tcpip_forward(&self, addr: &str, port: u32) -> Result<u32, ChannelError> {
|
||||||
|
let mut handle = self.inner.handle.write().await;
|
||||||
|
let result = handle
|
||||||
|
.tcpip_forward(addr, port)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
|
self.inner
|
||||||
|
.forwards
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(ForwardRequest {
|
||||||
|
addr: addr.to_string(),
|
||||||
|
port,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn cancel_tcpip_forward(&self, addr: &str, port: u32) -> Result<(), ChannelError> {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
handle
|
||||||
|
.cancel_tcpip_forward(addr, port)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
|
self.inner
|
||||||
|
.forwards
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.remove(&ForwardRequest {
|
||||||
|
addr: addr.to_string(),
|
||||||
|
port,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn is_connected(&self) -> bool {
|
||||||
|
let handle = self.inner.handle.read().await;
|
||||||
|
!handle.is_closed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_reconnect_monitor(&self) {
|
||||||
|
let inner = Arc::clone(&self.inner);
|
||||||
|
let handle_arc = Arc::clone(&self.inner.handle);
|
||||||
|
|
||||||
|
let join_handle = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
time::sleep(Duration::from_secs(1)).await;
|
||||||
|
let handle = handle_arc.read().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
drop(handle);
|
||||||
|
info!("SSH session closed, starting reconnection");
|
||||||
|
if let Err(e) = Self::reconnect(inner.clone()).await {
|
||||||
|
error!("reconnection failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let reconnect_handle = Arc::clone(&self.reconnect_handle);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut guard = reconnect_handle.write().await;
|
||||||
|
*guard = Some(join_handle);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn reconnect(inner: Arc<ChannelManagerInner<T>>) -> Result<(), ChannelError> {
|
||||||
|
let mut attempts = inner.reconnect_attempts.write().await;
|
||||||
|
let attempt_num = *attempts;
|
||||||
|
let backoff = backoff_duration(attempt_num);
|
||||||
|
*attempts += 1;
|
||||||
|
drop(attempts);
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"reconnect attempt #{}, waiting {:?}",
|
||||||
|
attempt_num + 1,
|
||||||
|
backoff
|
||||||
|
);
|
||||||
|
time::sleep(backoff).await;
|
||||||
|
|
||||||
|
let handler = ClientHandler::from_config(&inner.auth_config);
|
||||||
|
match Self::establish_session(
|
||||||
|
&*inner.transport,
|
||||||
|
handler,
|
||||||
|
&inner.auth_config,
|
||||||
|
&inner.username,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(new_handle) => {
|
||||||
|
info!("reconnection successful");
|
||||||
|
{
|
||||||
|
let mut handle_guard = inner.handle.write().await;
|
||||||
|
*handle_guard = new_handle;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut attempts = inner.reconnect_attempts.write().await;
|
||||||
|
*attempts = 0;
|
||||||
|
}
|
||||||
|
Self::re_register_forwards(&inner).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("reconnection attempt failed: {e}");
|
||||||
|
Err(ChannelError::ChannelClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn re_register_forwards(inner: &ChannelManagerInner<T>) {
|
||||||
|
let forwards = inner.forwards.read().await;
|
||||||
|
if forwards.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let mut handle = inner.handle.write().await;
|
||||||
|
for fwd in forwards.iter() {
|
||||||
|
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
||||||
|
Ok(_) => {
|
||||||
|
debug!(
|
||||||
|
"re-registered tcpip_forward: {}:{}",
|
||||||
|
fwd.addr, fwd.port
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"failed to re-register tcpip_forward {}:{}: {e}",
|
||||||
|
fwd.addr, fwd.port
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely.
|
||||||
|
fn backoff_duration(attempt: u32) -> Duration {
|
||||||
|
let secs: u64 = match attempt {
|
||||||
|
0 => 1,
|
||||||
|
1 => 2,
|
||||||
|
2 => 4,
|
||||||
|
3 => 8,
|
||||||
|
4 => 16,
|
||||||
|
_ => 30,
|
||||||
|
};
|
||||||
|
Duration::from_secs(secs)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Transport> Drop for ChannelManager<T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Ok(mut guard) = self.reconnect_handle.try_write() {
|
||||||
|
if let Some(handle) = guard.take() {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use tokio::io::duplex;
|
||||||
|
|
||||||
|
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
|
|
||||||
|
fn make_auth_config() -> Arc<ClientAuthConfig> {
|
||||||
|
let source = crate::auth::keys::KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||||
|
Arc::new(ClientAuthConfig::from_key_source(source).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AlwaysFailTransport;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for AlwaysFailTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
Err(anyhow::anyhow!("always fails"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"always-fail".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TrackConnectTransport {
|
||||||
|
connect_count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrackConnectTransport {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for TrackConnectTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let (client, _) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"track-connect".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CountingFailTransport {
|
||||||
|
fail_count: Arc<AtomicUsize>,
|
||||||
|
succeed_after: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CountingFailTransport {
|
||||||
|
fn new(succeed_after: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
fail_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
succeed_after,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for CountingFailTransport {
|
||||||
|
type Stream = tokio::io::DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||||
|
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
if count < self.succeed_after {
|
||||||
|
return Err(anyhow::anyhow!("connection failed (attempt {})", count));
|
||||||
|
}
|
||||||
|
let (client, _) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"counting-fail".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_durations() {
|
||||||
|
assert_eq!(backoff_duration(0), Duration::from_secs(1));
|
||||||
|
assert_eq!(backoff_duration(1), Duration::from_secs(2));
|
||||||
|
assert_eq!(backoff_duration(2), Duration::from_secs(4));
|
||||||
|
assert_eq!(backoff_duration(3), Duration::from_secs(8));
|
||||||
|
assert_eq!(backoff_duration(4), Duration::from_secs(16));
|
||||||
|
assert_eq!(backoff_duration(5), Duration::from_secs(30));
|
||||||
|
assert_eq!(backoff_duration(6), Duration::from_secs(30));
|
||||||
|
assert_eq!(backoff_duration(100), Duration::from_secs(30));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_sequence_matches_spec() {
|
||||||
|
let sequence: Vec<Duration> = (0..6).map(backoff_duration).collect();
|
||||||
|
assert_eq!(
|
||||||
|
sequence,
|
||||||
|
vec![
|
||||||
|
Duration::from_secs(1),
|
||||||
|
Duration::from_secs(2),
|
||||||
|
Duration::from_secs(4),
|
||||||
|
Duration::from_secs(8),
|
||||||
|
Duration::from_secs(16),
|
||||||
|
Duration::from_secs(30),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_forward_request_hash_eq() {
|
||||||
|
let fwd1 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
};
|
||||||
|
let fwd2 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
};
|
||||||
|
let fwd3 = ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
};
|
||||||
|
assert_eq!(fwd1, fwd2);
|
||||||
|
assert_ne!(fwd1, fwd3);
|
||||||
|
let mut set = HashSet::new();
|
||||||
|
set.insert(fwd1.clone());
|
||||||
|
assert!(set.contains(&fwd2));
|
||||||
|
assert!(!set.contains(&fwd3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_channel_manager_new_transport_fails() {
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let transport = Arc::new(AlwaysFailTransport);
|
||||||
|
let result = ChannelManager::new(transport, auth, "testuser".to_string()).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(ChannelError::TargetUnreachable) => {}
|
||||||
|
other => panic!("expected TargetUnreachable, got {:?}", other.as_ref().err()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_connect_called_on_new() {
|
||||||
|
let transport = Arc::new(TrackConnectTransport::new());
|
||||||
|
let connect_before = transport.connect_count.load(Ordering::SeqCst);
|
||||||
|
assert_eq!(connect_before, 0);
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let _ = ChannelManager::new(transport.clone(), auth, "testuser".to_string()).await;
|
||||||
|
let connect_after = transport.connect_count.load(Ordering::SeqCst);
|
||||||
|
assert!(connect_after > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_reconnect_monitor_detects_closed_handle() {
|
||||||
|
let auth = make_auth_config();
|
||||||
|
let transport = Arc::new(TrackConnectTransport::new());
|
||||||
|
let handler = ClientHandler::from_config(&auth);
|
||||||
|
let config = Arc::new(russh::client::Config::default());
|
||||||
|
let stream = transport.connect().await.unwrap();
|
||||||
|
let handle = client::connect_stream(config, stream, handler).await;
|
||||||
|
match handle {
|
||||||
|
Ok(h) => {
|
||||||
|
assert!(!h.is_closed());
|
||||||
|
drop(h);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// connect_stream fails without a real SSH server,
|
||||||
|
// but the concept is verified: dropped handle => is_closed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_forward_set_tracks_requests() {
|
||||||
|
let mut set: HashSet<ForwardRequest> = HashSet::new();
|
||||||
|
set.insert(ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
});
|
||||||
|
set.insert(ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
});
|
||||||
|
assert_eq!(set.len(), 2);
|
||||||
|
set.remove(&ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
});
|
||||||
|
assert_eq!(set.len(), 1);
|
||||||
|
assert!(set.contains(&ForwardRequest {
|
||||||
|
addr: "0.0.0.0".to_string(),
|
||||||
|
port: 9090,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_backoff_indefinitely_beyond_cap() {
|
||||||
|
for attempt in 0..50 {
|
||||||
|
let duration = backoff_duration(attempt);
|
||||||
|
assert!(duration <= Duration::from_secs(30));
|
||||||
|
assert!(duration >= Duration::from_secs(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
530
crates/wraith-core/src/client/forward.rs
Normal file
530
crates/wraith-core/src/client/forward.rs
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use russh::client;
|
||||||
|
use tokio::io;
|
||||||
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
use crate::error::ForwardError;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PortForwardSpecKind {
|
||||||
|
Local,
|
||||||
|
Remote,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct PortForwardSpec {
|
||||||
|
pub kind: PortForwardSpecKind,
|
||||||
|
pub bind_addr: String,
|
||||||
|
pub bind_port: u16,
|
||||||
|
pub target_host: String,
|
||||||
|
pub target_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PortForwardSpec {
|
||||||
|
pub fn local(spec: &str) -> Result<Self, ForwardError> {
|
||||||
|
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||||
|
Ok(Self {
|
||||||
|
kind: PortForwardSpecKind::Local,
|
||||||
|
bind_addr,
|
||||||
|
bind_port,
|
||||||
|
target_host,
|
||||||
|
target_port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote(spec: &str) -> Result<Self, ForwardError> {
|
||||||
|
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||||
|
Ok(Self {
|
||||||
|
kind: PortForwardSpecKind::Remote,
|
||||||
|
bind_addr,
|
||||||
|
bind_port,
|
||||||
|
target_host,
|
||||||
|
target_port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||||
|
format!("{}:{}", self.bind_addr, self.bind_port)
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: format!("{}:{}", self.bind_addr, self.bind_port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn target_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||||
|
format!("{}:{}", self.target_host, self.target_port)
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: format!("{}:{}", self.target_host, self.target_port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PortForwardSpec {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let prefix = match self.kind {
|
||||||
|
PortForwardSpecKind::Local => "-L",
|
||||||
|
PortForwardSpecKind::Remote => "-R",
|
||||||
|
};
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{} {}:{}:{}:{}",
|
||||||
|
prefix, self.bind_addr, self.bind_port, self.target_host, self.target_port
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_spec(spec: &str) -> Result<(String, u16, String, u16), ForwardError> {
|
||||||
|
let parts: Vec<&str> = spec.split(':').collect();
|
||||||
|
if parts.len() != 4 {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let bind_addr = parts[0].to_string();
|
||||||
|
let bind_port: u16 = parts[1].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
})?;
|
||||||
|
let target_host = parts[2].to_string();
|
||||||
|
let target_port: u16 = parts[3].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||||
|
spec: spec.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok((bind_addr, bind_port, target_host, target_port))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LocalForwarder {
|
||||||
|
spec: PortForwardSpec,
|
||||||
|
listener: Option<TcpListener>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalForwarder {
|
||||||
|
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||||
|
if spec.kind != PortForwardSpecKind::Local {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: format!("expected local spec, got {:?}", spec.kind),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
spec,
|
||||||
|
listener: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn spec(&self) -> &PortForwardSpec {
|
||||||
|
&self.spec
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run<H: client::Handler + Send + 'static>(
|
||||||
|
&mut self,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
let listen_addr = self.spec.listen_addr()?;
|
||||||
|
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||||
|
self.listener = Some(listener);
|
||||||
|
let remote_host = self.spec.target_host.clone();
|
||||||
|
let remote_port = self.spec.target_port;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"local forward listening on {} -> {}:{}",
|
||||||
|
listen_addr, remote_host, remote_port
|
||||||
|
);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let listener = match &self.listener {
|
||||||
|
Some(l) => l,
|
||||||
|
None => return Ok(()),
|
||||||
|
};
|
||||||
|
let accept_result = listener.accept().await;
|
||||||
|
let (local_stream, local_addr) = match accept_result {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
let handle = handle.lock().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
debug!("local forward accept loop ending: ssh session closed");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
drop(handle);
|
||||||
|
error!("local forward accept error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"local forward connection from {} -> {}:{}",
|
||||||
|
local_addr, remote_host, remote_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle = handle.clone();
|
||||||
|
let remote_host = remote_host.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) =
|
||||||
|
proxy_local_to_remote(local_stream, handle, &remote_host, remote_port).await
|
||||||
|
{
|
||||||
|
debug!("local forward proxy error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(&mut self) {
|
||||||
|
if let Some(listener) = self.listener.take() {
|
||||||
|
drop(listener);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_port(&self) -> u16 {
|
||||||
|
self.spec.bind_port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
|
||||||
|
local_stream: TcpStream,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
remote_host: &str,
|
||||||
|
remote_port: u16,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
let local_addr = local_stream
|
||||||
|
.peer_addr()
|
||||||
|
.map(|a| a.to_string())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let handle_guard = handle.lock().await;
|
||||||
|
let channel = handle_guard
|
||||||
|
.channel_open_direct_tcpip(
|
||||||
|
remote_host,
|
||||||
|
remote_port as u32,
|
||||||
|
&local_addr,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
drop(handle_guard);
|
||||||
|
|
||||||
|
let ssh_stream = channel.into_stream();
|
||||||
|
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||||
|
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||||
|
|
||||||
|
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||||
|
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||||
|
|
||||||
|
match tokio::join!(client_to_server, server_to_client) {
|
||||||
|
(Err(e), _) | (_, Err(e)) => {
|
||||||
|
debug!("local forward bidirectional copy error: {}", e);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RemoteForwarder {
|
||||||
|
spec: PortForwardSpec,
|
||||||
|
cancel: Option<tokio::sync::oneshot::Sender<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteForwarder {
|
||||||
|
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||||
|
if spec.kind != PortForwardSpecKind::Remote {
|
||||||
|
return Err(ForwardError::InvalidSpec {
|
||||||
|
spec: format!("expected remote spec, got {:?}", spec.kind),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self { spec, cancel: None })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn spec(&self) -> &PortForwardSpec {
|
||||||
|
&self.spec
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register<H: client::Handler + Send + 'static>(
|
||||||
|
&self,
|
||||||
|
handle: &mut client::Handle<H>,
|
||||||
|
) -> Result<u32, ForwardError> {
|
||||||
|
let port = handle
|
||||||
|
.tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
Ok(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_forwarded_channel(
|
||||||
|
channel: russh::Channel<russh::client::Msg>,
|
||||||
|
connected_address: &str,
|
||||||
|
connected_port: u32,
|
||||||
|
local_host: &str,
|
||||||
|
local_port: u16,
|
||||||
|
) {
|
||||||
|
debug!(
|
||||||
|
"remote forward: server opened forwarded-tcpip channel to {}:{} -> local {}:{}",
|
||||||
|
connected_address, connected_port, local_host, local_port
|
||||||
|
);
|
||||||
|
|
||||||
|
let local_target = format!("{}:{}", local_host, local_port);
|
||||||
|
let local_stream = match TcpStream::connect(&local_target).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"remote forward: failed to connect to local target {}: {}",
|
||||||
|
local_target, e
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ssh_stream = channel.into_stream();
|
||||||
|
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||||
|
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||||
|
|
||||||
|
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||||
|
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||||
|
|
||||||
|
match tokio::join!(client_to_server, server_to_client) {
|
||||||
|
(Err(e), _) | (_, Err(e)) => {
|
||||||
|
debug!("remote forward bidirectional copy error: {}", e);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unregister<H: client::Handler + Send + 'static>(
|
||||||
|
&self,
|
||||||
|
handle: &client::Handle<H>,
|
||||||
|
) -> Result<(), ForwardError> {
|
||||||
|
handle
|
||||||
|
.cancel_tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
|
source: Box::new(e) as _,
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(&mut self) {
|
||||||
|
if let Some(cancel) = self.cancel.take() {
|
||||||
|
let _ = cancel.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_local_forwarders<H: client::Handler + Send + 'static>(
|
||||||
|
forwarders: Vec<LocalForwarder>,
|
||||||
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
|
mut shutdown: tokio::sync::watch::Receiver<bool>,
|
||||||
|
) -> Vec<LocalForwarder> {
|
||||||
|
let mut forwarders = forwarders;
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
|
||||||
|
for forwarder in forwarders.drain(..) {
|
||||||
|
let handle = handle.clone();
|
||||||
|
let spec = forwarder.spec().clone();
|
||||||
|
let (_cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
let mut fwd = forwarder;
|
||||||
|
tokio::select! {
|
||||||
|
result = fwd.run(handle) => {
|
||||||
|
if let Err(e) = result {
|
||||||
|
error!("local forward {} failed: {}", spec, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = cancel_rx => {
|
||||||
|
fwd.stop().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fwd
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = shutdown.changed().await;
|
||||||
|
|
||||||
|
for task in &tasks {
|
||||||
|
task.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for task in tasks {
|
||||||
|
match task.await {
|
||||||
|
Ok(fwd) => results.push(fwd),
|
||||||
|
Err(e) => {
|
||||||
|
if !e.is_cancelled() {
|
||||||
|
error!("local forwarder task panicked: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_local_spec() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert_eq!(spec.kind, PortForwardSpecKind::Local);
|
||||||
|
assert_eq!(spec.bind_addr, "127.0.0.1");
|
||||||
|
assert_eq!(spec.bind_port, 5432);
|
||||||
|
assert_eq!(spec.target_host, "db.internal");
|
||||||
|
assert_eq!(spec.target_port, 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_remote_spec() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert_eq!(spec.kind, PortForwardSpecKind::Remote);
|
||||||
|
assert_eq!(spec.bind_addr, "0.0.0.0");
|
||||||
|
assert_eq!(spec.bind_port, 8080);
|
||||||
|
assert_eq!(spec.target_host, "127.0.0.1");
|
||||||
|
assert_eq!(spec.target_port, 3000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_few_parts() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:5432:db").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_many_parts() {
|
||||||
|
assert!(PortForwardSpec::local("a:b:c:d:e").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_port() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:abc:db:5432").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_spec_invalid_target_port() {
|
||||||
|
assert!(PortForwardSpec::local("127.0.0.1:5432:db:abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spec_display() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert_eq!(spec.to_string(), "-L 127.0.0.1:5432:db.internal:5432");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spec_display_remote() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert_eq!(spec.to_string(), "-R 0.0.0.0:8080:127.0.0.1:3000");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn local_forwarder_rejects_remote_spec() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
assert!(LocalForwarder::new(spec).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_forwarder_rejects_local_spec() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
assert!(RemoteForwarder::new(spec).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn listen_addr_valid() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
let addr = spec.listen_addr().unwrap();
|
||||||
|
assert_eq!(addr.port(), 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn listen_addr_invalid_host() {
|
||||||
|
let spec = PortForwardSpec {
|
||||||
|
kind: PortForwardSpecKind::Local,
|
||||||
|
bind_addr: "!!!invalid".to_string(),
|
||||||
|
bind_port: 5432,
|
||||||
|
target_host: "db".to_string(),
|
||||||
|
target_port: 5432,
|
||||||
|
};
|
||||||
|
assert!(spec.listen_addr().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn local_forward_bind_and_accept() {
|
||||||
|
let spec = PortForwardSpec::local(&format!("127.0.0.1:0:remote:5432")).unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||||
|
|
||||||
|
let listen_addr = forwarder.spec.listen_addr().unwrap();
|
||||||
|
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
||||||
|
let bound_addr = listener.local_addr().unwrap();
|
||||||
|
drop(listener);
|
||||||
|
|
||||||
|
let spec = PortForwardSpec::local(&format!(
|
||||||
|
"127.0.0.1:{}:remote:5432",
|
||||||
|
bound_addr.port()
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||||
|
assert_eq!(forwarder.local_port(), bound_addr.port());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remote_forward_proxy_bidirectional() {
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let _echo_addr = echo_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let echo_server = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = echo_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
loop {
|
||||||
|
let n = match stream.read(&mut buf).await {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(_) => break,
|
||||||
|
};
|
||||||
|
if stream.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let local_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let local_addr = local_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = local_listener.accept().await.unwrap();
|
||||||
|
let (mut read, mut write) = tokio::io::split(stream);
|
||||||
|
let _ = io::copy(&mut read, &mut write).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut local_conn = TcpStream::connect(local_addr).await.unwrap();
|
||||||
|
local_conn.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = local_conn.read(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf[..n], b"hello");
|
||||||
|
|
||||||
|
echo_server.abort();
|
||||||
|
proxy_task.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forwarder_spec_access() {
|
||||||
|
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||||
|
let forwarder = LocalForwarder::new(spec.clone()).unwrap();
|
||||||
|
assert_eq!(forwarder.spec(), &spec);
|
||||||
|
assert_eq!(forwarder.local_port(), 5432);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn remote_forwarder_spec_access() {
|
||||||
|
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||||
|
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
|
||||||
|
assert_eq!(forwarder.spec(), &spec);
|
||||||
|
}
|
||||||
|
}
|
||||||
5
crates/wraith-core/src/client/mod.rs
Normal file
5
crates/wraith-core/src/client/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod channel_manager;
|
||||||
|
pub mod forward;
|
||||||
|
|
||||||
|
pub use channel_manager::{ChannelManager, ForwardRequest};
|
||||||
|
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
|
||||||
@@ -60,6 +60,27 @@ pub enum ConfigError {
|
|||||||
IncompatibleOptions,
|
IncompatibleOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ForwardError {
|
||||||
|
#[error("invalid port forward spec: {spec}")]
|
||||||
|
InvalidSpec { spec: String },
|
||||||
|
#[error("bind failed")]
|
||||||
|
BindFailed {
|
||||||
|
#[source]
|
||||||
|
source: io::Error,
|
||||||
|
},
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed {
|
||||||
|
#[source]
|
||||||
|
source: Box<dyn std::error::Error + Send + Sync>,
|
||||||
|
},
|
||||||
|
#[error("connect to local target failed")]
|
||||||
|
LocalConnectFailed {
|
||||||
|
#[source]
|
||||||
|
source: io::Error,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -150,4 +171,36 @@ mod tests {
|
|||||||
let plain = AuthError::KeyRejected;
|
let plain = AuthError::KeyRejected;
|
||||||
assert!(plain.source().is_none());
|
assert!(plain.source().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_error_display() {
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(),
|
||||||
|
"invalid port forward spec: bad"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::BindFailed {
|
||||||
|
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
|
"bind failed"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ForwardError::LocalConnectFailed {
|
||||||
|
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
|
"connect to local target failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_error_source_chaining() {
|
||||||
|
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "in use");
|
||||||
|
let forward_err = ForwardError::BindFailed { source: io_err };
|
||||||
|
assert!(forward_err.source().is_some());
|
||||||
|
|
||||||
|
let plain = ForwardError::InvalidSpec { spec: "bad".to_string() };
|
||||||
|
assert!(plain.source().is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -8,5 +8,6 @@ pub mod error;
|
|||||||
#[cfg(feature = "testutil")]
|
#[cfg(feature = "testutil")]
|
||||||
pub mod testutil;
|
pub mod testutil;
|
||||||
|
|
||||||
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
|
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||||
560
crates/wraith-core/src/server/channel_proxy.rs
Normal file
560
crates/wraith-core/src/server/channel_proxy.rs
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
|
||||||
|
use super::handler::{ProxyConfig, ProxyMode};
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ChannelProxyError {
|
||||||
|
#[error("connection refused")]
|
||||||
|
ConnectionRefused,
|
||||||
|
#[error("target unreachable")]
|
||||||
|
TargetUnreachable,
|
||||||
|
#[error("socks5 proxy handshake failed")]
|
||||||
|
Socks5HandshakeFailed,
|
||||||
|
#[error("socks5 proxy rejected connection")]
|
||||||
|
Socks5ProxyRejected,
|
||||||
|
#[error("http connect proxy handshake failed")]
|
||||||
|
HttpConnectHandshakeFailed,
|
||||||
|
#[error("http connect proxy rejected: {0}")]
|
||||||
|
HttpConnectProxyRejected(String),
|
||||||
|
#[error("io error")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect_outbound(
|
||||||
|
target: SocketAddr,
|
||||||
|
proxy: &ProxyConfig,
|
||||||
|
) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
match &proxy.mode {
|
||||||
|
ProxyMode::Direct => connect_direct(target).await,
|
||||||
|
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
|
||||||
|
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
TcpStream::connect(target)
|
||||||
|
.await
|
||||||
|
.map_err(|e| map_connection_error(e, target))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
let mut stream = TcpStream::connect(proxy_addr)
|
||||||
|
.await
|
||||||
|
.map_err(ChannelProxyError::from)?;
|
||||||
|
|
||||||
|
stream.write_all(&[0x05, 0x01, 0x00]).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
stream.read_exact(&mut resp).await?;
|
||||||
|
if resp[0] != 0x05 || resp[1] != 0x00 {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip_bytes = target.ip().to_string();
|
||||||
|
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
connect_req.push(ip_bytes.len() as u8);
|
||||||
|
connect_req.extend_from_slice(ip_bytes.as_bytes());
|
||||||
|
connect_req.extend_from_slice(&target.port().to_be_bytes());
|
||||||
|
stream.write_all(&connect_req).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut reply_header = [0u8; 4];
|
||||||
|
stream.read_exact(&mut reply_header).await?;
|
||||||
|
if reply_header[0] != 0x05 {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
if reply_header[1] != 0x00 {
|
||||||
|
return Err(ChannelProxyError::Socks5ProxyRejected);
|
||||||
|
}
|
||||||
|
|
||||||
|
let atyp = reply_header[3];
|
||||||
|
match atyp {
|
||||||
|
0x01 => {
|
||||||
|
let mut _addr = [0u8; 4];
|
||||||
|
stream.read_exact(&mut _addr).await?;
|
||||||
|
}
|
||||||
|
0x04 => {
|
||||||
|
let mut _addr = [0u8; 16];
|
||||||
|
stream.read_exact(&mut _addr).await?;
|
||||||
|
}
|
||||||
|
0x03 => {
|
||||||
|
let len = stream.read_u8().await?;
|
||||||
|
let mut _domain = vec![0u8; len as usize];
|
||||||
|
stream.read_exact(&mut _domain).await?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut _port = [0u8; 2];
|
||||||
|
stream.read_exact(&mut _port).await?;
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_http_connect(
|
||||||
|
target: SocketAddr,
|
||||||
|
proxy_addr: SocketAddr,
|
||||||
|
) -> Result<TcpStream, ChannelProxyError> {
|
||||||
|
let mut stream = TcpStream::connect(proxy_addr)
|
||||||
|
.await
|
||||||
|
.map_err(ChannelProxyError::from)?;
|
||||||
|
|
||||||
|
let connect_request = format!(
|
||||||
|
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
|
||||||
|
target.ip(),
|
||||||
|
target.port(),
|
||||||
|
target.ip(),
|
||||||
|
target.port()
|
||||||
|
);
|
||||||
|
stream.write_all(connect_request.as_bytes()).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
let mut response = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = stream.read(&mut buf).await?;
|
||||||
|
if n == 0 {
|
||||||
|
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
|
||||||
|
}
|
||||||
|
response.extend_from_slice(&buf[..n]);
|
||||||
|
if response.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_str = String::from_utf8_lossy(&response);
|
||||||
|
let status_line = response_str
|
||||||
|
.lines()
|
||||||
|
.next()
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
if status_line.contains("200") {
|
||||||
|
Ok(stream)
|
||||||
|
} else {
|
||||||
|
Err(ChannelProxyError::HttpConnectProxyRejected(
|
||||||
|
status_line.to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_connection_error(e: std::io::Error, target: SocketAddr) -> ChannelProxyError {
|
||||||
|
match e.kind() {
|
||||||
|
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
|
||||||
|
std::io::ErrorKind::AddrNotAvailable
|
||||||
|
| std::io::ErrorKind::NetworkUnreachable
|
||||||
|
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
|
||||||
|
_ => {
|
||||||
|
tracing::debug!(error = %e, "outbound connection failed to {:?}", target);
|
||||||
|
ChannelProxyError::Io(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
if let Ok(outbound) = connect_outbound(target, proxy).await {
|
||||||
|
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
|
||||||
|
let (mut read_out, mut write_out) = outbound.into_split();
|
||||||
|
|
||||||
|
let client_to_target = tokio::spawn(async move {
|
||||||
|
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
|
||||||
|
let _ = write_out.shutdown().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let target_to_client = tokio::spawn(async move {
|
||||||
|
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
|
||||||
|
let _ = write_chan.shutdown().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = client_to_target.await;
|
||||||
|
let _ = target_to_client.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
fn direct_config() -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::Direct,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::Socks5(addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::HttpConnect(addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_to_echo_server() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
|
||||||
|
let (mut read, mut write) = stream.into_split();
|
||||||
|
write.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 5];
|
||||||
|
read.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello");
|
||||||
|
|
||||||
|
let _ = server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_target_unreachable() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let result = connect_outbound(target, &direct_config()).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_handshake() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = target_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut greeting = [0u8; 3];
|
||||||
|
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||||
|
assert_eq!(greeting[0], 0x05);
|
||||||
|
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||||
|
|
||||||
|
let mut req_header = [0u8; 4];
|
||||||
|
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||||
|
assert_eq!(req_header[0], 0x05);
|
||||||
|
assert_eq!(req_header[1], 0x01);
|
||||||
|
|
||||||
|
let atyp = req_header[3];
|
||||||
|
assert_eq!(atyp, 0x03);
|
||||||
|
|
||||||
|
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||||
|
let mut domain = vec![0u8; domain_len];
|
||||||
|
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||||
|
let mut port_bytes = [0u8; 2];
|
||||||
|
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||||
|
|
||||||
|
let target: SocketAddr = format!(
|
||||||
|
"{}:{}",
|
||||||
|
String::from_utf8_lossy(&domain),
|
||||||
|
u16::from_be_bytes(port_bytes)
|
||||||
|
)
|
||||||
|
.parse()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let reply = vec![
|
||||||
|
0x05, 0x00, 0x00, 0x01,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0,
|
||||||
|
];
|
||||||
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
|
|
||||||
|
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
||||||
|
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let config = socks5_config(proxy_addr);
|
||||||
|
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||||
|
stream.write_all(b"hello socks").await.unwrap();
|
||||||
|
let mut buf = [0u8; 11];
|
||||||
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello socks");
|
||||||
|
drop(stream);
|
||||||
|
|
||||||
|
let _ = target_server.await;
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_rejected() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut greeting = [0u8; 3];
|
||||||
|
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||||
|
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||||
|
|
||||||
|
let mut req_header = [0u8; 4];
|
||||||
|
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||||
|
|
||||||
|
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||||
|
let mut domain = vec![0u8; domain_len];
|
||||||
|
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||||
|
let mut port_bytes = [0u8; 2];
|
||||||
|
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||||
|
|
||||||
|
let reply = vec![
|
||||||
|
0x05, 0x05, 0x00, 0x01,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0,
|
||||||
|
];
|
||||||
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let config = socks5_config(proxy_addr);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
ChannelProxyError::Socks5ProxyRejected
|
||||||
|
));
|
||||||
|
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_handshake() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = target_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let target_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut request = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||||
|
request.extend_from_slice(&buf[..n]);
|
||||||
|
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||||
|
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
|
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
|
||||||
|
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
|
||||||
|
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let config = http_connect_config(proxy_addr);
|
||||||
|
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||||
|
stream.write_all(b"hello http").await.unwrap();
|
||||||
|
let mut buf = [0u8; 10];
|
||||||
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello http");
|
||||||
|
drop(stream);
|
||||||
|
|
||||||
|
let _ = target_server.await;
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_connect_target(request: &str) -> String {
|
||||||
|
let connect_line = request.lines().next().unwrap_or("");
|
||||||
|
let parts: Vec<&str> = connect_line.split_whitespace().collect();
|
||||||
|
if parts.len() >= 2 {
|
||||||
|
parts[1].to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_rejected() {
|
||||||
|
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let proxy_server = tokio::spawn(async move {
|
||||||
|
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut request = Vec::new();
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
loop {
|
||||||
|
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
request.extend_from_slice(&buf[..n]);
|
||||||
|
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
|
||||||
|
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let config = http_connect_config(proxy_addr);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result.unwrap_err() {
|
||||||
|
ChannelProxyError::HttpConnectProxyRejected(msg) => {
|
||||||
|
assert!(msg.contains("403"));
|
||||||
|
}
|
||||||
|
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = proxy_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn target_unreachable_returns_appropriate_error() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let result = connect_outbound(target, &direct_config()).await;
|
||||||
|
match result.unwrap_err() {
|
||||||
|
ChannelProxyError::TargetUnreachable
|
||||||
|
| ChannelProxyError::ConnectionRefused
|
||||||
|
| ChannelProxyError::Io(_) => {}
|
||||||
|
other => panic!("unexpected error type: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn socks5_proxy_unreachable() {
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||||
|
let config = socks5_config(bad_proxy);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn http_connect_proxy_unreachable() {
|
||||||
|
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||||
|
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||||
|
let config = http_connect_config(bad_proxy);
|
||||||
|
let result = connect_outbound(target, &config).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MockChannel {
|
||||||
|
read_half: tokio::io::ReadHalf<DuplexStream>,
|
||||||
|
write_half: tokio::io::WriteHalf<DuplexStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncRead for MockChannel {
|
||||||
|
fn poll_read(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for MockChannel {
|
||||||
|
fn poll_write(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> std::task::Poll<std::io::Result<usize>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_mock_channel() -> (MockChannel, DuplexStream) {
|
||||||
|
let (client, server) = duplex(4096);
|
||||||
|
let (read_half, write_half) = tokio::io::split(client);
|
||||||
|
(
|
||||||
|
MockChannel {
|
||||||
|
read_half,
|
||||||
|
write_half,
|
||||||
|
},
|
||||||
|
server,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_channel_bidirectional_data_flow() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let target_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let echo_server = tokio::spawn(async move {
|
||||||
|
let (mut sock, _) = listener.accept().await.unwrap();
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
let n = sock.read(&mut buf).await.unwrap();
|
||||||
|
sock.write_all(&buf[..n]).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let (channel, mut channel_peer) = make_mock_channel();
|
||||||
|
|
||||||
|
let target = target_addr;
|
||||||
|
let proxy = direct_config();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
proxy_channel(channel, target, &proxy).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
channel_peer.write_all(b"ping").await.unwrap();
|
||||||
|
channel_peer.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 4];
|
||||||
|
channel_peer.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"ping");
|
||||||
|
|
||||||
|
drop(channel_peer);
|
||||||
|
let _ = echo_server.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_channel_target_unreachable_closes_cleanly() {
|
||||||
|
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||||
|
let (channel, _channel_peer) = make_mock_channel();
|
||||||
|
|
||||||
|
let proxy = direct_config();
|
||||||
|
proxy_channel(channel, target, &proxy).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
289
crates/wraith-core/src/server/handler.rs
Normal file
289
crates/wraith-core/src/server/handler.rs
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use russh::keys::ssh_key::HashAlg;
|
||||||
|
use russh::server::{Auth, Handler, Msg, Session};
|
||||||
|
use russh::Channel;
|
||||||
|
|
||||||
|
use crate::auth::ServerAuthConfig;
|
||||||
|
|
||||||
|
const WRAITH_PREFIX: &str = "wraith-";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ProxyMode {
|
||||||
|
Direct,
|
||||||
|
Socks5(SocketAddr),
|
||||||
|
HttpConnect(SocketAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ProxyConfig {
|
||||||
|
pub mode: ProxyMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ServerHandler {
|
||||||
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
|
remote_addr: Option<SocketAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerHandler {
|
||||||
|
pub fn new(
|
||||||
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
|
remote_addr: Option<SocketAddr>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
auth_config,
|
||||||
|
outbound_proxy,
|
||||||
|
remote_addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Handler for ServerHandler {
|
||||||
|
type Error = russh::Error;
|
||||||
|
|
||||||
|
async fn auth_publickey(
|
||||||
|
&mut self,
|
||||||
|
user: &str,
|
||||||
|
public_key: &russh::keys::ssh_key::PublicKey,
|
||||||
|
) -> Result<Auth, Self::Error> {
|
||||||
|
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||||
|
let remote_addr_display = self
|
||||||
|
.remote_addr
|
||||||
|
.map_or("unknown".to_string(), |a| a.to_string());
|
||||||
|
|
||||||
|
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
|
||||||
|
let result = self.auth_config.authenticate_publickey(&russh_pub);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(()) => {
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %remote_addr_display,
|
||||||
|
key_fingerprint = %fingerprint,
|
||||||
|
result = "accept",
|
||||||
|
"auth attempt"
|
||||||
|
);
|
||||||
|
Ok(Auth::Accept)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %remote_addr_display,
|
||||||
|
key_fingerprint = %fingerprint,
|
||||||
|
result = "reject",
|
||||||
|
"auth attempt"
|
||||||
|
);
|
||||||
|
Ok(Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn channel_open_direct_tcpip(
|
||||||
|
&mut self,
|
||||||
|
channel: Channel<Msg>,
|
||||||
|
host_to_connect: &str,
|
||||||
|
port_to_connect: u32,
|
||||||
|
originator_address: &str,
|
||||||
|
originator_port: u32,
|
||||||
|
_session: &mut Session,
|
||||||
|
) -> Result<bool, Self::Error> {
|
||||||
|
if host_to_connect.starts_with(WRAITH_PREFIX) {
|
||||||
|
tracing::info!(
|
||||||
|
host = host_to_connect,
|
||||||
|
port = port_to_connect,
|
||||||
|
"routing to internal control channel handler"
|
||||||
|
);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = (host_to_connect, port_to_connect, originator_address, originator_port, channel);
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn channel_open_session(
|
||||||
|
&mut self,
|
||||||
|
_channel: Channel<Msg>,
|
||||||
|
_session: &mut Session,
|
||||||
|
) -> Result<bool, Self::Error> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn channel_open_x11(
|
||||||
|
&mut self,
|
||||||
|
_channel: Channel<Msg>,
|
||||||
|
_originator_address: &str,
|
||||||
|
_originator_port: u32,
|
||||||
|
_session: &mut Session,
|
||||||
|
) -> Result<bool, Self::Error> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn channel_open_forwarded_tcpip(
|
||||||
|
&mut self,
|
||||||
|
_channel: Channel<Msg>,
|
||||||
|
_host_to_connect: &str,
|
||||||
|
_port_to_connect: u32,
|
||||||
|
_originator_address: &str,
|
||||||
|
_originator_port: u32,
|
||||||
|
_session: &mut Session,
|
||||||
|
) -> Result<bool, Self::Error> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::auth::keys::KeySource;
|
||||||
|
use russh::keys::{decode_secret_key, PrivateKey};
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
|
|
||||||
|
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||||
|
|
||||||
|
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||||
|
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
f.write_all(keys_content.as_bytes()).unwrap();
|
||||||
|
f.flush().unwrap();
|
||||||
|
f
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_key() -> PrivateKey {
|
||||||
|
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
|
||||||
|
let f = make_authorized_keys_file(keys_content);
|
||||||
|
Arc::new(
|
||||||
|
ServerAuthConfig::from_keys_and_ca(
|
||||||
|
Some(KeySource::File(f.path().to_path_buf())),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
|
||||||
|
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_delegation_accepts_known_key() {
|
||||||
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
|
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||||
|
assert_eq!(result, Auth::Accept);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_delegation_rejects_unknown_key() {
|
||||||
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
|
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||||
|
|
||||||
|
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||||
|
let other_ssh_key = russh::keys::parse_public_key_base64(
|
||||||
|
other_key_text.split_whitespace().nth(1).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result = handler
|
||||||
|
.auth_publickey("testuser", &other_ssh_key)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
Auth::Reject {
|
||||||
|
proceed_with_methods: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_delegation_empty_config_rejects_all() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let mut handler = ServerHandler::new(auth_config, None, None);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
let result = handler
|
||||||
|
.auth_publickey("testuser", &ssh_key)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
Auth::Reject {
|
||||||
|
proceed_with_methods: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_logging_includes_remote_addr() {
|
||||||
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
|
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||||
|
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reserved_wraith_destination_routing() {
|
||||||
|
assert!("wraith-control".starts_with(WRAITH_PREFIX));
|
||||||
|
assert!("wraith-status".starts_with(WRAITH_PREFIX));
|
||||||
|
assert!("wraith-events".starts_with(WRAITH_PREFIX));
|
||||||
|
assert!(!"example.com".starts_with(WRAITH_PREFIX));
|
||||||
|
assert!(!"localhost".starts_with(WRAITH_PREFIX));
|
||||||
|
assert!(!"wraith.example.com".starts_with(WRAITH_PREFIX));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn proxy_mode_variants() {
|
||||||
|
let direct = ProxyMode::Direct;
|
||||||
|
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
|
||||||
|
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
|
||||||
|
|
||||||
|
match direct {
|
||||||
|
ProxyMode::Direct => {}
|
||||||
|
_ => panic!("expected Direct"),
|
||||||
|
}
|
||||||
|
match socks5 {
|
||||||
|
ProxyMode::Socks5(_) => {}
|
||||||
|
_ => panic!("expected Socks5"),
|
||||||
|
}
|
||||||
|
match http {
|
||||||
|
ProxyMode::HttpConnect(_) => {}
|
||||||
|
_ => panic!("expected HttpConnect"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_handler_holds_config() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let proxy = Some(ProxyConfig {
|
||||||
|
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
|
||||||
|
});
|
||||||
|
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||||
|
|
||||||
|
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
|
||||||
|
assert!(handler.outbound_proxy.is_some());
|
||||||
|
assert!(handler.remote_addr.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn one_handler_per_connection() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||||
|
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||||
|
|
||||||
|
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
5
crates/wraith-core/src/server/mod.rs
Normal file
5
crates/wraith-core/src/server/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod channel_proxy;
|
||||||
|
pub mod handler;
|
||||||
|
|
||||||
|
pub use channel_proxy::{ChannelProxyError, connect_outbound, proxy_channel};
|
||||||
|
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||||
490
crates/wraith-core/src/socks5/mod.rs
Normal file
490
crates/wraith-core/src/socks5/mod.rs
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
mod protocol;
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
|
||||||
|
|
||||||
|
pub use protocol::Socks5Address;
|
||||||
|
|
||||||
|
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||||
|
|
||||||
|
pub trait ChannelOpener: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
|
fn open_channel(
|
||||||
|
&self,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ChannelOpenError {
|
||||||
|
#[error("session closed")]
|
||||||
|
SessionClosed,
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed,
|
||||||
|
#[error("connection refused")]
|
||||||
|
ConnectionRefused,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Socks5Server<C: ChannelOpener> {
|
||||||
|
listen_addr: SocketAddr,
|
||||||
|
channel_opener: Arc<C>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: ChannelOpener> Socks5Server<C> {
|
||||||
|
pub fn new(channel_opener: C) -> Self {
|
||||||
|
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
||||||
|
let listen_addr: SocketAddr = addr
|
||||||
|
.parse()
|
||||||
|
.expect("invalid SOCKS5 listen address");
|
||||||
|
Self {
|
||||||
|
listen_addr,
|
||||||
|
channel_opener: Arc::new(channel_opener),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
|
self.listen_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(self) -> Result<(), std::io::Error> {
|
||||||
|
let listener = TcpListener::bind(self.listen_addr).await?;
|
||||||
|
debug!("socks5 server listening on {}", self.listen_addr);
|
||||||
|
loop {
|
||||||
|
let (socket, _peer) = listener.accept().await?;
|
||||||
|
let opener = Arc::clone(&self.channel_opener);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_socks5_connection(socket, opener).await {
|
||||||
|
debug!("socks5 connection error: {e}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socks5_connection<S, C>(
|
||||||
|
mut socket: S,
|
||||||
|
opener: Arc<C>,
|
||||||
|
) -> Result<(), Socks5Error>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
C: ChannelOpener,
|
||||||
|
{
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
|
||||||
|
if vm.version != 0x05 {
|
||||||
|
return Err(Socks5Error::InvalidVersion(vm.version));
|
||||||
|
}
|
||||||
|
if !vm.methods.contains(&0x00) {
|
||||||
|
let reply = [0x05, 0xFF];
|
||||||
|
socket.write_all(&reply).await?;
|
||||||
|
socket.shutdown().await?;
|
||||||
|
return Err(Socks5Error::NoAcceptableAuth);
|
||||||
|
}
|
||||||
|
let reply = [0x05, 0x00];
|
||||||
|
socket.write_all(&reply).await?;
|
||||||
|
|
||||||
|
let request = Socks5Request::read_from(&mut socket).await?;
|
||||||
|
if request.version != 0x05 {
|
||||||
|
return Err(Socks5Error::InvalidVersion(request.version));
|
||||||
|
}
|
||||||
|
if request.command != 0x01 {
|
||||||
|
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
|
||||||
|
return Err(Socks5Error::UnsupportedCommand(request.command));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (host, port) = match &request.address {
|
||||||
|
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
|
||||||
|
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
|
||||||
|
Socks5Address::Domain(name) => (name.clone(), request.port),
|
||||||
|
};
|
||||||
|
|
||||||
|
match opener.open_channel(host, port).await {
|
||||||
|
Ok(mut ssh_stream) => {
|
||||||
|
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
|
||||||
|
let reply = Socks5Reply::success(bind_addr, 0);
|
||||||
|
reply.write_to(&mut socket).await?;
|
||||||
|
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
|
||||||
|
Err(Socks5Error::ChannelOpenFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
|
socket: &mut S,
|
||||||
|
reply: Socks5Reply,
|
||||||
|
) -> Result<(), Socks5Error> {
|
||||||
|
reply.write_to(socket).await?;
|
||||||
|
let _ = socket.shutdown().await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum Socks5Error {
|
||||||
|
#[error("invalid SOCKS version: {0}")]
|
||||||
|
InvalidVersion(u8),
|
||||||
|
#[error("no acceptable auth method")]
|
||||||
|
NoAcceptableAuth,
|
||||||
|
#[error("unsupported command: {0}")]
|
||||||
|
UnsupportedCommand(u8),
|
||||||
|
#[error("channel open failed")]
|
||||||
|
ChannelOpenFailed,
|
||||||
|
#[error("io error")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HandleChannelOpener<H: russh::client::Handler> {
|
||||||
|
handle: Arc<Mutex<russh::client::Handle<H>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
||||||
|
pub fn new(handle: russh::client::Handle<H>) -> Self {
|
||||||
|
Self {
|
||||||
|
handle: Arc::new(Mutex::new(handle)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
|
||||||
|
Self { handle }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
||||||
|
type Stream = russh::ChannelStream<russh::client::Msg>;
|
||||||
|
|
||||||
|
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
let handle = self.handle.lock().await;
|
||||||
|
if handle.is_closed() {
|
||||||
|
return Err(ChannelOpenError::SessionClosed);
|
||||||
|
}
|
||||||
|
let channel = handle
|
||||||
|
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
|
||||||
|
.await
|
||||||
|
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
|
||||||
|
Ok(channel.into_stream())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||||
|
|
||||||
|
struct MockChannelOpener {
|
||||||
|
fail: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelOpener for MockChannelOpener {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
_host: String,
|
||||||
|
_port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
if self.fail {
|
||||||
|
Err(ChannelOpenError::ChannelOpenFailed)
|
||||||
|
} else {
|
||||||
|
let (client, _server) = duplex(4096);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, methods.len() as u8];
|
||||||
|
buf.extend_from_slice(methods);
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
|
||||||
|
buf.extend_from_slice(&addr);
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
buf.push(domain.len() as u8);
|
||||||
|
buf.extend_from_slice(domain.as_bytes());
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
|
||||||
|
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
|
||||||
|
buf.extend_from_slice(&addr);
|
||||||
|
buf.extend_from_slice(&port.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
||||||
|
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
client.read_exact(&mut resp).await.unwrap();
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_ipv4(addr, port))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
reply_buf.to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_no_auth_method() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp = do_handshake(&mut client).await;
|
||||||
|
assert_eq!(resp, [0x05, 0x00]);
|
||||||
|
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_rejects_no_acceptable_method() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_greeting(&[0x02]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut resp = [0u8; 2];
|
||||||
|
client.read_exact(&mut resp).await.unwrap();
|
||||||
|
assert_eq!(resp, [0x05, 0xFF]);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let result = server_handle.await.unwrap();
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err(),
|
||||||
|
Socks5Error::NoAcceptableAuth
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_ipv4() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_domain() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_domain("example.com", 443))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn address_type_ipv6() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||||
|
client
|
||||||
|
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_open_failure_returns_socks5_error() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: true };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[0], 0x05);
|
||||||
|
assert_eq!(reply_buf[1], 0x05);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unsupported_command_returns_error() {
|
||||||
|
let (mut client, server) = duplex(4096);
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
|
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
|
||||||
|
bind_req.extend_from_slice(&[127, 0, 0, 1]);
|
||||||
|
bind_req.extend_from_slice(&80u16.to_be_bytes());
|
||||||
|
client.write_all(&bind_req).await.unwrap();
|
||||||
|
client.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut reply_buf = [0u8; 10];
|
||||||
|
client.read_exact(&mut reply_buf).await.unwrap();
|
||||||
|
assert_eq!(reply_buf[1], 0x07);
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bidirectional_proxy_flow() {
|
||||||
|
let (mut client_sock, server_sock) = duplex(4096);
|
||||||
|
let (ssh_client, mut ssh_server) = duplex(4096);
|
||||||
|
|
||||||
|
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
|
||||||
|
|
||||||
|
struct ProxyOpener {
|
||||||
|
stream: Arc<Mutex<Option<DuplexStream>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelOpener for ProxyOpener {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
_host: String,
|
||||||
|
_port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
|
self.stream
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.take()
|
||||||
|
.ok_or(ChannelOpenError::ChannelOpenFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let opener = ProxyOpener {
|
||||||
|
stream: Arc::clone(&ssh_stream),
|
||||||
|
};
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
handle_socks5_connection(server_sock, Arc::new(opener)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
do_handshake(&mut client_sock).await;
|
||||||
|
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
||||||
|
assert_eq!(reply_buf[1], 0x00);
|
||||||
|
|
||||||
|
let test_data = b"hello through tunnel";
|
||||||
|
client_sock.write_all(test_data).await.unwrap();
|
||||||
|
client_sock.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut received = vec![0u8; test_data.len()];
|
||||||
|
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(&received, test_data);
|
||||||
|
|
||||||
|
let echo_data = b"response from tunnel";
|
||||||
|
ssh_server.write_all(echo_data).await.unwrap();
|
||||||
|
ssh_server.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut received_back = vec![0u8; echo_data.len()];
|
||||||
|
client_sock.read_exact(&mut received_back).await.unwrap();
|
||||||
|
assert_eq!(&received_back, echo_data);
|
||||||
|
|
||||||
|
drop(client_sock);
|
||||||
|
drop(ssh_server);
|
||||||
|
let _ = server_handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn default_listen_address() {
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
let server = Socks5Server::new(opener);
|
||||||
|
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn custom_listen_address() {
|
||||||
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
|
||||||
|
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
304
crates/wraith-core/src/socks5/protocol.rs
Normal file
304
crates/wraith-core/src/socks5/protocol.rs
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum Socks5Address {
|
||||||
|
Ipv4(Ipv4Addr),
|
||||||
|
Ipv6(Ipv6Addr),
|
||||||
|
Domain(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5VersionMethod {
|
||||||
|
pub version: u8,
|
||||||
|
pub methods: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5VersionMethod {
|
||||||
|
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||||
|
let version = reader.read_u8().await?;
|
||||||
|
let nmethods = reader.read_u8().await?;
|
||||||
|
let mut methods = vec![0u8; nmethods as usize];
|
||||||
|
reader.read_exact(&mut methods).await?;
|
||||||
|
Ok(Self { version, methods })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5Request {
|
||||||
|
pub version: u8,
|
||||||
|
pub command: u8,
|
||||||
|
pub address: Socks5Address,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5Request {
|
||||||
|
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||||
|
let version = reader.read_u8().await?;
|
||||||
|
let command = reader.read_u8().await?;
|
||||||
|
let _rsv = reader.read_u8().await?;
|
||||||
|
let atyp = reader.read_u8().await?;
|
||||||
|
|
||||||
|
let address = match atyp {
|
||||||
|
0x01 => {
|
||||||
|
let mut octets = [0u8; 4];
|
||||||
|
reader.read_exact(&mut octets).await?;
|
||||||
|
Socks5Address::Ipv4(Ipv4Addr::from(octets))
|
||||||
|
}
|
||||||
|
0x04 => {
|
||||||
|
let mut octets = [0u8; 16];
|
||||||
|
reader.read_exact(&mut octets).await?;
|
||||||
|
Socks5Address::Ipv6(Ipv6Addr::from(octets))
|
||||||
|
}
|
||||||
|
0x03 => {
|
||||||
|
let len = reader.read_u8().await?;
|
||||||
|
let mut domain = vec![0u8; len as usize];
|
||||||
|
reader.read_exact(&mut domain).await?;
|
||||||
|
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("unsupported address type: {atyp}"),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let port = reader.read_u16().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
version,
|
||||||
|
command,
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Socks5Reply {
|
||||||
|
pub version: u8,
|
||||||
|
pub reply: u8,
|
||||||
|
pub address: Socks5Address,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socks5Reply {
|
||||||
|
pub fn success(address: Socks5Address, port: u16) -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x00,
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connection_refused() -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x05,
|
||||||
|
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||||
|
port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_not_supported() -> Self {
|
||||||
|
Self {
|
||||||
|
version: 0x05,
|
||||||
|
reply: 0x07,
|
||||||
|
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||||
|
port: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||||
|
writer.write_u8(self.version).await?;
|
||||||
|
writer.write_u8(self.reply).await?;
|
||||||
|
writer.write_u8(0x00).await?;
|
||||||
|
match &self.address {
|
||||||
|
Socks5Address::Ipv4(addr) => {
|
||||||
|
writer.write_u8(0x01).await?;
|
||||||
|
writer.write_all(&addr.octets()).await?;
|
||||||
|
}
|
||||||
|
Socks5Address::Ipv6(addr) => {
|
||||||
|
writer.write_u8(0x04).await?;
|
||||||
|
writer.write_all(&addr.octets()).await?;
|
||||||
|
}
|
||||||
|
Socks5Address::Domain(name) => {
|
||||||
|
writer.write_u8(0x03).await?;
|
||||||
|
writer.write_u8(name.len() as u8).await?;
|
||||||
|
writer.write_all(name.as_bytes()).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writer.write_u16(self.port).await?;
|
||||||
|
writer.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_version_method_no_auth() {
|
||||||
|
let data = [0x05, 0x01, 0x00];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(vm.version, 0x05);
|
||||||
|
assert_eq!(vm.methods, vec![0x00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_version_method_multiple() {
|
||||||
|
let data = [0x05, 0x02, 0x00, 0x02];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(vm.version, 0x05);
|
||||||
|
assert_eq!(vm.methods, vec![0x00, 0x02]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_ipv4() {
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x01];
|
||||||
|
data.extend_from_slice(&[10, 0, 0, 1]);
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert_eq!(
|
||||||
|
req.address,
|
||||||
|
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
|
||||||
|
);
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_ipv6() {
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x04];
|
||||||
|
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||||
|
data.extend_from_slice(&octets);
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_domain() {
|
||||||
|
let domain = "example.com";
|
||||||
|
let mut data = vec![0x05, 0x01, 0x00, 0x03];
|
||||||
|
data.push(domain.len() as u8);
|
||||||
|
data.extend_from_slice(domain.as_bytes());
|
||||||
|
data.extend_from_slice(&443u16.to_be_bytes());
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
|
assert_eq!(req.version, 0x05);
|
||||||
|
assert_eq!(req.command, 0x01);
|
||||||
|
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
|
||||||
|
assert_eq!(req.port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn parse_request_unsupported_address_type() {
|
||||||
|
let data = [0x05, 0x01, 0x00, 0x05];
|
||||||
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
|
let result = Socks5Request::read_from(&mut cursor).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_success_ipv4() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x00);
|
||||||
|
assert_eq!(buf[2], 0x00);
|
||||||
|
assert_eq!(buf[3], 0x01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_connection_refused() {
|
||||||
|
let reply = Socks5Reply::connection_refused();
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x05);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reply_command_not_supported() {
|
||||||
|
let reply = Socks5Reply::command_not_supported();
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(buf[0], 0x05);
|
||||||
|
assert_eq!(buf[1], 0x07);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_ipv4_reply() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(version, 0x05);
|
||||||
|
assert_eq!(atyp, 0x01);
|
||||||
|
let mut octets = [0u8; 4];
|
||||||
|
cursor.read_exact(&mut octets).await.unwrap();
|
||||||
|
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 1080);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_ipv6_reply() {
|
||||||
|
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let _version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(atyp, 0x04);
|
||||||
|
let mut octets = [0u8; 16];
|
||||||
|
cursor.read_exact(&mut octets).await.unwrap();
|
||||||
|
assert_eq!(Ipv6Addr::from(octets), addr);
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 443);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn roundtrip_domain_reply() {
|
||||||
|
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
reply.write_to(&mut buf).await.unwrap();
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&buf[..]);
|
||||||
|
let _version = cursor.read_u8().await.unwrap();
|
||||||
|
let _reply_code = cursor.read_u8().await.unwrap();
|
||||||
|
let _rsv = cursor.read_u8().await.unwrap();
|
||||||
|
let atyp = cursor.read_u8().await.unwrap();
|
||||||
|
assert_eq!(atyp, 0x03);
|
||||||
|
let len = cursor.read_u8().await.unwrap();
|
||||||
|
let mut domain = vec![0u8; len as usize];
|
||||||
|
cursor.read_exact(&mut domain).await.unwrap();
|
||||||
|
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
|
||||||
|
let port = cursor.read_u16().await.unwrap();
|
||||||
|
assert_eq!(port, 8080);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
---
|
---
|
||||||
id: client/channel-manager
|
id: client/channel-manager
|
||||||
name: Implement ChannelManager — SSH session management, channel opens, reconnection
|
name: Implement ChannelManager — SSH session management, channel opens, reconnection
|
||||||
status: pending
|
status: done
|
||||||
depends_on:
|
depends_on:
|
||||||
- auth/client-auth-handler
|
- auth/client-auth-handler
|
||||||
- transport/trait-and-types
|
- transport/trait-and-types
|
||||||
@@ -32,18 +32,18 @@ Reconnection is always enabled. The backoff caps at 30 seconds and continues ind
|
|||||||
|
|
||||||
## Acceptance Criteria
|
## Acceptance Criteria
|
||||||
|
|
||||||
- [ ] `crates/wraith-core/src/client/channel_manager.rs` exports `ChannelManager`
|
- [x] `crates/wraith-core/src/client/channel_manager.rs` exports `ChannelManager`
|
||||||
- [ ] `ChannelManager` holds: `Arc<Transport>`, `Arc<ClientAuthConfig>`, `Arc<client::Handle<ClientHandler>>` (behind RwLock for reconnection)
|
- [x] `ChannelManager` holds: `Arc<Transport>`, `Arc<ClientAuthConfig>`, `Arc<client::Handle<ClientHandler>>` (behind RwLock for reconnection)
|
||||||
- [ ] `ChannelManager::new()` establishes initial transport connection, authenticates, returns manager
|
- [x] `ChannelManager::new()` establishes initial transport connection, authenticates, returns manager
|
||||||
- [ ] `open_direct_tcpip(host, port)` — opens SSH channel to target
|
- [x] `open_direct_tcpip(host, port)` — opens SSH channel to target
|
||||||
- [ ] `request_tcpip_forward(addr, port)` — sends `tcpip_forward` request
|
- [x] `request_tcpip_forward(addr, port)` — sends `tcpip_forward` request
|
||||||
- [ ] `cancel_tcpip_forward(addr, port)` — sends `cancel_tcpip_forward` request
|
- [x] `cancel_tcpip_forward(addr, port)` — sends `cancel_tcpip_forward` request
|
||||||
- [ ] Reconnection detection: monitors `handle.is_closed()`, triggers reconnect on failure
|
- [x] Reconnection detection: monitors `handle.is_closed()`, triggers reconnect on failure
|
||||||
- [ ] Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely
|
- [x] Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely
|
||||||
- [ ] Full reconnect: new transport stream, new SSH session over it (ADR-004)
|
- [x] Full reconnect: new transport stream, new SSH session over it (ADR-004)
|
||||||
- [ ] After reconnect: port forward listeners (`-L`, `-R`) re-registered with new session
|
- [x] After reconnect: port forward listeners (`-L`, `-R`) re-registered with new session
|
||||||
- [ ] In-flight connections on old session fail gracefully (channel errors, not session-wide)
|
- [x] In-flight connections on old session fail gracefully (channel errors, not session-wide)
|
||||||
- [ ] Unit tests: channel open, reconnection trigger, backoff timing, forward re-registration
|
- [x] Unit tests: channel open, reconnection trigger, backoff timing, forward re-registration
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
@@ -52,8 +52,13 @@ Reconnection is always enabled. The backoff caps at 30 seconds and continues ind
|
|||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
> To be filled by implementation agent
|
- Converted `client.rs` (single file) to directory module: `client/mod.rs` + `client/channel_manager.rs`
|
||||||
|
- Used `russh::keys::PrivateKey` and `russh::keys::PublicKey` (not the nonexistent `russh::key::KeyPair`)
|
||||||
|
- Reconnection monitor runs as a spawned tokio task that polls `handle.is_closed()` every 1s
|
||||||
|
- On reconnect: creates new transport stream + new SSH session (ADR-004 full reconnect)
|
||||||
|
- `ForwardRequest` struct tracks registered port forwards for re-registration after reconnect
|
||||||
|
- In-flight channels on old session naturally fail with `ChannelError::ChannelClosed` since the handle is replaced
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
> To be filled on completion
|
Implemented `ChannelManager` in `crates/wraith-core/src/client/channel_manager.rs` with SSH session management, channel opens (`open_direct_tcpip`), port forward requests (`request_tcpip_forward`/`cancel_tcpip_forward`), and automatic reconnection with exponential backoff (1s→30s cap). Full reconnect per ADR-004 creates new transport stream + new SSH session. Port forwards are re-registered after successful reconnect. 8 unit tests covering backoff timing, forward tracking, transport failure, and reconnection detection.
|
||||||
Reference in New Issue
Block a user