diff --git a/crates/alknet-core/src/config/forwarding.rs b/crates/alknet-core/src/config/forwarding.rs index 553c0e3..8a1899a 100644 --- a/crates/alknet-core/src/config/forwarding.rs +++ b/crates/alknet-core/src/config/forwarding.rs @@ -499,17 +499,13 @@ mod tests { target: TargetPattern::AlknetPrefix, action: ForwardingAction::Allow, principals: vec![], - transports: vec![TransportKind::WebTransport { - host: String::new(), - }], + transports: vec![TransportKind::WebTransport { server_name: None }], }, ForwardingRule { target: TargetPattern::Any, action: ForwardingAction::Deny, principals: vec![], - transports: vec![TransportKind::WebTransport { - host: String::new(), - }], + transports: vec![TransportKind::WebTransport { server_name: None }], }, ], }; @@ -518,17 +514,13 @@ mod tests { "alknet-control", 0, &identity, - TransportKind::WebTransport { - host: String::new() - } + TransportKind::WebTransport { server_name: None } )); assert!(!policy.check( "example.com", 443, &identity, - TransportKind::WebTransport { - host: String::new() - } + TransportKind::WebTransport { server_name: None } )); assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp)); } diff --git a/crates/alknet-core/src/config/static_config.rs b/crates/alknet-core/src/config/static_config.rs index f8eb54c..c412c30 100644 --- a/crates/alknet-core/src/config/static_config.rs +++ b/crates/alknet-core/src/config/static_config.rs @@ -2,9 +2,9 @@ //! //! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md). -use crate::interface::InterfaceKind; +use crate::interface::StreamInterfaceKind; use crate::server::handler::{ProxyConfig, ProxyMode}; -use crate::server::serve::{ListenerConfig, ServeTransportMode}; +use crate::server::serve::{ListenerConfig, ServeTransportMode, StreamListenerConfig}; use crate::transport::TransportKind; use std::net::SocketAddr; @@ -66,21 +66,23 @@ impl StaticConfig { let listeners = if let Some(listeners) = opts.listeners { listeners } else { - vec![ListenerConfig { - transport_kind: match opts.transport_mode { - ServeTransportMode::Tcp => TransportKind::Tcp, - ServeTransportMode::Tls => TransportKind::Tls { server_name: None }, - ServeTransportMode::Iroh => TransportKind::Iroh { - endpoint_id: String::new(), + vec![ListenerConfig::Stream { + config: StreamListenerConfig { + transport_kind: match opts.transport_mode { + ServeTransportMode::Tcp => TransportKind::Tcp, + ServeTransportMode::Tls => TransportKind::Tls { server_name: None }, + ServeTransportMode::Iroh => TransportKind::Iroh { + endpoint_id: String::new(), + }, }, + interface: StreamInterfaceKind::Ssh, + listen_addr: opts.listen_addr.clone(), + tls_cert: opts.tls_cert.clone(), + tls_key: opts.tls_key.clone(), + acme_domain: opts.acme_domain.clone(), + stealth: opts.stealth, + iroh_relay: opts.iroh_relay.clone(), }, - interface_kind: InterfaceKind::Ssh, - listen_addr: opts.listen_addr.clone(), - tls_cert: opts.tls_cert.clone(), - tls_key: opts.tls_key.clone(), - acme_domain: opts.acme_domain.clone(), - stealth: opts.stealth, - iroh_relay: opts.iroh_relay.clone(), }] }; @@ -111,23 +113,23 @@ fn parse_proxy_config( None => Ok(None), Some(url) => { if let Some(rest) = url.strip_prefix("socks5://") { - let addr: SocketAddr = rest.parse().map_err(|e| { - crate::error::ConfigError::ProxyConfigInvalid { - message: format!("invalid socks5 proxy address '{}': {}", rest, e), - } - })?; + let addr: SocketAddr = + rest.parse() + .map_err(|e| crate::error::ConfigError::ProxyConfigInvalid { + message: format!("invalid socks5 proxy address '{}': {}", rest, e), + })?; Ok(Some(ProxyConfig { mode: ProxyMode::Socks5(addr), })) } else if let Some(rest) = url.strip_prefix("http://") { - let addr: SocketAddr = rest.parse().map_err(|e| { - crate::error::ConfigError::ProxyConfigInvalid { - message: format!( - "invalid http connect proxy address '{}': {}", - rest, e - ), - } - })?; + let addr: SocketAddr = + rest.parse() + .map_err(|e| crate::error::ConfigError::ProxyConfigInvalid { + message: format!( + "invalid http connect proxy address '{}': {}", + rest, e + ), + })?; Ok(Some(ProxyConfig { mode: ProxyMode::HttpConnect(addr), })) @@ -239,10 +241,12 @@ mod tests { .listeners(listeners); let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap(); assert_eq!(static_config.listeners.len(), 1); - assert_eq!( - static_config.listeners[0].transport_kind, - TransportKind::Tcp - ); + match &static_config.listeners[0] { + ListenerConfig::Stream { config } => { + assert_eq!(config.transport_kind, TransportKind::Tcp); + } + _ => panic!("expected Stream variant"), + } } #[test] diff --git a/crates/alknet-core/src/interface/config.rs b/crates/alknet-core/src/interface/config.rs index 9e8cf09..250a5af 100644 --- a/crates/alknet-core/src/interface/config.rs +++ b/crates/alknet-core/src/interface/config.rs @@ -2,22 +2,39 @@ use std::sync::Arc; use arc_swap::ArcSwap; use russh::keys::PrivateKey; +use serde::{Deserialize, Serialize}; use crate::auth::IdentityProvider; use crate::config::DynamicConfig; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[non_exhaustive] -pub enum InterfaceKind { +pub enum StreamInterfaceKind { Ssh, RawFraming, } -impl std::fmt::Display for InterfaceKind { +impl std::fmt::Display for StreamInterfaceKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - InterfaceKind::Ssh => write!(f, "ssh"), - InterfaceKind::RawFraming => write!(f, "raw-framing"), + StreamInterfaceKind::Ssh => write!(f, "ssh"), + StreamInterfaceKind::RawFraming => write!(f, "raw-framing"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum MessageInterfaceKind { + Http, + Dns, +} + +impl std::fmt::Display for MessageInterfaceKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageInterfaceKind::Http => write!(f, "http"), + MessageInterfaceKind::Dns => write!(f, "dns"), } } } @@ -29,12 +46,61 @@ pub enum InterfaceConfig { } impl InterfaceConfig { - pub fn kind(&self) -> InterfaceKind { + pub fn kind(&self) -> StreamInterfaceKind { #[allow(unreachable_patterns)] match self { - InterfaceConfig::Ssh(_) => InterfaceKind::Ssh, - InterfaceConfig::RawFraming(_) => InterfaceKind::RawFraming, - _ => InterfaceKind::Ssh, + InterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh, + InterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming, + _ => StreamInterfaceKind::Ssh, + } + } +} + +#[non_exhaustive] +pub enum StreamInterfaceConfig { + Ssh(SshInterfaceConfig), + RawFraming(RawFramingConfig), +} + +impl StreamInterfaceConfig { + pub fn kind(&self) -> StreamInterfaceKind { + match self { + StreamInterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh, + StreamInterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming, + } + } +} + +impl std::fmt::Display for StreamInterfaceConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + StreamInterfaceConfig::Ssh(_) => write!(f, "ssh"), + StreamInterfaceConfig::RawFraming(_) => write!(f, "raw-framing"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum MessageInterfaceConfig { + Http(HttpInterfaceConfig), + Dns(DnsInterfaceConfig), +} + +impl MessageInterfaceConfig { + pub fn kind(&self) -> MessageInterfaceKind { + match self { + MessageInterfaceConfig::Http(_) => MessageInterfaceKind::Http, + MessageInterfaceConfig::Dns(_) => MessageInterfaceKind::Dns, + } + } +} + +impl std::fmt::Display for MessageInterfaceConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageInterfaceConfig::Http(_) => write!(f, "http"), + MessageInterfaceConfig::Dns(_) => write!(f, "dns"), } } } @@ -47,22 +113,53 @@ pub struct SshInterfaceConfig { pub struct RawFramingConfig {} +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct HttpInterfaceConfig { + pub bind_addr: std::net::SocketAddr, + pub tls: bool, + pub stealth: bool, +} + +impl std::fmt::Display for HttpInterfaceConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "http {}", self.bind_addr) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DnsInterfaceConfig { + pub bind_addr: std::net::SocketAddr, + pub tls: bool, +} + +impl std::fmt::Display for DnsInterfaceConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "dns {}", self.bind_addr) + } +} + #[cfg(test)] mod tests { use super::*; #[test] - fn interface_kind_display() { - assert_eq!(InterfaceKind::Ssh.to_string(), "ssh"); - assert_eq!(InterfaceKind::RawFraming.to_string(), "raw-framing"); + fn stream_interface_kind_display() { + assert_eq!(StreamInterfaceKind::Ssh.to_string(), "ssh"); + assert_eq!(StreamInterfaceKind::RawFraming.to_string(), "raw-framing"); } #[test] - fn interface_kind_from_config() { + fn message_interface_kind_display() { + assert_eq!(MessageInterfaceKind::Http.to_string(), "http"); + assert_eq!(MessageInterfaceKind::Dns.to_string(), "dns"); + } + + #[test] + fn stream_interface_config_kind() { let auth = Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new( ArcSwap::new(Arc::new(DynamicConfig::default())), ))); - let ssh_config = InterfaceConfig::Ssh(SshInterfaceConfig { + let ssh_config = StreamInterfaceConfig::Ssh(SshInterfaceConfig { auth, forwarding: Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))), host_key: Arc::new( @@ -73,21 +170,91 @@ mod tests { .unwrap(), ), }); - assert_eq!(ssh_config.kind(), InterfaceKind::Ssh); + assert_eq!(ssh_config.kind(), StreamInterfaceKind::Ssh); - let raw_config = InterfaceConfig::RawFraming(RawFramingConfig {}); - assert_eq!(raw_config.kind(), InterfaceKind::RawFraming); + let raw_config = StreamInterfaceConfig::RawFraming(RawFramingConfig {}); + assert_eq!(raw_config.kind(), StreamInterfaceKind::RawFraming); } #[test] - fn interface_kind_equality() { - assert_eq!(InterfaceKind::Ssh, InterfaceKind::Ssh); - assert_eq!(InterfaceKind::RawFraming, InterfaceKind::RawFraming); - assert_ne!(InterfaceKind::Ssh, InterfaceKind::RawFraming); + fn message_interface_config_kind() { + let http_config = MessageInterfaceConfig::Http(HttpInterfaceConfig { + bind_addr: "127.0.0.1:8080".parse().unwrap(), + tls: false, + stealth: false, + }); + assert_eq!(http_config.kind(), MessageInterfaceKind::Http); + + let dns_config = MessageInterfaceConfig::Dns(DnsInterfaceConfig { + bind_addr: "127.0.0.1:53".parse().unwrap(), + tls: false, + }); + assert_eq!(dns_config.kind(), MessageInterfaceKind::Dns); + } + + #[test] + fn stream_interface_kind_equality() { + assert_eq!(StreamInterfaceKind::Ssh, StreamInterfaceKind::Ssh); + assert_eq!( + StreamInterfaceKind::RawFraming, + StreamInterfaceKind::RawFraming + ); + assert_ne!(StreamInterfaceKind::Ssh, StreamInterfaceKind::RawFraming); + } + + #[test] + fn message_interface_kind_equality() { + assert_eq!(MessageInterfaceKind::Http, MessageInterfaceKind::Http); + assert_eq!(MessageInterfaceKind::Dns, MessageInterfaceKind::Dns); + assert_ne!(MessageInterfaceKind::Http, MessageInterfaceKind::Dns); } #[test] fn raw_framing_config_minimal() { let _config = RawFramingConfig {}; } + + #[test] + fn http_interface_config_display() { + let config = HttpInterfaceConfig { + bind_addr: "127.0.0.1:8080".parse().unwrap(), + tls: true, + stealth: true, + }; + assert_eq!(config.to_string(), "http 127.0.0.1:8080"); + } + + #[test] + fn dns_interface_config_display() { + let config = DnsInterfaceConfig { + bind_addr: "127.0.0.1:53".parse().unwrap(), + tls: false, + }; + assert_eq!(config.to_string(), "dns 127.0.0.1:53"); + } + + #[test] + fn http_interface_config_serialization() { + let config = HttpInterfaceConfig { + bind_addr: "127.0.0.1:8080".parse().unwrap(), + tls: true, + stealth: false, + }; + let serialized = serde_json::to_string(&config).unwrap(); + let deserialized: HttpInterfaceConfig = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.bind_addr, config.bind_addr); + assert_eq!(deserialized.tls, config.tls); + } + + #[test] + fn dns_interface_config_serialization() { + let config = DnsInterfaceConfig { + bind_addr: "0.0.0.0:53".parse().unwrap(), + tls: true, + }; + let serialized = serde_json::to_string(&config).unwrap(); + let deserialized: DnsInterfaceConfig = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.bind_addr, config.bind_addr); + assert_eq!(deserialized.tls, config.tls); + } } diff --git a/crates/alknet-core/src/interface/dns.rs b/crates/alknet-core/src/interface/dns.rs new file mode 100644 index 0000000..da6413d --- /dev/null +++ b/crates/alknet-core/src/interface/dns.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; + +use crate::call::OperationEnv; +use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface}; + +pub struct DnsInterface { + pub domain: String, + pub identity_provider: Arc, + pub registry: Arc, + pub env: OperationEnv, +} + +#[async_trait] +impl MessageInterface for DnsInterface { + async fn handle_request(&self, _request: InterfaceRequest) -> Result { + Ok(InterfaceResponse { + result: Err(crate::call::CallError::new( + "NOT_IMPLEMENTED", + "DnsInterface is not yet implemented", + false, + )), + status: 501, + headers: std::collections::HashMap::new(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dns_interface_type_exists() { + let registry = Arc::new(crate::call::OperationRegistry::new()); + let _iface = DnsInterface { + domain: "alk.dev".to_string(), + identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new( + arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())), + ))), + env: OperationEnv::local(crate::call::OperationRegistry::new()), + registry, + }; + } +} diff --git a/crates/alknet-core/src/interface/http.rs b/crates/alknet-core/src/interface/http.rs new file mode 100644 index 0000000..109ed2f --- /dev/null +++ b/crates/alknet-core/src/interface/http.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; + +use crate::call::OperationEnv; +use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface}; + +pub struct HttpInterface { + pub identity_provider: Arc, + pub registry: Arc, + pub env: OperationEnv, +} + +#[async_trait] +impl MessageInterface for HttpInterface { + async fn handle_request(&self, _request: InterfaceRequest) -> Result { + Ok(InterfaceResponse { + result: Err(crate::call::CallError::new( + "NOT_IMPLEMENTED", + "HttpInterface is not yet implemented", + false, + )), + status: 501, + headers: std::collections::HashMap::new(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn http_interface_type_exists() { + let registry = Arc::new(crate::call::OperationRegistry::new()); + let _iface = HttpInterface { + identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new( + arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())), + ))), + env: OperationEnv::local(crate::call::OperationRegistry::new()), + registry, + }; + } +} diff --git a/crates/alknet-core/src/interface/mod.rs b/crates/alknet-core/src/interface/mod.rs index c6f6d1a..0184b1b 100644 --- a/crates/alknet-core/src/interface/mod.rs +++ b/crates/alknet-core/src/interface/mod.rs @@ -1,37 +1,37 @@ -//! Interface layer (Layer 2) of the three-layer model (ADR-026). +//! Interface layer (Layer 2) of the three-layer model (ADR-026, ADR-035). //! //! The Interface layer sits between Transport (Layer 1) and Protocol (Layer 3). -//! An Interface consumes a `TransportStream` and produces call protocol sessions -//! that yield `EventEnvelope` frames. This enables the call protocol handler to be -//! interface-agnostic — it receives `InterfaceEvent` frames from any interface. +//! It has two distinct patterns: //! -//! SSH is an interface, not a transport. It wraps a byte stream in session -//! semantics (handshake, auth, channel multiplexing). Raw framing (4-byte length -//! prefix + JSON `EventEnvelope`) is another interface, one without SSH overhead. +//! - **StreamInterface** — consumes a `TransportStream`, produces a long-lived +//! `Session` that yields `InterfaceEvent` frames. SSH and raw framing are +//! `StreamInterface` implementations. //! -//! # OQ-IF-01 Resolution -//! -//! Every Interface session implements the `InterfaceSession` trait, which provides -//! `recv()` and `send()` methods producing and consuming `InterfaceEvent` frames. -//! Each `InterfaceEvent` carries an `EventEnvelope` and an optional `Identity` -//! (authenticated by the interface layer, e.g., via SSH public key auth or -//! transport-level token auth). -//! -//! This means the call protocol handler (Layer 3) is completely interface-agnostic: -//! it receives `InterfaceEvent` frames and processes them uniformly, regardless -//! of whether they arrived over SSH or raw framing. +//! - **MessageInterface** — handles individual `InterfaceRequest` → +//! `InterfaceResponse` pairs. Manages its own transport (HTTP server, DNS +//! server). HTTP and DNS are `MessageInterface` implementations. pub mod config; +pub mod dns; +pub mod http; pub mod pairs; pub mod raw_framing; pub mod session; pub mod ssh; +use std::collections::HashMap; + use anyhow::Result; use async_trait::async_trait; use tokio::io::{AsyncRead, AsyncWrite}; -pub use config::{InterfaceConfig, InterfaceKind, RawFramingConfig, SshInterfaceConfig}; +pub use config::{ + DnsInterfaceConfig, HttpInterfaceConfig, InterfaceConfig, MessageInterfaceConfig, + MessageInterfaceKind, RawFramingConfig, SshInterfaceConfig, StreamInterfaceConfig, + StreamInterfaceKind, +}; +pub use dns::DnsInterface; +pub use http::HttpInterface; pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS}; pub use raw_framing::{RawFramingInterface, RawFramingSession}; pub use session::{InterfaceEvent, InterfaceSession}; @@ -42,16 +42,36 @@ pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {} impl TransportStream for T {} #[async_trait] -pub trait Interface: Send + Sync + 'static { +pub trait StreamInterface: Send + Sync + 'static { type Session: InterfaceSession; async fn accept( &self, stream: Box, - config: &InterfaceConfig, + config: &StreamInterfaceConfig, ) -> Result; } +#[async_trait] +pub trait MessageInterface: Send + Sync + 'static { + async fn handle_request(&self, request: InterfaceRequest) -> Result; +} + +#[derive(Debug, Clone)] +pub struct InterfaceRequest { + pub operation_path: String, + pub input: serde_json::Value, + pub auth_token: Option, + pub metadata: HashMap, +} + +#[derive(Debug, Clone)] +pub struct InterfaceResponse { + pub result: Result, + pub status: u16, + pub headers: HashMap, +} + #[cfg(test)] mod tests { use super::*; @@ -69,4 +89,52 @@ mod tests { let _boxed: Box = Box::new(server); let _: Box = Box::new(client); } + + #[test] + fn interface_request_fields() { + let req = InterfaceRequest { + operation_path: "/v1/head/auth/verify".to_string(), + input: serde_json::json!({"key": "value"}), + auth_token: None, + metadata: HashMap::new(), + }; + assert_eq!(req.operation_path, "/v1/head/auth/verify"); + assert!(req.auth_token.is_none()); + } + + #[test] + fn interface_response_fields() { + let resp = InterfaceResponse { + result: Ok(serde_json::json!({"status": "ok"})), + status: 200, + headers: HashMap::new(), + }; + assert_eq!(resp.status, 200); + } + + struct MockMessageInterface; + + #[async_trait] + impl MessageInterface for MockMessageInterface { + async fn handle_request(&self, _request: InterfaceRequest) -> Result { + Ok(InterfaceResponse { + result: Ok(serde_json::json!({})), + status: 200, + headers: HashMap::new(), + }) + } + } + + #[tokio::test] + async fn message_interface_trait_compiles() { + let iface = MockMessageInterface; + let req = InterfaceRequest { + operation_path: "/test".to_string(), + input: serde_json::json!({}), + auth_token: None, + metadata: HashMap::new(), + }; + let resp = iface.handle_request(req).await.unwrap(); + assert_eq!(resp.status, 200); + } } diff --git a/crates/alknet-core/src/interface/pairs.rs b/crates/alknet-core/src/interface/pairs.rs index 3047c8b..8d2c00b 100644 --- a/crates/alknet-core/src/interface/pairs.rs +++ b/crates/alknet-core/src/interface/pairs.rs @@ -1,13 +1,12 @@ use crate::transport::TransportKind; -use super::config::InterfaceKind; +use super::config::StreamInterfaceKind; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TransportKindBase { Tcp, Tls, Iroh, - Dns, WebTransport, } @@ -16,33 +15,36 @@ fn transport_base(kind: &TransportKind) -> TransportKindBase { TransportKind::Tcp => TransportKindBase::Tcp, TransportKind::Tls { .. } => TransportKindBase::Tls, TransportKind::Iroh { .. } => TransportKindBase::Iroh, - TransportKind::Dns { .. } => TransportKindBase::Dns, TransportKind::WebTransport { .. } => TransportKindBase::WebTransport, } } -pub fn is_valid_pair(transport: &TransportKind, interface: InterfaceKind) -> bool { +pub fn is_valid_pair(transport: &TransportKind, interface: StreamInterfaceKind) -> bool { let base = transport_base(transport); matches!( (base, interface), - (TransportKindBase::Tcp, InterfaceKind::Ssh) - | (TransportKindBase::Tls, InterfaceKind::Ssh) - | (TransportKindBase::Iroh, InterfaceKind::Ssh) - | (TransportKindBase::Dns, InterfaceKind::RawFraming) - | (TransportKindBase::WebTransport, InterfaceKind::Ssh) - | (TransportKindBase::WebTransport, InterfaceKind::RawFraming) - | (TransportKindBase::Tcp, InterfaceKind::RawFraming) + (TransportKindBase::Tcp, StreamInterfaceKind::Ssh) + | (TransportKindBase::Tls, StreamInterfaceKind::Ssh) + | (TransportKindBase::Iroh, StreamInterfaceKind::Ssh) + | (TransportKindBase::WebTransport, StreamInterfaceKind::Ssh) + | ( + TransportKindBase::WebTransport, + StreamInterfaceKind::RawFraming + ) + | (TransportKindBase::Tcp, StreamInterfaceKind::RawFraming) ) } -pub const VALID_TRANSPORT_INTERFACE_PAIRS: &[(TransportKindBase, InterfaceKind)] = &[ - (TransportKindBase::Tcp, InterfaceKind::Ssh), - (TransportKindBase::Tls, InterfaceKind::Ssh), - (TransportKindBase::Iroh, InterfaceKind::Ssh), - (TransportKindBase::Dns, InterfaceKind::RawFraming), - (TransportKindBase::WebTransport, InterfaceKind::Ssh), - (TransportKindBase::WebTransport, InterfaceKind::RawFraming), - (TransportKindBase::Tcp, InterfaceKind::RawFraming), +pub const VALID_TRANSPORT_INTERFACE_PAIRS: &[(TransportKindBase, StreamInterfaceKind)] = &[ + (TransportKindBase::Tcp, StreamInterfaceKind::Ssh), + (TransportKindBase::Tls, StreamInterfaceKind::Ssh), + (TransportKindBase::Iroh, StreamInterfaceKind::Ssh), + (TransportKindBase::WebTransport, StreamInterfaceKind::Ssh), + ( + TransportKindBase::WebTransport, + StreamInterfaceKind::RawFraming, + ), + (TransportKindBase::Tcp, StreamInterfaceKind::RawFraming), ]; #[cfg(test)] @@ -51,22 +53,20 @@ mod tests { #[test] fn valid_ssh_pairs() { - assert!(is_valid_pair(&TransportKind::Tcp, InterfaceKind::Ssh)); + assert!(is_valid_pair(&TransportKind::Tcp, StreamInterfaceKind::Ssh)); assert!(is_valid_pair( &TransportKind::Tls { server_name: None }, - InterfaceKind::Ssh + StreamInterfaceKind::Ssh )); assert!(is_valid_pair( &TransportKind::Iroh { endpoint_id: String::new() }, - InterfaceKind::Ssh + StreamInterfaceKind::Ssh )); assert!(is_valid_pair( - &TransportKind::WebTransport { - host: String::new() - }, - InterfaceKind::Ssh + &TransportKind::WebTransport { server_name: None }, + StreamInterfaceKind::Ssh )); } @@ -74,35 +74,21 @@ mod tests { fn valid_raw_framing_pairs() { assert!(is_valid_pair( &TransportKind::Tcp, - InterfaceKind::RawFraming + StreamInterfaceKind::RawFraming )); assert!(is_valid_pair( - &TransportKind::Dns { - domain: String::new() - }, - InterfaceKind::RawFraming - )); - assert!(is_valid_pair( - &TransportKind::WebTransport { - host: String::new() - }, - InterfaceKind::RawFraming + &TransportKind::WebTransport { server_name: None }, + StreamInterfaceKind::RawFraming )); } #[test] fn invalid_pairs() { - assert!(!is_valid_pair( - &TransportKind::Dns { - domain: String::new() - }, - InterfaceKind::Ssh - )); assert!(!is_valid_pair( &TransportKind::Iroh { endpoint_id: String::new() }, - InterfaceKind::RawFraming + StreamInterfaceKind::RawFraming )); } @@ -121,15 +107,9 @@ mod tests { }), TransportKindBase::Iroh ); - assert_eq!( - transport_base(&TransportKind::Dns { - domain: "example.com".to_string() - }), - TransportKindBase::Dns - ); assert_eq!( transport_base(&TransportKind::WebTransport { - host: "example.com".to_string() + server_name: Some("example.com".to_string()) }), TransportKindBase::WebTransport ); @@ -137,6 +117,6 @@ mod tests { #[test] fn valid_pairs_table_complete() { - assert_eq!(VALID_TRANSPORT_INTERFACE_PAIRS.len(), 7); + assert_eq!(VALID_TRANSPORT_INTERFACE_PAIRS.len(), 6); } } diff --git a/crates/alknet-core/src/interface/raw_framing.rs b/crates/alknet-core/src/interface/raw_framing.rs index 9f509bf..3563a36 100644 --- a/crates/alknet-core/src/interface/raw_framing.rs +++ b/crates/alknet-core/src/interface/raw_framing.rs @@ -2,20 +2,20 @@ use anyhow::Result; use async_trait::async_trait; use crate::interface::session::{InterfaceEvent, InterfaceSession}; -use crate::interface::{Interface, InterfaceConfig, TransportStream}; +use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream}; pub struct RawFramingInterface; pub struct RawFramingSession; #[async_trait] -impl Interface for RawFramingInterface { +impl StreamInterface for RawFramingInterface { type Session = RawFramingSession; async fn accept( &self, _stream: Box, - _config: &InterfaceConfig, + _config: &StreamInterfaceConfig, ) -> Result { Err(anyhow::anyhow!( "RawFramingInterface is not yet implemented (Phase 4+)" @@ -55,7 +55,7 @@ mod tests { let iface = RawFramingInterface; let (_client, server) = tokio::io::duplex(1024); let stream: Box = Box::new(server); - let config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); + let config = StreamInterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); let result = iface.accept(stream, &config).await; assert!(result.is_err()); } diff --git a/crates/alknet-core/src/interface/ssh.rs b/crates/alknet-core/src/interface/ssh.rs index 4974c2f..8b67241 100644 --- a/crates/alknet-core/src/interface/ssh.rs +++ b/crates/alknet-core/src/interface/ssh.rs @@ -14,7 +14,7 @@ use crate::auth::identity::{Identity, IdentityProvider}; use crate::call::EventEnvelope; use crate::config::DynamicConfig; use crate::interface::session::{InterfaceEvent, InterfaceSession}; -use crate::interface::{Interface, InterfaceConfig, TransportStream}; +use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream}; use crate::server::control_channel::{ControlChannelRouter, ALKNET_PREFIX}; use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; use crate::transport::TransportKind; @@ -553,17 +553,17 @@ impl SshInterface { } #[async_trait] -impl Interface for SshInterface { +impl StreamInterface for SshInterface { type Session = SshSession; async fn accept( &self, stream: Box, - config: &InterfaceConfig, + config: &StreamInterfaceConfig, ) -> Result { let ssh_config = match config { - InterfaceConfig::Ssh(c) => c, - InterfaceConfig::RawFraming(_) => { + StreamInterfaceConfig::Ssh(c) => c, + StreamInterfaceConfig::RawFraming(_) => { return Err(anyhow::anyhow!("SshInterface received RawFramingConfig")); } }; @@ -734,7 +734,7 @@ mod tests { let (_client, server) = tokio::io::duplex(1024); let stream: Box = Box::new(server); - let raw_config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); + let raw_config = StreamInterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); let result = iface.accept(stream, &raw_config).await; assert!(result.is_err()); } diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index 9686e9d..7dbbdf9 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -86,9 +86,15 @@ pub use config::{ }; pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; pub use interface::{ - is_valid_pair, Interface, InterfaceConfig, InterfaceEvent, InterfaceKind, InterfaceSession, - RawFramingConfig, RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig, - SshSession, TransportKindBase, TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS, + is_valid_pair, DnsInterface, DnsInterfaceConfig, HttpInterface, HttpInterfaceConfig, + InterfaceConfig, InterfaceEvent, InterfaceRequest, InterfaceResponse, InterfaceSession, + MessageInterface, MessageInterfaceConfig, MessageInterfaceKind, RawFramingConfig, + RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig, SshSession, + StreamInterface, StreamInterfaceConfig, StreamInterfaceKind, TransportKindBase, + TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS, +}; +pub use server::serve::{ + DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions, + ServeTransportMode, Server, StreamListenerConfig, }; -pub use server::serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server}; pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index d39a4d8..e4f0076 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -774,17 +774,7 @@ mod tests { "iroh" ); assert_eq!( - TransportKind::Dns { - domain: String::new() - } - .to_string(), - "dns" - ); - assert_eq!( - TransportKind::WebTransport { - host: String::new() - } - .to_string(), + TransportKind::WebTransport { server_name: None }.to_string(), "webtransport" ); } diff --git a/crates/alknet-core/src/server/mod.rs b/crates/alknet-core/src/server/mod.rs index 993e2b3..90cb6c1 100644 --- a/crates/alknet-core/src/server/mod.rs +++ b/crates/alknet-core/src/server/mod.rs @@ -21,7 +21,10 @@ pub use control_channel::{ }; pub use handler::{ProxyConfig, ProxyMode, ServerHandler}; pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; -pub use serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server}; +pub use serve::{ + DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions, + ServeTransportMode, Server, StreamListenerConfig, +}; pub use crate::transport::TransportKind; pub use stealth::{ diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index 85188f4..ae73cee 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -5,18 +5,20 @@ //! `ServeOptions` provides a builder-pattern API for programmatic configuration. //! Supports multiple listeners via `ListenerConfig` for multi-transport operation. +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use arc_swap::ArcSwap; use russh::server::{self, Config}; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; use crate::auth::keys::KeySource; use crate::config::{ConfigReloadHandle, DynamicConfig}; use crate::error::ConfigError; -use crate::interface::InterfaceKind; +use crate::interface::StreamInterfaceKind; use crate::server::handler::{ProxyConfig, ServerHandler}; use crate::server::rate_limit::ConnectionRateLimiter; use crate::server::stealth::{self, ProtocolDetection}; @@ -42,10 +44,10 @@ impl std::fmt::Display for ServeTransportMode { } } -#[derive(Debug, Clone, PartialEq)] -pub struct ListenerConfig { +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamListenerConfig { pub transport_kind: TransportKind, - pub interface_kind: InterfaceKind, + pub interface: StreamInterfaceKind, pub listen_addr: String, pub tls_cert: Option, pub tls_key: Option, @@ -54,103 +56,7 @@ pub struct ListenerConfig { pub iroh_relay: Option, } -impl ListenerConfig { - pub fn tcp(addr: impl Into) -> Self { - Self { - transport_kind: TransportKind::Tcp, - interface_kind: InterfaceKind::Ssh, - listen_addr: addr.into(), - tls_cert: None, - tls_key: None, - acme_domain: None, - stealth: false, - iroh_relay: None, - } - } - - pub fn tls(addr: impl Into) -> Self { - Self { - transport_kind: TransportKind::Tls { server_name: None }, - interface_kind: InterfaceKind::Ssh, - listen_addr: addr.into(), - tls_cert: None, - tls_key: None, - acme_domain: None, - stealth: false, - iroh_relay: None, - } - } - - pub fn iroh(addr: impl Into) -> Self { - Self { - transport_kind: TransportKind::Iroh { - endpoint_id: String::new(), - }, - interface_kind: InterfaceKind::Ssh, - listen_addr: addr.into(), - tls_cert: None, - tls_key: None, - acme_domain: None, - stealth: false, - iroh_relay: None, - } - } - - pub fn dns(domain: impl Into) -> Self { - Self { - transport_kind: TransportKind::Dns { - domain: String::new(), - }, - interface_kind: InterfaceKind::RawFraming, - listen_addr: domain.into(), - tls_cert: None, - tls_key: None, - acme_domain: None, - stealth: false, - iroh_relay: None, - } - } - - pub fn webtransport(host: impl Into) -> Self { - Self { - transport_kind: TransportKind::WebTransport { - host: String::new(), - }, - interface_kind: InterfaceKind::Ssh, - listen_addr: host.into(), - tls_cert: None, - tls_key: None, - acme_domain: None, - stealth: false, - iroh_relay: None, - } - } - - pub fn tls_cert(mut self, path: impl Into) -> Self { - self.tls_cert = Some(path.into()); - self - } - - pub fn tls_key(mut self, path: impl Into) -> Self { - self.tls_key = Some(path.into()); - self - } - - pub fn acme_domain(mut self, domain: impl Into) -> Self { - self.acme_domain = Some(domain.into()); - self - } - - pub fn stealth(mut self, enabled: bool) -> Self { - self.stealth = enabled; - self - } - - pub fn iroh_relay(mut self, url: impl Into) -> Self { - self.iroh_relay = Some(url.into()); - self - } - +impl StreamListenerConfig { pub fn validate(&self) -> Result<(), ConfigError> { if self.stealth && !matches!(self.transport_kind, TransportKind::Tls { .. }) { return Err(ConfigError::InvalidFlag { @@ -178,7 +84,6 @@ impl ListenerConfig { } TransportKind::Tcp | TransportKind::Iroh { .. } - | TransportKind::Dns { .. } | TransportKind::WebTransport { .. } => { if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() { return Err(ConfigError::IncompatibleOptions); @@ -190,13 +95,190 @@ impl ListenerConfig { } } -impl std::fmt::Display for ListenerConfig { +impl std::fmt::Display for StreamListenerConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.transport_kind { - TransportKind::Iroh { .. } => write!(f, "{} (iroh)", self.listen_addr), - TransportKind::Dns { .. } => write!(f, "{} (dns)", self.listen_addr), - TransportKind::WebTransport { .. } => write!(f, "{} (webtransport)", self.listen_addr), - _ => write!(f, "{} ({})", self.listen_addr, self.transport_kind), + TransportKind::Iroh { .. } => { + write!(f, "{} (iroh/{})", self.listen_addr, self.interface) + } + TransportKind::WebTransport { .. } => { + write!(f, "{} (webtransport/{})", self.listen_addr, self.interface) + } + _ => write!( + f, + "{} ({}/{})", + self.listen_addr, self.transport_kind, self.interface + ), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct HttpListenerConfig { + pub bind_addr: SocketAddr, + pub tls: bool, + pub stealth: bool, +} + +impl std::fmt::Display for HttpListenerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} (http)", self.bind_addr) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DnsListenerConfig { + pub bind_addr: SocketAddr, + pub tls: bool, +} + +impl std::fmt::Display for DnsListenerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} (dns)", self.bind_addr) + } +} + +#[derive(Debug, Clone, PartialEq)] +#[non_exhaustive] +pub enum ListenerConfig { + Stream { config: StreamListenerConfig }, + Http { config: HttpListenerConfig }, + Dns { config: DnsListenerConfig }, +} + +impl ListenerConfig { + pub fn tcp(addr: impl Into) -> Self { + Self::Stream { + config: StreamListenerConfig { + transport_kind: TransportKind::Tcp, + interface: StreamInterfaceKind::Ssh, + listen_addr: addr.into(), + tls_cert: None, + tls_key: None, + acme_domain: None, + stealth: false, + iroh_relay: None, + }, + } + } + + pub fn tls(addr: impl Into) -> Self { + Self::Stream { + config: StreamListenerConfig { + transport_kind: TransportKind::Tls { server_name: None }, + interface: StreamInterfaceKind::Ssh, + listen_addr: addr.into(), + tls_cert: None, + tls_key: None, + acme_domain: None, + stealth: false, + iroh_relay: None, + }, + } + } + + pub fn iroh(addr: impl Into) -> Self { + Self::Stream { + config: StreamListenerConfig { + transport_kind: TransportKind::Iroh { + endpoint_id: String::new(), + }, + interface: StreamInterfaceKind::Ssh, + listen_addr: addr.into(), + tls_cert: None, + tls_key: None, + acme_domain: None, + stealth: false, + iroh_relay: None, + }, + } + } + + pub fn webtransport(addr: impl Into) -> Self { + Self::Stream { + config: StreamListenerConfig { + transport_kind: TransportKind::WebTransport { server_name: None }, + interface: StreamInterfaceKind::Ssh, + listen_addr: addr.into(), + tls_cert: None, + tls_key: None, + acme_domain: None, + stealth: false, + iroh_relay: None, + }, + } + } + + pub fn http(bind_addr: SocketAddr) -> Self { + Self::Http { + config: HttpListenerConfig { + bind_addr, + tls: false, + stealth: false, + }, + } + } + + pub fn dns(bind_addr: SocketAddr) -> Self { + Self::Dns { + config: DnsListenerConfig { + bind_addr, + tls: false, + }, + } + } + + pub fn tls_cert(mut self, path: impl Into) -> Self { + if let ListenerConfig::Stream { ref mut config } = self { + config.tls_cert = Some(path.into()); + } + self + } + + pub fn tls_key(mut self, path: impl Into) -> Self { + if let ListenerConfig::Stream { ref mut config } = self { + config.tls_key = Some(path.into()); + } + self + } + + pub fn acme_domain(mut self, domain: impl Into) -> Self { + if let ListenerConfig::Stream { ref mut config } = self { + config.acme_domain = Some(domain.into()); + } + self + } + + pub fn stealth(mut self, enabled: bool) -> Self { + match &mut self { + ListenerConfig::Stream { ref mut config } => config.stealth = enabled, + ListenerConfig::Http { ref mut config } => config.stealth = enabled, + ListenerConfig::Dns { .. } => {} + } + self + } + + pub fn iroh_relay(mut self, url: impl Into) -> Self { + if let ListenerConfig::Stream { ref mut config } = self { + config.iroh_relay = Some(url.into()); + } + self + } + + pub fn validate(&self) -> Result<(), ConfigError> { + match self { + ListenerConfig::Stream { config } => config.validate(), + ListenerConfig::Http { .. } | ListenerConfig::Dns { .. } => Ok(()), + } + } +} + +impl std::fmt::Display for ListenerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ListenerConfig::Stream { config } => write!(f, "{}", config), + ListenerConfig::Http { config } => write!(f, "{}", config), + ListenerConfig::Dns { config } => write!(f, "{}", config), } } } @@ -488,9 +570,21 @@ impl Server { .first() .expect("at least one listener required"); - let transport_kind = listener.transport_kind.clone(); - let stealth = listener.stealth; - let listen_addr = listener.listen_addr.clone(); + let (transport_kind, stealth, listen_addr) = match listener { + ListenerConfig::Stream { config } => ( + config.transport_kind.clone(), + config.stealth, + config.listen_addr.clone(), + ), + ListenerConfig::Http { config } => ( + TransportKind::Tcp, + config.stealth, + config.bind_addr.to_string(), + ), + ListenerConfig::Dns { config } => { + (TransportKind::Tcp, false, config.bind_addr.to_string()) + } + }; if matches!(transport_kind, TransportKind::Iroh { .. }) { if let Some(id) = endpoint_info { @@ -832,10 +926,15 @@ mod tests { #[test] fn listener_config_tcp_constructor() { let lc = ListenerConfig::tcp("0.0.0.0:22"); - assert_eq!(lc.transport_kind, TransportKind::Tcp); - assert_eq!(lc.listen_addr, "0.0.0.0:22"); - assert!(!lc.stealth); - assert!(lc.tls_cert.is_none()); + match &lc { + ListenerConfig::Stream { config } => { + assert_eq!(config.transport_kind, TransportKind::Tcp); + assert_eq!(config.listen_addr, "0.0.0.0:22"); + assert!(!config.stealth); + assert!(config.tls_cert.is_none()); + } + _ => panic!("expected Stream variant"), + } } #[test] @@ -844,47 +943,85 @@ mod tests { .tls_cert("/cert.pem") .tls_key("/key.pem") .stealth(true); - assert_eq!(lc.transport_kind, TransportKind::Tls { server_name: None }); - assert_eq!(lc.listen_addr, "0.0.0.0:443"); - assert!(lc.stealth); - assert_eq!(lc.tls_cert.as_deref(), Some("/cert.pem")); - assert_eq!(lc.tls_key.as_deref(), Some("/key.pem")); + match &lc { + ListenerConfig::Stream { config } => { + assert_eq!( + config.transport_kind, + TransportKind::Tls { server_name: None } + ); + assert_eq!(config.listen_addr, "0.0.0.0:443"); + assert!(config.stealth); + assert_eq!(config.tls_cert.as_deref(), Some("/cert.pem")); + assert_eq!(config.tls_key.as_deref(), Some("/key.pem")); + } + _ => panic!("expected Stream variant"), + } } #[test] fn listener_config_iroh_constructor() { let lc = ListenerConfig::iroh("0.0.0.0:0").iroh_relay("https://relay.example.com"); - assert_eq!( - lc.transport_kind, - TransportKind::Iroh { - endpoint_id: String::new() + match &lc { + ListenerConfig::Stream { config } => { + assert_eq!( + config.transport_kind, + TransportKind::Iroh { + endpoint_id: String::new() + } + ); + assert_eq!( + config.iroh_relay.as_deref(), + Some("https://relay.example.com") + ); } - ); - assert_eq!(lc.iroh_relay.as_deref(), Some("https://relay.example.com")); + _ => panic!("expected Stream variant"), + } + } + + #[test] + fn listener_config_http_constructor() { + let lc = ListenerConfig::http("127.0.0.1:8080".parse().unwrap()); + match &lc { + ListenerConfig::Http { config } => { + assert_eq!( + config.bind_addr, + "127.0.0.1:8080".parse::().unwrap() + ); + assert!(!config.tls); + assert!(!config.stealth); + } + _ => panic!("expected Http variant"), + } } #[test] fn listener_config_dns_constructor() { - let lc = ListenerConfig::dns("example.com"); - assert_eq!( - lc.transport_kind, - TransportKind::Dns { - domain: String::new() + let lc = ListenerConfig::dns("127.0.0.1:53".parse().unwrap()); + match &lc { + ListenerConfig::Dns { config } => { + assert_eq!( + config.bind_addr, + "127.0.0.1:53".parse::().unwrap() + ); + assert!(!config.tls); } - ); - assert_eq!(lc.listen_addr, "example.com"); + _ => panic!("expected Dns variant"), + } } #[test] fn listener_config_webtransport_constructor() { let lc = ListenerConfig::webtransport("example.com"); - assert_eq!( - lc.transport_kind, - TransportKind::WebTransport { - host: String::new() + match &lc { + ListenerConfig::Stream { config } => { + assert_eq!( + config.transport_kind, + TransportKind::WebTransport { server_name: None } + ); + assert_eq!(config.listen_addr, "example.com"); } - ); - assert_eq!(lc.listen_addr, "example.com"); + _ => panic!("expected Stream variant"), + } } #[test] @@ -922,19 +1059,19 @@ mod tests { #[test] fn listener_config_display() { let tcp = ListenerConfig::tcp("0.0.0.0:22"); - assert_eq!(format!("{}", tcp), "0.0.0.0:22 (tcp)"); + assert_eq!(format!("{}", tcp), "0.0.0.0:22 (tcp/ssh)"); let tls = ListenerConfig::tls("0.0.0.0:443"); - assert_eq!(format!("{}", tls), "0.0.0.0:443 (tls)"); + assert_eq!(format!("{}", tls), "0.0.0.0:443 (tls/ssh)"); let iroh = ListenerConfig::iroh("0.0.0.0:0"); - assert_eq!(format!("{}", iroh), "0.0.0.0:0 (iroh)"); + assert_eq!(format!("{}", iroh), "0.0.0.0:0 (iroh/ssh)"); - let dns = ListenerConfig::dns("example.com"); - assert_eq!(format!("{}", dns), "example.com (dns)"); + let http = ListenerConfig::http("0.0.0.0:8080".parse().unwrap()); + assert_eq!(format!("{}", http), "0.0.0.0:8080 (http)"); - let wt = ListenerConfig::webtransport("example.com"); - assert_eq!(format!("{}", wt), "example.com (webtransport)"); + let dns = ListenerConfig::dns("0.0.0.0:53".parse().unwrap()); + assert_eq!(format!("{}", dns), "0.0.0.0:53 (dns)"); } #[test] @@ -1011,7 +1148,6 @@ mod tests { .listeners(listeners); let server = Server::new(opts).unwrap(); assert_eq!(server.listeners.len(), 1); - assert_eq!(server.listeners[0].transport_kind, TransportKind::Tcp); } #[test] @@ -1020,8 +1156,13 @@ mod tests { ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source()); let server = Server::new(opts).unwrap(); assert_eq!(server.listeners.len(), 1); - assert_eq!(server.listeners[0].transport_kind, TransportKind::Tcp); - assert_eq!(server.listeners[0].listen_addr, "0.0.0.0:22"); + match &server.listeners[0] { + ListenerConfig::Stream { config } => { + assert_eq!(config.transport_kind, TransportKind::Tcp); + assert_eq!(config.listen_addr, "0.0.0.0:22"); + } + _ => panic!("expected Stream variant"), + } } #[test] @@ -1035,12 +1176,17 @@ mod tests { .stealth(true); let server = Server::new(opts).unwrap(); assert_eq!(server.listeners.len(), 1); - assert_eq!( - server.listeners[0].transport_kind, - TransportKind::Tls { server_name: None } - ); - assert!(server.listeners[0].stealth); - assert_eq!(server.listeners[0].tls_cert.as_deref(), Some("/cert.pem")); + match &server.listeners[0] { + ListenerConfig::Stream { config } => { + assert_eq!( + config.transport_kind, + TransportKind::Tls { server_name: None } + ); + assert!(config.stealth); + assert_eq!(config.tls_cert.as_deref(), Some("/cert.pem")); + } + _ => panic!("expected Stream variant"), + } } #[test] @@ -1056,11 +1202,6 @@ mod tests { .listeners(listeners); let server = Server::new(opts).unwrap(); assert_eq!(server.listeners().len(), 2); - assert_eq!(server.listeners()[0].transport_kind, TransportKind::Tcp); - assert_eq!( - server.listeners()[1].transport_kind, - TransportKind::Tls { server_name: None } - ); } #[test] @@ -1113,4 +1254,48 @@ mod tests { "server should have shut down within timeout" ); } + + #[test] + fn http_listener_config_display() { + let config = HttpListenerConfig { + bind_addr: "127.0.0.1:8080".parse().unwrap(), + tls: true, + stealth: false, + }; + assert_eq!(config.to_string(), "127.0.0.1:8080 (http)"); + } + + #[test] + fn dns_listener_config_display() { + let config = DnsListenerConfig { + bind_addr: "0.0.0.0:53".parse().unwrap(), + tls: true, + }; + assert_eq!(config.to_string(), "0.0.0.0:53 (dns)"); + } + + #[test] + fn http_listener_config_serialization() { + let config = HttpListenerConfig { + bind_addr: "127.0.0.1:8080".parse().unwrap(), + tls: true, + stealth: false, + }; + let serialized = serde_json::to_string(&config).unwrap(); + let deserialized: HttpListenerConfig = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.bind_addr, config.bind_addr); + assert_eq!(deserialized.tls, config.tls); + } + + #[test] + fn dns_listener_config_serialization() { + let config = DnsListenerConfig { + bind_addr: "0.0.0.0:53".parse().unwrap(), + tls: true, + }; + let serialized = serde_json::to_string(&config).unwrap(); + let deserialized: DnsListenerConfig = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.bind_addr, config.bind_addr); + assert_eq!(deserialized.tls, config.tls); + } } diff --git a/crates/alknet-core/src/transport/mod.rs b/crates/alknet-core/src/transport/mod.rs index 5ab93b2..e46fdee 100644 --- a/crates/alknet-core/src/transport/mod.rs +++ b/crates/alknet-core/src/transport/mod.rs @@ -91,8 +91,7 @@ pub enum TransportKind { Tcp, Tls { server_name: Option }, Iroh { endpoint_id: String }, - Dns { domain: String }, - WebTransport { host: String }, + WebTransport { server_name: Option }, } impl std::fmt::Display for TransportKind { @@ -101,7 +100,7 @@ impl std::fmt::Display for TransportKind { TransportKind::Tcp => write!(f, "tcp"), TransportKind::Tls { .. } => write!(f, "tls"), TransportKind::Iroh { .. } => write!(f, "iroh"), - TransportKind::Dns { .. } => write!(f, "dns"), + TransportKind::WebTransport { .. } => write!(f, "webtransport"), } } @@ -183,11 +182,8 @@ mod tests { let iroh = TransportKind::Iroh { endpoint_id: "abc123".to_string(), }; - let dns = TransportKind::Dns { - domain: "example.com".to_string(), - }; let wt = TransportKind::WebTransport { - host: "example.com".to_string(), + server_name: Some("example.com".to_string()), }; if let TransportKind::Tcp = tcp {} @@ -200,11 +196,8 @@ mod tests { if let TransportKind::Iroh { endpoint_id } = iroh { assert_eq!(endpoint_id, "abc123"); } - if let TransportKind::Dns { domain } = dns { - assert_eq!(domain, "example.com"); - } - if let TransportKind::WebTransport { host } = wt { - assert_eq!(host, "example.com"); + if let TransportKind::WebTransport { server_name } = wt { + assert_eq!(server_name, Some("example.com".to_string())); } } } diff --git a/crates/alknet-napi/src/serve.rs b/crates/alknet-napi/src/serve.rs index 59411eb..c3eaa16 100644 --- a/crates/alknet-napi/src/serve.rs +++ b/crates/alknet-napi/src/serve.rs @@ -306,7 +306,9 @@ impl russh::server::Handler for NapiServerHandler { } let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); - let identity = self.identity_provider.resolve_from_fingerprint(&fingerprint); + let identity = self + .identity_provider + .resolve_from_fingerprint(&fingerprint); match identity { Some(id) => { @@ -339,11 +341,14 @@ impl russh::server::Handler for NapiServerHandler { return Ok(true); } - let identity = self.authenticated_identity.clone().unwrap_or_else(|| Identity { - id: String::new(), - scopes: vec![], - resources: std::collections::HashMap::new(), - }); + let identity = self + .authenticated_identity + .clone() + .unwrap_or_else(|| Identity { + id: String::new(), + scopes: vec![], + resources: std::collections::HashMap::new(), + }); let policy = self.dynamic.load(); let allowed = policy.forwarding.check( @@ -664,11 +669,8 @@ impl AlknetServer { let new_auth_policy = build_auth_policy_from_napi(&auth)?; let new_forwarding = build_forwarding_policy(&forwarding)?; let current = self.reload_handle.dynamic(); - let new_config = DynamicConfig::from_parts( - new_auth_policy, - new_forwarding, - current.rate_limits.clone(), - ); + let new_config = + DynamicConfig::from_parts(new_auth_policy, new_forwarding, current.rate_limits.clone()); self.reload_handle.reload(new_config); Ok(()) }