diff --git a/Cargo.lock b/Cargo.lock index 4c33ef4..fa16080 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5461,6 +5461,7 @@ name = "wraith-core" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "iroh", "russh", "rustls", @@ -5470,6 +5471,7 @@ dependencies = [ "tokio-rustls", "tokio-util", "tracing", + "wraith-core", ] [[package]] diff --git a/crates/wraith-core/Cargo.toml b/crates/wraith-core/Cargo.toml index cbe38a5..86ccdc0 100644 --- a/crates/wraith-core/Cargo.toml +++ b/crates/wraith-core/Cargo.toml @@ -11,6 +11,8 @@ default = [] tls = ["dep:tokio-rustls", "dep:rustls"] iroh = ["dep:iroh"] acme = ["dep:rustls-acme", "tls"] +testutil = [] +transport-traits = [] [dependencies] russh = "0.49" @@ -22,4 +24,8 @@ tokio-util = { version = "0.7", features = ["compat"] } tokio-rustls = { version = "0.26", optional = true } rustls = { version = "0.23", optional = true } rustls-acme = { version = "0.12", optional = true } -iroh = { version = "0.34", optional = true } \ No newline at end of file +iroh = { version = "0.34", optional = true } +async-trait = "0.1" + +[dev-dependencies] +wraith-core = { path = ".", features = ["testutil"] } \ No newline at end of file diff --git a/crates/wraith-core/src/lib.rs b/crates/wraith-core/src/lib.rs index 88111fc..349b0ca 100644 --- a/crates/wraith-core/src/lib.rs +++ b/crates/wraith-core/src/lib.rs @@ -8,4 +8,5 @@ pub mod error; #[cfg(feature = "testutil")] pub mod testutil; -pub use error::{AuthError, ChannelError, ConfigError, TransportError}; \ No newline at end of file +pub use error::{AuthError, ChannelError, ConfigError, TransportError}; +pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; \ No newline at end of file diff --git a/crates/wraith-core/src/testutil.rs b/crates/wraith-core/src/testutil.rs new file mode 100644 index 0000000..62e8ae4 --- /dev/null +++ b/crates/wraith-core/src/testutil.rs @@ -0,0 +1,141 @@ +use tokio::io::{DuplexStream, AsyncRead, AsyncWrite}; +use anyhow::Result; + +#[cfg(feature = "transport-traits")] +pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; + +#[cfg(not(feature = "transport-traits"))] +pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind}; + +#[cfg(not(feature = "transport-traits"))] +mod local_traits { + use std::net::SocketAddr; + use anyhow::Result; + use tokio::io::{AsyncRead, AsyncWrite}; + use async_trait::async_trait; + + #[async_trait] + pub trait Transport: Send + Sync + 'static { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + async fn connect(&self) -> Result; + fn describe(&self) -> String; + } + + #[async_trait] + pub trait TransportAcceptor: Send + Sync + 'static { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>; + } + + #[derive(Debug, Clone)] + pub struct TransportInfo { + pub remote_addr: Option, + pub transport_kind: TransportKind, + } + + #[derive(Debug, Clone)] + pub enum TransportKind { + Tcp, + Tls { server_name: Option }, + Iroh { endpoint_id: String }, + } +} + +pub struct MockStream { + inner: DuplexStream, +} + +impl MockStream { + pub fn new(inner: DuplexStream) -> Self { + Self { inner } + } +} + +impl AsyncRead for MockStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for MockStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } +} + +impl Unpin for MockStream {} + +pub struct MockTransport { + buf_size: usize, +} + +impl MockTransport { + pub fn new(buf_size: usize) -> Self { + Self { buf_size } + } +} + +#[async_trait::async_trait] +impl Transport for MockTransport { + type Stream = MockStream; + + async fn connect(&self) -> Result { + let (client, _) = tokio::io::duplex(self.buf_size); + Ok(MockStream::new(client)) + } + + fn describe(&self) -> String { + "mock".to_string() + } +} + +pub struct MockTransportAcceptor { + buf_size: usize, +} + +impl MockTransportAcceptor { + pub fn new(buf_size: usize) -> Self { + Self { buf_size } + } +} + +#[async_trait::async_trait] +impl TransportAcceptor for MockTransportAcceptor { + type Stream = MockStream; + + async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> { + let (_, server) = tokio::io::duplex(self.buf_size); + let info = TransportInfo { + remote_addr: None, + transport_kind: TransportKind::Tcp, + }; + Ok((MockStream::new(server), info)) + } +} + +pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) { + let (client, server) = tokio::io::duplex(buf_size); + (MockStream::new(client), MockStream::new(server)) +} \ No newline at end of file diff --git a/crates/wraith-core/src/transport.rs b/crates/wraith-core/src/transport.rs deleted file mode 100644 index e69de29..0000000 diff --git a/crates/wraith-core/src/transport/mod.rs b/crates/wraith-core/src/transport/mod.rs new file mode 100644 index 0000000..fcdcaf4 --- /dev/null +++ b/crates/wraith-core/src/transport/mod.rs @@ -0,0 +1,139 @@ +use std::net::SocketAddr; + +use anyhow::Result; +use async_trait::async_trait; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[async_trait] +pub trait Transport: Send + Sync + 'static { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + async fn connect(&self) -> Result; + + fn describe(&self) -> String; +} + +#[async_trait] +pub trait TransportAcceptor: Send + Sync + 'static { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>; +} + +/// Metadata about an incoming transport connection. +/// +/// Carries the remote address (if available) and the kind of transport +/// used. The server handler uses this for logging and auth decisions. +/// See ADR-001 for the pluggable transport rationale and ADR-004 +/// for why SSH runs entirely over the transport stream. +#[derive(Debug, Clone)] +pub struct TransportInfo { + pub remote_addr: Option, + pub transport_kind: TransportKind, +} + +/// The kind of transport that produced a connection. +/// +/// Each variant identifies the transport mechanism. Used by the +/// server handler for logging and authorization decisions. +/// See ADR-001 and ADR-004. +#[derive(Debug, Clone)] +pub enum TransportKind { + Tcp, + Tls { + server_name: Option, + }, + Iroh { + endpoint_id: String, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{duplex, DuplexStream}; + + struct MockTransport; + + #[async_trait] + impl Transport for MockTransport { + type Stream = DuplexStream; + + async fn connect(&self) -> Result { + let (stream, _) = duplex(1024); + Ok(stream) + } + + fn describe(&self) -> String { + "mock".to_string() + } + } + + struct MockAcceptor; + + #[async_trait] + impl TransportAcceptor for MockAcceptor { + type Stream = DuplexStream; + + async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> { + let (stream, _) = duplex(1024); + let info = TransportInfo { + remote_addr: None, + transport_kind: TransportKind::Tcp, + }; + Ok((stream, info)) + } + } + + #[tokio::test] + async fn transport_trait_object() { + let _boxed: Box> = Box::new(MockTransport); + } + + #[tokio::test] + async fn transport_acceptor_trait_object() { + let _boxed: Box> = Box::new(MockAcceptor); + } + + #[tokio::test] + async fn transport_connect_returns_stream() { + let t = MockTransport; + let _stream = t.connect().await.unwrap(); + } + + #[tokio::test] + async fn transport_describe_returns_string() { + let t = MockTransport; + assert_eq!(t.describe(), "mock"); + } + + #[tokio::test] + async fn acceptor_accept_returns_stream_and_info() { + let a = MockAcceptor; + let (_, info) = a.accept().await.unwrap(); + assert!(info.remote_addr.is_none()); + assert!(matches!(info.transport_kind, TransportKind::Tcp)); + } + + #[test] + fn transport_kind_variants() { + let tcp = TransportKind::Tcp; + let tls = TransportKind::Tls { + server_name: Some("example.com".to_string()), + }; + let iroh = TransportKind::Iroh { + endpoint_id: "abc123".to_string(), + }; + + if let TransportKind::Tcp = tcp {} + if let TransportKind::Tls { + server_name: Some(name), + } = tls + { + assert_eq!(name, "example.com"); + } + if let TransportKind::Iroh { endpoint_id } = iroh { + assert_eq!(endpoint_id, "abc123"); + } + } +} \ No newline at end of file diff --git a/crates/wraith-core/tests/auth_tests.rs b/crates/wraith-core/tests/auth_tests.rs new file mode 100644 index 0000000..4504cca --- /dev/null +++ b/crates/wraith-core/tests/auth_tests.rs @@ -0,0 +1,2 @@ +#[tokio::test] +async fn auth_placeholder() {} \ No newline at end of file diff --git a/crates/wraith-core/tests/client_tests.rs b/crates/wraith-core/tests/client_tests.rs new file mode 100644 index 0000000..2276dcd --- /dev/null +++ b/crates/wraith-core/tests/client_tests.rs @@ -0,0 +1,2 @@ +#[tokio::test] +async fn client_placeholder() {} \ No newline at end of file diff --git a/crates/wraith-core/tests/server_tests.rs b/crates/wraith-core/tests/server_tests.rs new file mode 100644 index 0000000..5de0852 --- /dev/null +++ b/crates/wraith-core/tests/server_tests.rs @@ -0,0 +1,2 @@ +#[tokio::test] +async fn server_placeholder() {} \ No newline at end of file diff --git a/crates/wraith-core/tests/transport_tests.rs b/crates/wraith-core/tests/transport_tests.rs new file mode 100644 index 0000000..10548cc --- /dev/null +++ b/crates/wraith-core/tests/transport_tests.rs @@ -0,0 +1,26 @@ +use wraith_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair}; + +#[tokio::test] +async fn mock_transport_connect() { + let transport = MockTransport::new(1024); + let stream = transport.connect().await.unwrap(); + drop(stream); +} + +#[tokio::test] +async fn mock_transport_acceptor_accept() { + let acceptor = MockTransportAcceptor::new(1024); + let (stream, info) = acceptor.accept().await.unwrap(); + drop(stream); + drop(info); +} + +#[tokio::test] +async fn mock_pair_communicates() { + let (mut client, mut server) = mock_pair(1024); + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + client.write_all(b"hello").await.unwrap(); + let mut buf = [0u8; 5]; + server.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello"); +} \ No newline at end of file