Define Transport trait, TransportAcceptor trait, TransportInfo, and TransportKind types

This commit is contained in:
2026-06-02 09:17:50 +00:00
parent 56d032afdb
commit dddc6d7a4c
10 changed files with 323 additions and 2 deletions

2
Cargo.lock generated
View File

@@ -5461,6 +5461,7 @@ name = "wraith-core"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"iroh",
"russh",
"rustls",
@@ -5470,6 +5471,7 @@ dependencies = [
"tokio-rustls",
"tokio-util",
"tracing",
"wraith-core",
]
[[package]]

View File

@@ -11,6 +11,8 @@ default = []
tls = ["dep:tokio-rustls", "dep:rustls"]
iroh = ["dep:iroh"]
acme = ["dep:rustls-acme", "tls"]
testutil = []
transport-traits = []
[dependencies]
russh = "0.49"
@@ -23,3 +25,7 @@ tokio-rustls = { version = "0.26", optional = true }
rustls = { version = "0.23", optional = true }
rustls-acme = { version = "0.12", optional = true }
iroh = { version = "0.34", optional = true }
async-trait = "0.1"
[dev-dependencies]
wraith-core = { path = ".", features = ["testutil"] }

View File

@@ -9,3 +9,4 @@ pub mod error;
pub mod testutil;
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};

View File

@@ -0,0 +1,141 @@
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
use anyhow::Result;
#[cfg(feature = "transport-traits")]
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
#[cfg(not(feature = "transport-traits"))]
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
#[cfg(not(feature = "transport-traits"))]
mod local_traits {
use std::net::SocketAddr;
use anyhow::Result;
use tokio::io::{AsyncRead, AsyncWrite};
use async_trait::async_trait;
#[async_trait]
pub trait Transport: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn connect(&self) -> Result<Self::Stream>;
fn describe(&self) -> String;
}
#[async_trait]
pub trait TransportAcceptor: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
}
#[derive(Debug, Clone)]
pub struct TransportInfo {
pub remote_addr: Option<SocketAddr>,
pub transport_kind: TransportKind,
}
#[derive(Debug, Clone)]
pub enum TransportKind {
Tcp,
Tls { server_name: Option<String> },
Iroh { endpoint_id: String },
}
}
pub struct MockStream {
inner: DuplexStream,
}
impl MockStream {
pub fn new(inner: DuplexStream) -> Self {
Self { inner }
}
}
impl AsyncRead for MockStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}
impl AsyncWrite for MockStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
impl Unpin for MockStream {}
pub struct MockTransport {
buf_size: usize,
}
impl MockTransport {
pub fn new(buf_size: usize) -> Self {
Self { buf_size }
}
}
#[async_trait::async_trait]
impl Transport for MockTransport {
type Stream = MockStream;
async fn connect(&self) -> Result<Self::Stream> {
let (client, _) = tokio::io::duplex(self.buf_size);
Ok(MockStream::new(client))
}
fn describe(&self) -> String {
"mock".to_string()
}
}
pub struct MockTransportAcceptor {
buf_size: usize,
}
impl MockTransportAcceptor {
pub fn new(buf_size: usize) -> Self {
Self { buf_size }
}
}
#[async_trait::async_trait]
impl TransportAcceptor for MockTransportAcceptor {
type Stream = MockStream;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (_, server) = tokio::io::duplex(self.buf_size);
let info = TransportInfo {
remote_addr: None,
transport_kind: TransportKind::Tcp,
};
Ok((MockStream::new(server), info))
}
}
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
let (client, server) = tokio::io::duplex(buf_size);
(MockStream::new(client), MockStream::new(server))
}

View File

@@ -0,0 +1,139 @@
use std::net::SocketAddr;
use anyhow::Result;
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite};
#[async_trait]
pub trait Transport: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn connect(&self) -> Result<Self::Stream>;
fn describe(&self) -> String;
}
#[async_trait]
pub trait TransportAcceptor: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
}
/// Metadata about an incoming transport connection.
///
/// Carries the remote address (if available) and the kind of transport
/// used. The server handler uses this for logging and auth decisions.
/// See ADR-001 for the pluggable transport rationale and ADR-004
/// for why SSH runs entirely over the transport stream.
#[derive(Debug, Clone)]
pub struct TransportInfo {
pub remote_addr: Option<SocketAddr>,
pub transport_kind: TransportKind,
}
/// The kind of transport that produced a connection.
///
/// Each variant identifies the transport mechanism. Used by the
/// server handler for logging and authorization decisions.
/// See ADR-001 and ADR-004.
#[derive(Debug, Clone)]
pub enum TransportKind {
Tcp,
Tls {
server_name: Option<String>,
},
Iroh {
endpoint_id: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, DuplexStream};
struct MockTransport;
#[async_trait]
impl Transport for MockTransport {
type Stream = DuplexStream;
async fn connect(&self) -> Result<Self::Stream> {
let (stream, _) = duplex(1024);
Ok(stream)
}
fn describe(&self) -> String {
"mock".to_string()
}
}
struct MockAcceptor;
#[async_trait]
impl TransportAcceptor for MockAcceptor {
type Stream = DuplexStream;
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
let (stream, _) = duplex(1024);
let info = TransportInfo {
remote_addr: None,
transport_kind: TransportKind::Tcp,
};
Ok((stream, info))
}
}
#[tokio::test]
async fn transport_trait_object() {
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
}
#[tokio::test]
async fn transport_acceptor_trait_object() {
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
}
#[tokio::test]
async fn transport_connect_returns_stream() {
let t = MockTransport;
let _stream = t.connect().await.unwrap();
}
#[tokio::test]
async fn transport_describe_returns_string() {
let t = MockTransport;
assert_eq!(t.describe(), "mock");
}
#[tokio::test]
async fn acceptor_accept_returns_stream_and_info() {
let a = MockAcceptor;
let (_, info) = a.accept().await.unwrap();
assert!(info.remote_addr.is_none());
assert!(matches!(info.transport_kind, TransportKind::Tcp));
}
#[test]
fn transport_kind_variants() {
let tcp = TransportKind::Tcp;
let tls = TransportKind::Tls {
server_name: Some("example.com".to_string()),
};
let iroh = TransportKind::Iroh {
endpoint_id: "abc123".to_string(),
};
if let TransportKind::Tcp = tcp {}
if let TransportKind::Tls {
server_name: Some(name),
} = tls
{
assert_eq!(name, "example.com");
}
if let TransportKind::Iroh { endpoint_id } = iroh {
assert_eq!(endpoint_id, "abc123");
}
}
}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn auth_placeholder() {}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn client_placeholder() {}

View File

@@ -0,0 +1,2 @@
#[tokio::test]
async fn server_placeholder() {}

View File

@@ -0,0 +1,26 @@
use wraith_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
#[tokio::test]
async fn mock_transport_connect() {
let transport = MockTransport::new(1024);
let stream = transport.connect().await.unwrap();
drop(stream);
}
#[tokio::test]
async fn mock_transport_acceptor_accept() {
let acceptor = MockTransportAcceptor::new(1024);
let (stream, info) = acceptor.accept().await.unwrap();
drop(stream);
drop(info);
}
#[tokio::test]
async fn mock_pair_communicates() {
let (mut client, mut server) = mock_pair(1024);
use tokio::io::{AsyncReadExt, AsyncWriteExt};
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}