Define Transport trait, TransportAcceptor trait, TransportInfo, and TransportKind types
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5461,6 +5461,7 @@ name = "wraith-core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"iroh",
|
||||
"russh",
|
||||
"rustls",
|
||||
@@ -5470,6 +5471,7 @@ dependencies = [
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"wraith-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -11,6 +11,8 @@ default = []
|
||||
tls = ["dep:tokio-rustls", "dep:rustls"]
|
||||
iroh = ["dep:iroh"]
|
||||
acme = ["dep:rustls-acme", "tls"]
|
||||
testutil = []
|
||||
transport-traits = []
|
||||
|
||||
[dependencies]
|
||||
russh = "0.49"
|
||||
@@ -23,3 +25,7 @@ tokio-rustls = { version = "0.26", optional = true }
|
||||
rustls = { version = "0.23", optional = true }
|
||||
rustls-acme = { version = "0.12", optional = true }
|
||||
iroh = { version = "0.34", optional = true }
|
||||
async-trait = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
wraith-core = { path = ".", features = ["testutil"] }
|
||||
@@ -9,3 +9,4 @@ pub mod error;
|
||||
pub mod testutil;
|
||||
|
||||
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
|
||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
141
crates/wraith-core/src/testutil.rs
Normal file
141
crates/wraith-core/src/testutil.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
|
||||
use anyhow::Result;
|
||||
|
||||
#[cfg(feature = "transport-traits")]
|
||||
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
mod local_traits {
|
||||
use std::net::SocketAddr;
|
||||
use anyhow::Result;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls { server_name: Option<String> },
|
||||
Iroh { endpoint_id: String },
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockStream {
|
||||
inner: DuplexStream,
|
||||
}
|
||||
|
||||
impl MockStream {
|
||||
pub fn new(inner: DuplexStream) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for MockStream {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MockStream {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for MockStream {}
|
||||
|
||||
pub struct MockTransport {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransport {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (client, _) = tokio::io::duplex(self.buf_size);
|
||||
Ok(MockStream::new(client))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockTransportAcceptor {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransportAcceptor {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportAcceptor for MockTransportAcceptor {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (_, server) = tokio::io::duplex(self.buf_size);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((MockStream::new(server), info))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
|
||||
let (client, server) = tokio::io::duplex(buf_size);
|
||||
(MockStream::new(client), MockStream::new(server))
|
||||
}
|
||||
139
crates/wraith-core/src/transport/mod.rs
Normal file
139
crates/wraith-core/src/transport/mod.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
/// Metadata about an incoming transport connection.
|
||||
///
|
||||
/// Carries the remote address (if available) and the kind of transport
|
||||
/// used. The server handler uses this for logging and auth decisions.
|
||||
/// See ADR-001 for the pluggable transport rationale and ADR-004
|
||||
/// for why SSH runs entirely over the transport stream.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
/// The kind of transport that produced a connection.
|
||||
///
|
||||
/// Each variant identifies the transport mechanism. Used by the
|
||||
/// server handler for logging and authorization decisions.
|
||||
/// See ADR-001 and ADR-004.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls {
|
||||
server_name: Option<String>,
|
||||
},
|
||||
Iroh {
|
||||
endpoint_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, DuplexStream};
|
||||
|
||||
struct MockTransport;
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (stream, _) = duplex(1024);
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct MockAcceptor;
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for MockAcceptor {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (stream, _) = duplex(1024);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_trait_object() {
|
||||
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_acceptor_trait_object() {
|
||||
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_connect_returns_stream() {
|
||||
let t = MockTransport;
|
||||
let _stream = t.connect().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_describe_returns_string() {
|
||||
let t = MockTransport;
|
||||
assert_eq!(t.describe(), "mock");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_accept_returns_stream_and_info() {
|
||||
let a = MockAcceptor;
|
||||
let (_, info) = a.accept().await.unwrap();
|
||||
assert!(info.remote_addr.is_none());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_variants() {
|
||||
let tcp = TransportKind::Tcp;
|
||||
let tls = TransportKind::Tls {
|
||||
server_name: Some("example.com".to_string()),
|
||||
};
|
||||
let iroh = TransportKind::Iroh {
|
||||
endpoint_id: "abc123".to_string(),
|
||||
};
|
||||
|
||||
if let TransportKind::Tcp = tcp {}
|
||||
if let TransportKind::Tls {
|
||||
server_name: Some(name),
|
||||
} = tls
|
||||
{
|
||||
assert_eq!(name, "example.com");
|
||||
}
|
||||
if let TransportKind::Iroh { endpoint_id } = iroh {
|
||||
assert_eq!(endpoint_id, "abc123");
|
||||
}
|
||||
}
|
||||
}
|
||||
2
crates/wraith-core/tests/auth_tests.rs
Normal file
2
crates/wraith-core/tests/auth_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[tokio::test]
|
||||
async fn auth_placeholder() {}
|
||||
2
crates/wraith-core/tests/client_tests.rs
Normal file
2
crates/wraith-core/tests/client_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[tokio::test]
|
||||
async fn client_placeholder() {}
|
||||
2
crates/wraith-core/tests/server_tests.rs
Normal file
2
crates/wraith-core/tests/server_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[tokio::test]
|
||||
async fn server_placeholder() {}
|
||||
26
crates/wraith-core/tests/transport_tests.rs
Normal file
26
crates/wraith-core/tests/transport_tests.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use wraith_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_connect() {
|
||||
let transport = MockTransport::new(1024);
|
||||
let stream = transport.connect().await.unwrap();
|
||||
drop(stream);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_acceptor_accept() {
|
||||
let acceptor = MockTransportAcceptor::new(1024);
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
drop(stream);
|
||||
drop(info);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_pair_communicates() {
|
||||
let (mut client, mut server) = mock_pair(1024);
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
Reference in New Issue
Block a user