Compare commits
1 Commits
feat/clien
...
feat/trans
| Author | SHA1 | Date | |
|---|---|---|---|
| e3f33a24c3 |
@@ -10,7 +10,7 @@ name = "wraith_core"
|
|||||||
default = []
|
default = []
|
||||||
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
||||||
iroh = ["dep:iroh", "dep:url"]
|
iroh = ["dep:iroh", "dep:url"]
|
||||||
acme = ["dep:rustls-acme", "tls"]
|
acme = ["dep:rustls-acme", "dep:futures", "tls"]
|
||||||
testutil = []
|
testutil = []
|
||||||
transport-traits = []
|
transport-traits = []
|
||||||
|
|
||||||
@@ -25,6 +25,7 @@ tokio-rustls = { version = "0.26", optional = true }
|
|||||||
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
||||||
rustls-pki-types = { version = "1", optional = true }
|
rustls-pki-types = { version = "1", optional = true }
|
||||||
rustls-acme = { version = "0.12", optional = true }
|
rustls-acme = { version = "0.12", optional = true }
|
||||||
|
futures = { version = "0.3", optional = true }
|
||||||
webpki-roots = { version = "0.26", optional = true }
|
webpki-roots = { version = "0.26", optional = true }
|
||||||
iroh = { version = "0.34", optional = true }
|
iroh = { version = "0.34", optional = true }
|
||||||
url = { version = "2", optional = true }
|
url = { version = "2", optional = true }
|
||||||
|
|||||||
0
crates/wraith-core/src/socks5.rs
Normal file
0
crates/wraith-core/src/socks5.rs
Normal file
@@ -1,490 +0,0 @@
|
|||||||
mod protocol;
|
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
|
||||||
use tokio::net::TcpListener;
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
|
|
||||||
|
|
||||||
pub use protocol::Socks5Address;
|
|
||||||
|
|
||||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
|
||||||
|
|
||||||
pub trait ChannelOpener: Send + Sync + 'static {
|
|
||||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
|
||||||
|
|
||||||
fn open_channel(
|
|
||||||
&self,
|
|
||||||
host: String,
|
|
||||||
port: u16,
|
|
||||||
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum ChannelOpenError {
|
|
||||||
#[error("session closed")]
|
|
||||||
SessionClosed,
|
|
||||||
#[error("channel open failed")]
|
|
||||||
ChannelOpenFailed,
|
|
||||||
#[error("connection refused")]
|
|
||||||
ConnectionRefused,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Socks5Server<C: ChannelOpener> {
|
|
||||||
listen_addr: SocketAddr,
|
|
||||||
channel_opener: Arc<C>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C: ChannelOpener> Socks5Server<C> {
|
|
||||||
pub fn new(channel_opener: C) -> Self {
|
|
||||||
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
|
||||||
let listen_addr: SocketAddr = addr
|
|
||||||
.parse()
|
|
||||||
.expect("invalid SOCKS5 listen address");
|
|
||||||
Self {
|
|
||||||
listen_addr,
|
|
||||||
channel_opener: Arc::new(channel_opener),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn listen_addr(&self) -> SocketAddr {
|
|
||||||
self.listen_addr
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(self) -> Result<(), std::io::Error> {
|
|
||||||
let listener = TcpListener::bind(self.listen_addr).await?;
|
|
||||||
debug!("socks5 server listening on {}", self.listen_addr);
|
|
||||||
loop {
|
|
||||||
let (socket, _peer) = listener.accept().await?;
|
|
||||||
let opener = Arc::clone(&self.channel_opener);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = handle_socks5_connection(socket, opener).await {
|
|
||||||
debug!("socks5 connection error: {e}");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_socks5_connection<S, C>(
|
|
||||||
mut socket: S,
|
|
||||||
opener: Arc<C>,
|
|
||||||
) -> Result<(), Socks5Error>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
C: ChannelOpener,
|
|
||||||
{
|
|
||||||
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
|
|
||||||
if vm.version != 0x05 {
|
|
||||||
return Err(Socks5Error::InvalidVersion(vm.version));
|
|
||||||
}
|
|
||||||
if !vm.methods.contains(&0x00) {
|
|
||||||
let reply = [0x05, 0xFF];
|
|
||||||
socket.write_all(&reply).await?;
|
|
||||||
socket.shutdown().await?;
|
|
||||||
return Err(Socks5Error::NoAcceptableAuth);
|
|
||||||
}
|
|
||||||
let reply = [0x05, 0x00];
|
|
||||||
socket.write_all(&reply).await?;
|
|
||||||
|
|
||||||
let request = Socks5Request::read_from(&mut socket).await?;
|
|
||||||
if request.version != 0x05 {
|
|
||||||
return Err(Socks5Error::InvalidVersion(request.version));
|
|
||||||
}
|
|
||||||
if request.command != 0x01 {
|
|
||||||
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
|
|
||||||
return Err(Socks5Error::UnsupportedCommand(request.command));
|
|
||||||
}
|
|
||||||
|
|
||||||
let (host, port) = match &request.address {
|
|
||||||
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
|
|
||||||
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
|
|
||||||
Socks5Address::Domain(name) => (name.clone(), request.port),
|
|
||||||
};
|
|
||||||
|
|
||||||
match opener.open_channel(host, port).await {
|
|
||||||
Ok(mut ssh_stream) => {
|
|
||||||
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
|
|
||||||
let reply = Socks5Reply::success(bind_addr, 0);
|
|
||||||
reply.write_to(&mut socket).await?;
|
|
||||||
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
|
|
||||||
Err(Socks5Error::ChannelOpenFailed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
|
|
||||||
socket: &mut S,
|
|
||||||
reply: Socks5Reply,
|
|
||||||
) -> Result<(), Socks5Error> {
|
|
||||||
reply.write_to(socket).await?;
|
|
||||||
let _ = socket.shutdown().await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum Socks5Error {
|
|
||||||
#[error("invalid SOCKS version: {0}")]
|
|
||||||
InvalidVersion(u8),
|
|
||||||
#[error("no acceptable auth method")]
|
|
||||||
NoAcceptableAuth,
|
|
||||||
#[error("unsupported command: {0}")]
|
|
||||||
UnsupportedCommand(u8),
|
|
||||||
#[error("channel open failed")]
|
|
||||||
ChannelOpenFailed,
|
|
||||||
#[error("io error")]
|
|
||||||
Io(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct HandleChannelOpener<H: russh::client::Handler> {
|
|
||||||
handle: Arc<Mutex<russh::client::Handle<H>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
|
||||||
pub fn new(handle: russh::client::Handle<H>) -> Self {
|
|
||||||
Self {
|
|
||||||
handle: Arc::new(Mutex::new(handle)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
|
|
||||||
Self { handle }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
|
||||||
type Stream = russh::ChannelStream<russh::client::Msg>;
|
|
||||||
|
|
||||||
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
|
|
||||||
let handle = self.handle.lock().await;
|
|
||||||
if handle.is_closed() {
|
|
||||||
return Err(ChannelOpenError::SessionClosed);
|
|
||||||
}
|
|
||||||
let channel = handle
|
|
||||||
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
|
|
||||||
.await
|
|
||||||
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
|
|
||||||
Ok(channel.into_stream())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
|
||||||
|
|
||||||
struct MockChannelOpener {
|
|
||||||
fail: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChannelOpener for MockChannelOpener {
|
|
||||||
type Stream = DuplexStream;
|
|
||||||
|
|
||||||
async fn open_channel(
|
|
||||||
&self,
|
|
||||||
_host: String,
|
|
||||||
_port: u16,
|
|
||||||
) -> Result<Self::Stream, ChannelOpenError> {
|
|
||||||
if self.fail {
|
|
||||||
Err(ChannelOpenError::ChannelOpenFailed)
|
|
||||||
} else {
|
|
||||||
let (client, _server) = duplex(4096);
|
|
||||||
Ok(client)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
|
|
||||||
let mut buf = vec![0x05, methods.len() as u8];
|
|
||||||
buf.extend_from_slice(methods);
|
|
||||||
buf
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
|
|
||||||
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
|
|
||||||
buf.extend_from_slice(&addr);
|
|
||||||
buf.extend_from_slice(&port.to_be_bytes());
|
|
||||||
buf
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
|
|
||||||
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
|
|
||||||
buf.push(domain.len() as u8);
|
|
||||||
buf.extend_from_slice(domain.as_bytes());
|
|
||||||
buf.extend_from_slice(&port.to_be_bytes());
|
|
||||||
buf
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
|
|
||||||
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
|
|
||||||
buf.extend_from_slice(&addr);
|
|
||||||
buf.extend_from_slice(&port.to_be_bytes());
|
|
||||||
buf
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
|
||||||
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
let mut resp = [0u8; 2];
|
|
||||||
client.read_exact(&mut resp).await.unwrap();
|
|
||||||
resp
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
|
|
||||||
client
|
|
||||||
.write_all(&build_socks5_connect_ipv4(addr, port))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
let mut reply_buf = [0u8; 10];
|
|
||||||
client.read_exact(&mut reply_buf).await.unwrap();
|
|
||||||
reply_buf.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn handshake_no_auth_method() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
let resp = do_handshake(&mut client).await;
|
|
||||||
assert_eq!(resp, [0x05, 0x00]);
|
|
||||||
|
|
||||||
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
|
|
||||||
assert_eq!(reply_buf[0], 0x05);
|
|
||||||
assert_eq!(reply_buf[1], 0x00);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn handshake_rejects_no_acceptable_method() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
client
|
|
||||||
.write_all(&build_socks5_greeting(&[0x02]))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut resp = [0u8; 2];
|
|
||||||
client.read_exact(&mut resp).await.unwrap();
|
|
||||||
assert_eq!(resp, [0x05, 0xFF]);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let result = server_handle.await.unwrap();
|
|
||||||
assert!(result.is_err());
|
|
||||||
assert!(matches!(
|
|
||||||
result.unwrap_err(),
|
|
||||||
Socks5Error::NoAcceptableAuth
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn address_type_ipv4() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
|
||||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
|
||||||
assert_eq!(reply_buf[1], 0x00);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn address_type_domain() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
|
||||||
|
|
||||||
client
|
|
||||||
.write_all(&build_socks5_connect_domain("example.com", 443))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut reply_buf = [0u8; 10];
|
|
||||||
client.read_exact(&mut reply_buf).await.unwrap();
|
|
||||||
assert_eq!(reply_buf[1], 0x00);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn address_type_ipv6() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
|
||||||
|
|
||||||
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
|
||||||
client
|
|
||||||
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut reply_buf = [0u8; 10];
|
|
||||||
client.read_exact(&mut reply_buf).await.unwrap();
|
|
||||||
assert_eq!(reply_buf[0], 0x05);
|
|
||||||
assert_eq!(reply_buf[1], 0x00);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn channel_open_failure_returns_socks5_error() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: true };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
|
||||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
|
||||||
assert_eq!(reply_buf[0], 0x05);
|
|
||||||
assert_eq!(reply_buf[1], 0x05);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn unsupported_command_returns_error() {
|
|
||||||
let (mut client, server) = duplex(4096);
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
|
||||||
|
|
||||||
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
|
|
||||||
bind_req.extend_from_slice(&[127, 0, 0, 1]);
|
|
||||||
bind_req.extend_from_slice(&80u16.to_be_bytes());
|
|
||||||
client.write_all(&bind_req).await.unwrap();
|
|
||||||
client.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut reply_buf = [0u8; 10];
|
|
||||||
client.read_exact(&mut reply_buf).await.unwrap();
|
|
||||||
assert_eq!(reply_buf[1], 0x07);
|
|
||||||
|
|
||||||
drop(client);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn bidirectional_proxy_flow() {
|
|
||||||
let (mut client_sock, server_sock) = duplex(4096);
|
|
||||||
let (ssh_client, mut ssh_server) = duplex(4096);
|
|
||||||
|
|
||||||
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
|
|
||||||
|
|
||||||
struct ProxyOpener {
|
|
||||||
stream: Arc<Mutex<Option<DuplexStream>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChannelOpener for ProxyOpener {
|
|
||||||
type Stream = DuplexStream;
|
|
||||||
|
|
||||||
async fn open_channel(
|
|
||||||
&self,
|
|
||||||
_host: String,
|
|
||||||
_port: u16,
|
|
||||||
) -> Result<Self::Stream, ChannelOpenError> {
|
|
||||||
self.stream
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.take()
|
|
||||||
.ok_or(ChannelOpenError::ChannelOpenFailed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let opener = ProxyOpener {
|
|
||||||
stream: Arc::clone(&ssh_stream),
|
|
||||||
};
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
handle_socks5_connection(server_sock, Arc::new(opener)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client_sock).await;
|
|
||||||
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
|
||||||
assert_eq!(reply_buf[1], 0x00);
|
|
||||||
|
|
||||||
let test_data = b"hello through tunnel";
|
|
||||||
client_sock.write_all(test_data).await.unwrap();
|
|
||||||
client_sock.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut received = vec![0u8; test_data.len()];
|
|
||||||
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(&received, test_data);
|
|
||||||
|
|
||||||
let echo_data = b"response from tunnel";
|
|
||||||
ssh_server.write_all(echo_data).await.unwrap();
|
|
||||||
ssh_server.flush().await.unwrap();
|
|
||||||
|
|
||||||
let mut received_back = vec![0u8; echo_data.len()];
|
|
||||||
client_sock.read_exact(&mut received_back).await.unwrap();
|
|
||||||
assert_eq!(&received_back, echo_data);
|
|
||||||
|
|
||||||
drop(client_sock);
|
|
||||||
drop(ssh_server);
|
|
||||||
let _ = server_handle.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn default_listen_address() {
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
let server = Socks5Server::new(opener);
|
|
||||||
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn custom_listen_address() {
|
|
||||||
let opener = MockChannelOpener { fail: false };
|
|
||||||
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
|
|
||||||
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,304 +0,0 @@
|
|||||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub enum Socks5Address {
|
|
||||||
Ipv4(Ipv4Addr),
|
|
||||||
Ipv6(Ipv6Addr),
|
|
||||||
Domain(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Socks5VersionMethod {
|
|
||||||
pub version: u8,
|
|
||||||
pub methods: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Socks5VersionMethod {
|
|
||||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
|
||||||
let version = reader.read_u8().await?;
|
|
||||||
let nmethods = reader.read_u8().await?;
|
|
||||||
let mut methods = vec![0u8; nmethods as usize];
|
|
||||||
reader.read_exact(&mut methods).await?;
|
|
||||||
Ok(Self { version, methods })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Socks5Request {
|
|
||||||
pub version: u8,
|
|
||||||
pub command: u8,
|
|
||||||
pub address: Socks5Address,
|
|
||||||
pub port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Socks5Request {
|
|
||||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
|
||||||
let version = reader.read_u8().await?;
|
|
||||||
let command = reader.read_u8().await?;
|
|
||||||
let _rsv = reader.read_u8().await?;
|
|
||||||
let atyp = reader.read_u8().await?;
|
|
||||||
|
|
||||||
let address = match atyp {
|
|
||||||
0x01 => {
|
|
||||||
let mut octets = [0u8; 4];
|
|
||||||
reader.read_exact(&mut octets).await?;
|
|
||||||
Socks5Address::Ipv4(Ipv4Addr::from(octets))
|
|
||||||
}
|
|
||||||
0x04 => {
|
|
||||||
let mut octets = [0u8; 16];
|
|
||||||
reader.read_exact(&mut octets).await?;
|
|
||||||
Socks5Address::Ipv6(Ipv6Addr::from(octets))
|
|
||||||
}
|
|
||||||
0x03 => {
|
|
||||||
let len = reader.read_u8().await?;
|
|
||||||
let mut domain = vec![0u8; len as usize];
|
|
||||||
reader.read_exact(&mut domain).await?;
|
|
||||||
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err(std::io::Error::new(
|
|
||||||
std::io::ErrorKind::InvalidData,
|
|
||||||
format!("unsupported address type: {atyp}"),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let port = reader.read_u16().await?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
version,
|
|
||||||
command,
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Socks5Reply {
|
|
||||||
pub version: u8,
|
|
||||||
pub reply: u8,
|
|
||||||
pub address: Socks5Address,
|
|
||||||
pub port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Socks5Reply {
|
|
||||||
pub fn success(address: Socks5Address, port: u16) -> Self {
|
|
||||||
Self {
|
|
||||||
version: 0x05,
|
|
||||||
reply: 0x00,
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn connection_refused() -> Self {
|
|
||||||
Self {
|
|
||||||
version: 0x05,
|
|
||||||
reply: 0x05,
|
|
||||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
|
||||||
port: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_not_supported() -> Self {
|
|
||||||
Self {
|
|
||||||
version: 0x05,
|
|
||||||
reply: 0x07,
|
|
||||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
|
||||||
port: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
|
|
||||||
writer.write_u8(self.version).await?;
|
|
||||||
writer.write_u8(self.reply).await?;
|
|
||||||
writer.write_u8(0x00).await?;
|
|
||||||
match &self.address {
|
|
||||||
Socks5Address::Ipv4(addr) => {
|
|
||||||
writer.write_u8(0x01).await?;
|
|
||||||
writer.write_all(&addr.octets()).await?;
|
|
||||||
}
|
|
||||||
Socks5Address::Ipv6(addr) => {
|
|
||||||
writer.write_u8(0x04).await?;
|
|
||||||
writer.write_all(&addr.octets()).await?;
|
|
||||||
}
|
|
||||||
Socks5Address::Domain(name) => {
|
|
||||||
writer.write_u8(0x03).await?;
|
|
||||||
writer.write_u8(name.len() as u8).await?;
|
|
||||||
writer.write_all(name.as_bytes()).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
writer.write_u16(self.port).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use std::io::Cursor;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_version_method_no_auth() {
|
|
||||||
let data = [0x05, 0x01, 0x00];
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
|
||||||
assert_eq!(vm.version, 0x05);
|
|
||||||
assert_eq!(vm.methods, vec![0x00]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_version_method_multiple() {
|
|
||||||
let data = [0x05, 0x02, 0x00, 0x02];
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
|
||||||
assert_eq!(vm.version, 0x05);
|
|
||||||
assert_eq!(vm.methods, vec![0x00, 0x02]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_request_ipv4() {
|
|
||||||
let mut data = vec![0x05, 0x01, 0x00, 0x01];
|
|
||||||
data.extend_from_slice(&[10, 0, 0, 1]);
|
|
||||||
data.extend_from_slice(&443u16.to_be_bytes());
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
|
||||||
assert_eq!(req.version, 0x05);
|
|
||||||
assert_eq!(req.command, 0x01);
|
|
||||||
assert_eq!(
|
|
||||||
req.address,
|
|
||||||
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
|
|
||||||
);
|
|
||||||
assert_eq!(req.port, 443);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_request_ipv6() {
|
|
||||||
let mut data = vec![0x05, 0x01, 0x00, 0x04];
|
|
||||||
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
|
||||||
data.extend_from_slice(&octets);
|
|
||||||
data.extend_from_slice(&443u16.to_be_bytes());
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
|
||||||
assert_eq!(req.version, 0x05);
|
|
||||||
assert_eq!(req.command, 0x01);
|
|
||||||
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
|
|
||||||
assert_eq!(req.port, 443);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_request_domain() {
|
|
||||||
let domain = "example.com";
|
|
||||||
let mut data = vec![0x05, 0x01, 0x00, 0x03];
|
|
||||||
data.push(domain.len() as u8);
|
|
||||||
data.extend_from_slice(domain.as_bytes());
|
|
||||||
data.extend_from_slice(&443u16.to_be_bytes());
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
|
||||||
assert_eq!(req.version, 0x05);
|
|
||||||
assert_eq!(req.command, 0x01);
|
|
||||||
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
|
|
||||||
assert_eq!(req.port, 443);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn parse_request_unsupported_address_type() {
|
|
||||||
let data = [0x05, 0x01, 0x00, 0x05];
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
|
||||||
let result = Socks5Request::read_from(&mut cursor).await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn reply_success_ipv4() {
|
|
||||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
assert_eq!(buf[0], 0x05);
|
|
||||||
assert_eq!(buf[1], 0x00);
|
|
||||||
assert_eq!(buf[2], 0x00);
|
|
||||||
assert_eq!(buf[3], 0x01);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn reply_connection_refused() {
|
|
||||||
let reply = Socks5Reply::connection_refused();
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
assert_eq!(buf[0], 0x05);
|
|
||||||
assert_eq!(buf[1], 0x05);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn reply_command_not_supported() {
|
|
||||||
let reply = Socks5Reply::command_not_supported();
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
assert_eq!(buf[0], 0x05);
|
|
||||||
assert_eq!(buf[1], 0x07);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn roundtrip_ipv4_reply() {
|
|
||||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
|
|
||||||
let mut cursor = Cursor::new(&buf[..]);
|
|
||||||
let version = cursor.read_u8().await.unwrap();
|
|
||||||
let _reply_code = cursor.read_u8().await.unwrap();
|
|
||||||
let _rsv = cursor.read_u8().await.unwrap();
|
|
||||||
let atyp = cursor.read_u8().await.unwrap();
|
|
||||||
assert_eq!(version, 0x05);
|
|
||||||
assert_eq!(atyp, 0x01);
|
|
||||||
let mut octets = [0u8; 4];
|
|
||||||
cursor.read_exact(&mut octets).await.unwrap();
|
|
||||||
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
|
|
||||||
let port = cursor.read_u16().await.unwrap();
|
|
||||||
assert_eq!(port, 1080);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn roundtrip_ipv6_reply() {
|
|
||||||
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
|
|
||||||
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
|
|
||||||
let mut cursor = Cursor::new(&buf[..]);
|
|
||||||
let _version = cursor.read_u8().await.unwrap();
|
|
||||||
let _reply_code = cursor.read_u8().await.unwrap();
|
|
||||||
let _rsv = cursor.read_u8().await.unwrap();
|
|
||||||
let atyp = cursor.read_u8().await.unwrap();
|
|
||||||
assert_eq!(atyp, 0x04);
|
|
||||||
let mut octets = [0u8; 16];
|
|
||||||
cursor.read_exact(&mut octets).await.unwrap();
|
|
||||||
assert_eq!(Ipv6Addr::from(octets), addr);
|
|
||||||
let port = cursor.read_u16().await.unwrap();
|
|
||||||
assert_eq!(port, 443);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn roundtrip_domain_reply() {
|
|
||||||
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
reply.write_to(&mut buf).await.unwrap();
|
|
||||||
|
|
||||||
let mut cursor = Cursor::new(&buf[..]);
|
|
||||||
let _version = cursor.read_u8().await.unwrap();
|
|
||||||
let _reply_code = cursor.read_u8().await.unwrap();
|
|
||||||
let _rsv = cursor.read_u8().await.unwrap();
|
|
||||||
let atyp = cursor.read_u8().await.unwrap();
|
|
||||||
assert_eq!(atyp, 0x03);
|
|
||||||
let len = cursor.read_u8().await.unwrap();
|
|
||||||
let mut domain = vec![0u8; len as usize];
|
|
||||||
cursor.read_exact(&mut domain).await.unwrap();
|
|
||||||
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
|
|
||||||
let port = cursor.read_u16().await.unwrap();
|
|
||||||
assert_eq!(port, 8080);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
362
crates/wraith-core/src/transport/acme.rs
Normal file
362
crates/wraith-core/src/transport/acme.rs
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use rustls::crypto::aws_lc_rs::default_provider;
|
||||||
|
use rustls::ServerConfig;
|
||||||
|
use rustls_acme::caches::DirCache;
|
||||||
|
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
||||||
|
use tracing::{error, info};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
||||||
|
|
||||||
|
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum AcmeMode {
|
||||||
|
Domain { domain: String },
|
||||||
|
Ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcmeCertProvider {
|
||||||
|
mode: AcmeMode,
|
||||||
|
cache_dir: Option<PathBuf>,
|
||||||
|
directory_url: String,
|
||||||
|
contact: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for AcmeCertProvider {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("AcmeCertProvider")
|
||||||
|
.field("mode", &self.mode)
|
||||||
|
.field("cache_dir", &self.cache_dir)
|
||||||
|
.field("directory_url", &self.directory_url)
|
||||||
|
.field("contact", &self.contact)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcmeCertProvider {
|
||||||
|
pub fn new(mode: AcmeMode) -> Self {
|
||||||
|
Self {
|
||||||
|
mode,
|
||||||
|
cache_dir: None,
|
||||||
|
directory_url: rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY.to_string(),
|
||||||
|
contact: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn domain(domain: impl Into<String>) -> Self {
|
||||||
|
Self::new(AcmeMode::Domain {
|
||||||
|
domain: domain.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ip() -> Self {
|
||||||
|
Self::new(AcmeMode::Ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||||
|
self.cache_dir = Some(dir.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_directory(mut self, url: impl Into<String>) -> Self {
|
||||||
|
self.directory_url = url.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_production_directory(mut self) -> Self {
|
||||||
|
self.directory_url = rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_contact(mut self, contact: impl Into<String>) -> Self {
|
||||||
|
self.contact.push(contact.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mode(&self) -> &AcmeMode {
|
||||||
|
&self.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_acme_state(&self) -> (AcmeState<std::io::Error>, Arc<ResolvesServerCertAcme>) {
|
||||||
|
let domains: Vec<String> = match &self.mode {
|
||||||
|
AcmeMode::Domain { domain } => vec![domain.clone()],
|
||||||
|
AcmeMode::Ip => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let base_config = AcmeConfig::new(domains)
|
||||||
|
.directory(&self.directory_url)
|
||||||
|
.contact(self.contact.clone());
|
||||||
|
|
||||||
|
let state = match &self.cache_dir {
|
||||||
|
Some(cache_dir) => {
|
||||||
|
base_config.cache(DirCache::new(cache_dir.clone())).state()
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
base_config
|
||||||
|
.cache(rustls_acme::caches::NoCache::default())
|
||||||
|
.state()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolver = state.resolver();
|
||||||
|
(state, resolver)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_server_config_with_resolver(
|
||||||
|
&self,
|
||||||
|
resolver: Arc<ResolvesServerCertAcme>,
|
||||||
|
) -> Result<Arc<ServerConfig>> {
|
||||||
|
let provider = default_provider().into();
|
||||||
|
let mut config = ServerConfig::builder_with_provider(provider)
|
||||||
|
.with_safe_default_protocol_versions()
|
||||||
|
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(resolver);
|
||||||
|
config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||||
|
Ok(Arc::new(config))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcmeTlsAcceptor {
|
||||||
|
listener: TcpListener,
|
||||||
|
listen_addr: SocketAddr,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
server_config: Arc<ServerConfig>,
|
||||||
|
tokio_acceptor: TokioTlsAcceptor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcmeTlsAcceptor {
|
||||||
|
pub async fn bind_acme(
|
||||||
|
addr: SocketAddr,
|
||||||
|
provider: Arc<AcmeCertProvider>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (state, resolver) = provider.build_acme_state();
|
||||||
|
|
||||||
|
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
||||||
|
|
||||||
|
Self::spawn_state_worker(state, resolver);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
let listen_addr = listener.local_addr()?;
|
||||||
|
|
||||||
|
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
listener,
|
||||||
|
listen_addr,
|
||||||
|
server_config,
|
||||||
|
tokio_acceptor,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
|
self.listen_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_state_worker(state: AcmeState<std::io::Error>, resolver: Arc<ResolvesServerCertAcme>) {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
let task = async move {
|
||||||
|
let mut state = state;
|
||||||
|
while let Some(event) = state.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(ok) => {
|
||||||
|
if let rustls_acme::EventOk::DeployedNewCert = ok {
|
||||||
|
info!("ACME: new certificate deployed");
|
||||||
|
} else {
|
||||||
|
info!("ACME event: {:?}", ok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => error!("ACME event error: {:?}", err),
|
||||||
|
}
|
||||||
|
if Arc::strong_count(&resolver) == 1 {
|
||||||
|
info!("ACME resolver dropped, stopping background task");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tokio::spawn(task);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TransportAcceptor for AcmeTlsAcceptor {
|
||||||
|
type Stream = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
|
||||||
|
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||||
|
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||||
|
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||||
|
|
||||||
|
let server_name = tls_stream
|
||||||
|
.get_ref()
|
||||||
|
.1
|
||||||
|
.server_name()
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
let info = TransportInfo {
|
||||||
|
remote_addr: Some(remote_addr),
|
||||||
|
transport_kind: TransportKind::Tls { server_name },
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((tls_stream, info))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_domain_mode() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||||
|
if let AcmeMode::Domain { domain } = provider.mode() {
|
||||||
|
assert_eq!(domain, "example.com");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_ip_mode() {
|
||||||
|
let provider = AcmeCertProvider::ip();
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Ip));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_default_staging_directory() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_production_directory() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com").with_production_directory();
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_custom_directory() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_directory("https://custom.acme.dir/");
|
||||||
|
assert_eq!(provider.directory_url, "https://custom.acme.dir/");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_with_cache_dir() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/acme_cache");
|
||||||
|
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/acme_cache")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_with_contact() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_contact("mailto:admin@example.com");
|
||||||
|
assert_eq!(
|
||||||
|
provider.contact,
|
||||||
|
vec!["mailto:admin@example.com".to_string()]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_state_domain() {
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
let (_state, resolver) = provider.build_acme_state();
|
||||||
|
assert!(Arc::strong_count(&resolver) >= 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_state_with_cache() {
|
||||||
|
let provider =
|
||||||
|
AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
||||||
|
let (_state, resolver) = provider.build_acme_state();
|
||||||
|
assert!(Arc::strong_count(&resolver) >= 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_build_server_config() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
|
let (_, resolver) = provider.build_acme_state();
|
||||||
|
let config = provider.build_server_config_with_resolver(resolver).unwrap();
|
||||||
|
assert!(!config.alpn_protocols.is_empty());
|
||||||
|
assert!(config
|
||||||
|
.alpn_protocols
|
||||||
|
.iter()
|
||||||
|
.any(|p| p == ACME_TLS_ALPN_NAME));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_mode_domain_debug() {
|
||||||
|
let mode = AcmeMode::Domain {
|
||||||
|
domain: "test.example.com".to_string(),
|
||||||
|
};
|
||||||
|
let debug_str = format!("{:?}", mode);
|
||||||
|
assert!(debug_str.contains("test.example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_mode_ip_debug() {
|
||||||
|
let mode = AcmeMode::Ip;
|
||||||
|
let debug_str = format!("{:?}", mode);
|
||||||
|
assert!(debug_str.contains("Ip"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn acme_cert_provider_builder_chain() {
|
||||||
|
let provider = AcmeCertProvider::domain("test.example.com")
|
||||||
|
.with_production_directory()
|
||||||
|
.with_cache_dir("/tmp/cache")
|
||||||
|
.with_contact("mailto:admin@test.example.com");
|
||||||
|
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||||
|
assert_eq!(
|
||||||
|
provider.directory_url,
|
||||||
|
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||||
|
);
|
||||||
|
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/cache")));
|
||||||
|
assert_eq!(provider.contact.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn acme_tls_acceptor_bind_acme() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
let provider = Arc::new(AcmeCertProvider::domain("example.com"));
|
||||||
|
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
|
||||||
|
let acceptor = AcmeTlsAcceptor::bind_acme(addr, provider).await.unwrap();
|
||||||
|
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn acme_staging_domain_cert_provisioning() {
|
||||||
|
let _ = default_provider().install_default();
|
||||||
|
|
||||||
|
let cache_dir = tempfile::tempdir().unwrap();
|
||||||
|
let provider = Arc::new(
|
||||||
|
AcmeCertProvider::domain("acme-test.example.com")
|
||||||
|
.with_cache_dir(cache_dir.path())
|
||||||
|
.with_contact("mailto:admin@example.com"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||||
|
let result = AcmeTlsAcceptor::bind_acme(addr, provider).await;
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"ACME TlsAcceptor should bind: {:?}",
|
||||||
|
result.err()
|
||||||
|
);
|
||||||
|
|
||||||
|
let acceptor = result.unwrap();
|
||||||
|
assert_eq!(acceptor.listen_addr().port(), 443);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,12 @@ mod tls;
|
|||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
mod acme;
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
pub use acme::{AcmeCertProvider, AcmeMode, AcmeTlsAcceptor};
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|||||||
@@ -9,8 +9,16 @@ use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
|||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
use rustls::crypto::aws_lc_rs::default_provider;
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
use rustls_acme::ResolvesServerCertAcme;
|
||||||
|
|
||||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||||
|
|
||||||
/// A TLS-based client transport that connects to a remote address over TLS.
|
/// A TLS-based client transport that connects to a remote address over TLS.
|
||||||
///
|
///
|
||||||
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
||||||
@@ -110,8 +118,10 @@ pub struct AcmeConfig {
|
|||||||
/// A TLS-based server transport acceptor that accepts TCP connections
|
/// A TLS-based server transport acceptor that accepts TCP connections
|
||||||
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
||||||
///
|
///
|
||||||
/// Requires certificate and private key configuration. Supports manual
|
/// Supports three certificate modes (ADR-008):
|
||||||
/// cert/key paths and an ACME config stub (ADR-008).
|
/// - Manual certs via `bind()` with explicit cert/key
|
||||||
|
/// - ACME certs via `bind_acme()` with an `AcmeCertProvider`
|
||||||
|
/// - The stub `AcmeConfig` parameter in `bind()` is kept for backward compat
|
||||||
pub struct TlsAcceptor {
|
pub struct TlsAcceptor {
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
listen_addr: SocketAddr,
|
listen_addr: SocketAddr,
|
||||||
@@ -145,6 +155,33 @@ impl TlsAcceptor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "acme")]
|
||||||
|
pub async fn bind_acme(
|
||||||
|
addr: SocketAddr,
|
||||||
|
acme_resolver: Arc<ResolvesServerCertAcme>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
let listen_addr = listener.local_addr()?;
|
||||||
|
|
||||||
|
let provider = default_provider().into();
|
||||||
|
let mut server_config = ServerConfig::builder_with_provider(provider)
|
||||||
|
.with_safe_default_protocol_versions()
|
||||||
|
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(acme_resolver);
|
||||||
|
server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||||
|
|
||||||
|
let server_config = Arc::new(server_config);
|
||||||
|
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
listener,
|
||||||
|
listen_addr,
|
||||||
|
server_config,
|
||||||
|
tokio_acceptor,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn listen_addr(&self) -> SocketAddr {
|
pub fn listen_addr(&self) -> SocketAddr {
|
||||||
self.listen_addr
|
self.listen_addr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,8 +43,14 @@ This integrates with `TlsAcceptor` by providing ACME-resolved certificates inste
|
|||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
> To be filled by implementation agent
|
- `AcmeCertProvider` is the main entry point. It creates `AcmeState` and `ResolvesServerCertAcme` from `rustls-acme`.
|
||||||
|
- The `ResolvesServerCertAcme` resolver is shared between the `AcmeState` background task and the `ServerConfig`, so cert updates propagate automatically.
|
||||||
|
- `AcmeTlsAcceptor::bind_acme()` creates a TLS acceptor that uses ACME-provisioned certs and spawns a background tokio task for auto-renewal.
|
||||||
|
- `TlsAcceptor::bind_acme()` also added for users who want to use ACME with the standard `TlsAcceptor` type directly.
|
||||||
|
- The `AcmeConfig` stub in `tls.rs` is retained for backward compat with existing `TlsAcceptor::bind()`.
|
||||||
|
- `acme` feature implies `tls` and adds `rustls-acme` + `futures` dependencies.
|
||||||
|
- TLS-ALPN-01 challenge handling works via the `acme-tls/1` ALPN protocol registered in `ServerConfig` — the resolver dispatches challenge vs regular certs automatically.
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
> To be filled on completion
|
Implemented ACME/Let's Encrypt certificate provisioning (ADR-008) behind the `acme` feature flag. `AcmeCertProvider` supports domain-based and IP-based modes using `rustls-acme`. `AcmeTlsAcceptor::bind_acme()` and `TlsAcceptor::bind_acme()` provide ACME-integrated TLS acceptance with automatic certificate renewal via a background tokio task. Unit tests cover config construction, builder patterns, and server config generation. Integration test for LE staging is marked `#[ignore]`.
|
||||||
Reference in New Issue
Block a user