Define Transport trait, TransportAcceptor trait, TransportInfo, and TransportKind types
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5461,6 +5461,7 @@ name = "wraith-core"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"iroh",
|
"iroh",
|
||||||
"russh",
|
"russh",
|
||||||
"rustls",
|
"rustls",
|
||||||
@@ -5470,6 +5471,7 @@ dependencies = [
|
|||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"wraith-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ default = []
|
|||||||
tls = ["dep:tokio-rustls", "dep:rustls"]
|
tls = ["dep:tokio-rustls", "dep:rustls"]
|
||||||
iroh = ["dep:iroh"]
|
iroh = ["dep:iroh"]
|
||||||
acme = ["dep:rustls-acme", "tls"]
|
acme = ["dep:rustls-acme", "tls"]
|
||||||
|
testutil = []
|
||||||
|
transport-traits = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
russh = "0.49"
|
russh = "0.49"
|
||||||
@@ -22,4 +24,8 @@ tokio-util = { version = "0.7", features = ["compat"] }
|
|||||||
tokio-rustls = { version = "0.26", optional = true }
|
tokio-rustls = { version = "0.26", optional = true }
|
||||||
rustls = { version = "0.23", optional = true }
|
rustls = { version = "0.23", optional = true }
|
||||||
rustls-acme = { version = "0.12", optional = true }
|
rustls-acme = { version = "0.12", optional = true }
|
||||||
iroh = { version = "0.34", optional = true }
|
iroh = { version = "0.34", optional = true }
|
||||||
|
async-trait = "0.1"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
wraith-core = { path = ".", features = ["testutil"] }
|
||||||
@@ -8,4 +8,5 @@ pub mod error;
|
|||||||
#[cfg(feature = "testutil")]
|
#[cfg(feature = "testutil")]
|
||||||
pub mod testutil;
|
pub mod testutil;
|
||||||
|
|
||||||
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
|
pub use error::{AuthError, ChannelError, ConfigError, TransportError};
|
||||||
|
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
141
crates/wraith-core/src/testutil.rs
Normal file
141
crates/wraith-core/src/testutil.rs
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
#[cfg(feature = "transport-traits")]
|
||||||
|
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "transport-traits"))]
|
||||||
|
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "transport-traits"))]
|
||||||
|
mod local_traits {
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use anyhow::Result;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Transport: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
async fn connect(&self) -> Result<Self::Stream>;
|
||||||
|
fn describe(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TransportInfo {
|
||||||
|
pub remote_addr: Option<SocketAddr>,
|
||||||
|
pub transport_kind: TransportKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum TransportKind {
|
||||||
|
Tcp,
|
||||||
|
Tls { server_name: Option<String> },
|
||||||
|
Iroh { endpoint_id: String },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockStream {
|
||||||
|
inner: DuplexStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockStream {
|
||||||
|
pub fn new(inner: DuplexStream) -> Self {
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for MockStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for MockStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> std::task::Poll<std::io::Result<usize>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<std::io::Result<()>> {
|
||||||
|
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Unpin for MockStream {}
|
||||||
|
|
||||||
|
pub struct MockTransport {
|
||||||
|
buf_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockTransport {
|
||||||
|
pub fn new(buf_size: usize) -> Self {
|
||||||
|
Self { buf_size }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for MockTransport {
|
||||||
|
type Stream = MockStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> Result<Self::Stream> {
|
||||||
|
let (client, _) = tokio::io::duplex(self.buf_size);
|
||||||
|
Ok(MockStream::new(client))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"mock".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockTransportAcceptor {
|
||||||
|
buf_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockTransportAcceptor {
|
||||||
|
pub fn new(buf_size: usize) -> Self {
|
||||||
|
Self { buf_size }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TransportAcceptor for MockTransportAcceptor {
|
||||||
|
type Stream = MockStream;
|
||||||
|
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||||
|
let (_, server) = tokio::io::duplex(self.buf_size);
|
||||||
|
let info = TransportInfo {
|
||||||
|
remote_addr: None,
|
||||||
|
transport_kind: TransportKind::Tcp,
|
||||||
|
};
|
||||||
|
Ok((MockStream::new(server), info))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
|
||||||
|
let (client, server) = tokio::io::duplex(buf_size);
|
||||||
|
(MockStream::new(client), MockStream::new(server))
|
||||||
|
}
|
||||||
139
crates/wraith-core/src/transport/mod.rs
Normal file
139
crates/wraith-core/src/transport/mod.rs
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Transport: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
|
async fn connect(&self) -> Result<Self::Stream>;
|
||||||
|
|
||||||
|
fn describe(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Metadata about an incoming transport connection.
|
||||||
|
///
|
||||||
|
/// Carries the remote address (if available) and the kind of transport
|
||||||
|
/// used. The server handler uses this for logging and auth decisions.
|
||||||
|
/// See ADR-001 for the pluggable transport rationale and ADR-004
|
||||||
|
/// for why SSH runs entirely over the transport stream.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TransportInfo {
|
||||||
|
pub remote_addr: Option<SocketAddr>,
|
||||||
|
pub transport_kind: TransportKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The kind of transport that produced a connection.
|
||||||
|
///
|
||||||
|
/// Each variant identifies the transport mechanism. Used by the
|
||||||
|
/// server handler for logging and authorization decisions.
|
||||||
|
/// See ADR-001 and ADR-004.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum TransportKind {
|
||||||
|
Tcp,
|
||||||
|
Tls {
|
||||||
|
server_name: Option<String>,
|
||||||
|
},
|
||||||
|
Iroh {
|
||||||
|
endpoint_id: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::{duplex, DuplexStream};
|
||||||
|
|
||||||
|
struct MockTransport;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Transport for MockTransport {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn connect(&self) -> Result<Self::Stream> {
|
||||||
|
let (stream, _) = duplex(1024);
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn describe(&self) -> String {
|
||||||
|
"mock".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MockAcceptor;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TransportAcceptor for MockAcceptor {
|
||||||
|
type Stream = DuplexStream;
|
||||||
|
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||||
|
let (stream, _) = duplex(1024);
|
||||||
|
let info = TransportInfo {
|
||||||
|
remote_addr: None,
|
||||||
|
transport_kind: TransportKind::Tcp,
|
||||||
|
};
|
||||||
|
Ok((stream, info))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_trait_object() {
|
||||||
|
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_acceptor_trait_object() {
|
||||||
|
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_connect_returns_stream() {
|
||||||
|
let t = MockTransport;
|
||||||
|
let _stream = t.connect().await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn transport_describe_returns_string() {
|
||||||
|
let t = MockTransport;
|
||||||
|
assert_eq!(t.describe(), "mock");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn acceptor_accept_returns_stream_and_info() {
|
||||||
|
let a = MockAcceptor;
|
||||||
|
let (_, info) = a.accept().await.unwrap();
|
||||||
|
assert!(info.remote_addr.is_none());
|
||||||
|
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn transport_kind_variants() {
|
||||||
|
let tcp = TransportKind::Tcp;
|
||||||
|
let tls = TransportKind::Tls {
|
||||||
|
server_name: Some("example.com".to_string()),
|
||||||
|
};
|
||||||
|
let iroh = TransportKind::Iroh {
|
||||||
|
endpoint_id: "abc123".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let TransportKind::Tcp = tcp {}
|
||||||
|
if let TransportKind::Tls {
|
||||||
|
server_name: Some(name),
|
||||||
|
} = tls
|
||||||
|
{
|
||||||
|
assert_eq!(name, "example.com");
|
||||||
|
}
|
||||||
|
if let TransportKind::Iroh { endpoint_id } = iroh {
|
||||||
|
assert_eq!(endpoint_id, "abc123");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
2
crates/wraith-core/tests/auth_tests.rs
Normal file
2
crates/wraith-core/tests/auth_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
#[tokio::test]
|
||||||
|
async fn auth_placeholder() {}
|
||||||
2
crates/wraith-core/tests/client_tests.rs
Normal file
2
crates/wraith-core/tests/client_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
#[tokio::test]
|
||||||
|
async fn client_placeholder() {}
|
||||||
2
crates/wraith-core/tests/server_tests.rs
Normal file
2
crates/wraith-core/tests/server_tests.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
#[tokio::test]
|
||||||
|
async fn server_placeholder() {}
|
||||||
26
crates/wraith-core/tests/transport_tests.rs
Normal file
26
crates/wraith-core/tests/transport_tests.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use wraith_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mock_transport_connect() {
|
||||||
|
let transport = MockTransport::new(1024);
|
||||||
|
let stream = transport.connect().await.unwrap();
|
||||||
|
drop(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mock_transport_acceptor_accept() {
|
||||||
|
let acceptor = MockTransportAcceptor::new(1024);
|
||||||
|
let (stream, info) = acceptor.accept().await.unwrap();
|
||||||
|
drop(stream);
|
||||||
|
drop(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mock_pair_communicates() {
|
||||||
|
let (mut client, mut server) = mock_pair(1024);
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
client.write_all(b"hello").await.unwrap();
|
||||||
|
let mut buf = [0u8; 5];
|
||||||
|
server.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, b"hello");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user