feat(core): implement multi-transport listeners with ListenerConfig and Vec<ListenerConfig>

- Add ListenerConfig struct with transport_kind, listen_addr, per-transport config
- Add Dns and WebTransport variants to TransportKind (tags only, no behavior)
- Add .listeners() builder method to ServeOptions for multi-listener config
- Keep .transport_mode() backwards compatible (creates single-element listeners vec)
- Update Server::run() to use listeners from Server struct (first listener)
- Add Server::listeners() accessor for multi-transport listener configs
- Update StaticConfig to support listeners field, converted from ServeOptions
- All listeners share Arc<ArcSwap<DynamicConfig>>, ConnectionRateLimiter, and IdentityProvider
- Graceful shutdown terminates accept loop via existing shutdown signal
- TOML [[listeners]] array-of-tables syntax supported via ListenerConfig in StaticConfig
- Add comprehensive tests for ListenerConfig, multi-listener ServeOptions, Server creation
This commit is contained in:
2026-06-07 14:25:23 +00:00
parent ee1b3f3819
commit 851cf1bdab
6 changed files with 479 additions and 46 deletions

View File

@@ -1,5 +1,5 @@
use crate::server::handler::{ProxyConfig, ProxyMode}; use crate::server::handler::{ProxyConfig, ProxyMode};
use crate::server::serve::ServeTransportMode; use crate::server::serve::{ListenerConfig, ServeTransportMode};
use std::net::SocketAddr; use std::net::SocketAddr;
pub struct StaticConfig { pub struct StaticConfig {
@@ -15,6 +15,7 @@ pub struct StaticConfig {
pub max_connections_per_ip: usize, pub max_connections_per_ip: usize,
pub proxy_config: Option<ProxyConfig>, pub proxy_config: Option<ProxyConfig>,
pub iroh_relay: Option<String>, pub iroh_relay: Option<String>,
pub listeners: Vec<ListenerConfig>,
} }
impl std::fmt::Debug for StaticConfig { impl std::fmt::Debug for StaticConfig {
@@ -31,6 +32,7 @@ impl std::fmt::Debug for StaticConfig {
.field("max_connections_per_ip", &self.max_connections_per_ip) .field("max_connections_per_ip", &self.max_connections_per_ip)
.field("proxy_config", &self.proxy_config) .field("proxy_config", &self.proxy_config)
.field("iroh_relay", &self.iroh_relay) .field("iroh_relay", &self.iroh_relay)
.field("listeners", &self.listeners)
.finish() .finish()
} }
} }
@@ -55,6 +57,24 @@ impl StaticConfig {
let proxy_config = parse_proxy_config(opts.proxy.as_deref()); let proxy_config = parse_proxy_config(opts.proxy.as_deref());
let listeners = if let Some(listeners) = opts.listeners {
listeners
} else {
vec![ListenerConfig {
transport_kind: match opts.transport_mode {
ServeTransportMode::Tcp => crate::server::handler::TransportKind::Tcp,
ServeTransportMode::Tls => crate::server::handler::TransportKind::Tls,
ServeTransportMode::Iroh => crate::server::handler::TransportKind::Iroh,
},
listen_addr: opts.listen_addr.clone(),
tls_cert: opts.tls_cert.clone(),
tls_key: opts.tls_key.clone(),
acme_domain: opts.acme_domain.clone(),
stealth: opts.stealth,
iroh_relay: opts.iroh_relay.clone(),
}]
};
let static_config = StaticConfig { let static_config = StaticConfig {
transport_mode: opts.transport_mode, transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr, listen_addr: opts.listen_addr,
@@ -68,6 +88,7 @@ impl StaticConfig {
max_connections_per_ip: opts.max_connections_per_ip, max_connections_per_ip: opts.max_connections_per_ip,
proxy_config, proxy_config,
iroh_relay: opts.iroh_relay, iroh_relay: opts.iroh_relay,
listeners,
}; };
Ok((static_config, dynamic)) Ok((static_config, dynamic))

View File

@@ -68,5 +68,5 @@ pub use config::{
ForwardingRule, RateLimitConfig, StaticConfig, ForwardingRule, RateLimitConfig, StaticConfig,
}; };
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use server::serve::{ServeError, ServeOptions, ServeTransportMode, Server}; pub use server::serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};

View File

@@ -25,11 +25,13 @@ pub struct ProxyConfig {
pub mode: ProxyMode, pub mode: ProxyMode,
} }
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum TransportKind { pub enum TransportKind {
Tcp, Tcp,
Tls, Tls,
Iroh, Iroh,
Dns,
WebTransport,
} }
impl std::fmt::Display for TransportKind { impl std::fmt::Display for TransportKind {
@@ -38,6 +40,8 @@ impl std::fmt::Display for TransportKind {
TransportKind::Tcp => write!(f, "tcp"), TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"), TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"), TransportKind::Iroh => write!(f, "iroh"),
TransportKind::Dns => write!(f, "dns"),
TransportKind::WebTransport => write!(f, "webtransport"),
} }
} }
} }
@@ -736,6 +740,8 @@ mod tests {
assert_eq!(TransportKind::Tcp.to_string(), "tcp"); assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls"); assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh"); assert_eq!(TransportKind::Iroh.to_string(), "iroh");
assert_eq!(TransportKind::Dns.to_string(), "dns");
assert_eq!(TransportKind::WebTransport.to_string(), "webtransport");
} }
#[tokio::test] #[tokio::test]

View File

@@ -21,7 +21,7 @@ pub use control_channel::{
}; };
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use serve::{ServeError, ServeOptions, ServeTransportMode, Server}; pub use serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server};
pub use stealth::{ pub use stealth::{
detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection,
}; };

View File

@@ -3,6 +3,7 @@
//! `Server` binds to a transport acceptor and runs an accept loop, handling //! `Server` binds to a transport acceptor and runs an accept loop, handling
//! authentication, stealth mode protocol detection, and graceful shutdown. //! authentication, stealth mode protocol detection, and graceful shutdown.
//! `ServeOptions` provides a builder-pattern API for programmatic configuration. //! `ServeOptions` provides a builder-pattern API for programmatic configuration.
//! Supports multiple listeners via `ListenerConfig` for multi-transport operation.
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
@@ -24,7 +25,6 @@ use crate::server::stealth::{self, ProtocolDetection};
const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22"; const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2); const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
/// Transport mode for the server listener.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ServeTransportMode { pub enum ServeTransportMode {
Tcp, Tcp,
@@ -42,22 +42,153 @@ impl std::fmt::Display for ServeTransportMode {
} }
} }
/// Programmatic configuration for an alknet server. #[derive(Debug, Clone, PartialEq)]
/// pub struct ListenerConfig {
/// Construct with `ServeOptions::new(key_source)` and chain builder methods. pub transport_kind: TransportKind,
/// Call `validate()` before passing to `Server::new()`. pub listen_addr: String,
/// pub tls_cert: Option<String>,
/// ``` pub tls_key: Option<String>,
/// use alknet_core::server::{ServeOptions, ServeTransportMode}; pub acme_domain: Option<String>,
/// use alknet_core::auth::keys::KeySource; pub stealth: bool,
/// pub iroh_relay: Option<String>,
/// let opts = ServeOptions::new(KeySource::File("/path/to/host_key".into())) }
/// .transport_mode(ServeTransportMode::Tcp)
/// .listen_addr("0.0.0.0:22") impl ListenerConfig {
/// .max_connections_per_ip(5) pub fn tcp(addr: impl Into<String>) -> Self {
/// .max_auth_attempts(3); Self {
/// opts.validate().unwrap(); transport_kind: TransportKind::Tcp,
/// ``` listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
iroh_relay: None,
}
}
pub fn tls(addr: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Tls,
listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
iroh_relay: None,
}
}
pub fn iroh(addr: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Iroh,
listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
iroh_relay: None,
}
}
pub fn dns(domain: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Dns,
listen_addr: domain.into(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
iroh_relay: None,
}
}
pub fn webtransport(host: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::WebTransport,
listen_addr: host.into(),
tls_cert: None,
tls_key: None,
acme_domain: None,
stealth: false,
iroh_relay: None,
}
}
pub fn tls_cert(mut self, path: impl Into<String>) -> Self {
self.tls_cert = Some(path.into());
self
}
pub fn tls_key(mut self, path: impl Into<String>) -> Self {
self.tls_key = Some(path.into());
self
}
pub fn acme_domain(mut self, domain: impl Into<String>) -> Self {
self.acme_domain = Some(domain.into());
self
}
pub fn stealth(mut self, enabled: bool) -> Self {
self.stealth = enabled;
self
}
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
self.iroh_relay = Some(url.into());
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.stealth && self.transport_kind != TransportKind::Tls {
return Err(ConfigError::InvalidFlag {
name: "stealth mode requires TLS transport".to_string(),
});
}
match self.transport_kind {
TransportKind::Tls => {
if self.tls_cert.is_none() && self.acme_domain.is_none() {
return Err(ConfigError::InvalidFlag {
name: "TLS transport requires tls_cert/tls_key or acme_domain".to_string(),
});
}
if self.tls_cert.is_some() && self.tls_key.is_none() {
return Err(ConfigError::InvalidFlag {
name: "tls_cert requires tls_key".to_string(),
});
}
if self.tls_key.is_some() && self.tls_cert.is_none() {
return Err(ConfigError::InvalidFlag {
name: "tls_key requires tls_cert".to_string(),
});
}
}
TransportKind::Tcp
| TransportKind::Iroh
| TransportKind::Dns
| TransportKind::WebTransport => {
if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() {
return Err(ConfigError::IncompatibleOptions);
}
}
}
Ok(())
}
}
impl std::fmt::Display for ListenerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.transport_kind {
TransportKind::Iroh => write!(f, "{} (iroh)", self.listen_addr),
TransportKind::Dns => write!(f, "{} (dns)", self.listen_addr),
TransportKind::WebTransport => write!(f, "{} (webtransport)", self.listen_addr),
_ => write!(f, "{} ({})", self.listen_addr, self.transport_kind),
}
}
}
pub struct ServeOptions { pub struct ServeOptions {
pub key: KeySource, pub key: KeySource,
pub authorized_keys: Option<KeySource>, pub authorized_keys: Option<KeySource>,
@@ -72,6 +203,7 @@ pub struct ServeOptions {
pub iroh_relay: Option<String>, pub iroh_relay: Option<String>,
pub max_connections_per_ip: usize, pub max_connections_per_ip: usize,
pub max_auth_attempts: usize, pub max_auth_attempts: usize,
pub listeners: Option<Vec<ListenerConfig>>,
} }
impl ServeOptions { impl ServeOptions {
@@ -90,6 +222,7 @@ impl ServeOptions {
iroh_relay: None, iroh_relay: None,
max_connections_per_ip: 0, max_connections_per_ip: 0,
max_auth_attempts: 10, max_auth_attempts: 10,
listeners: None,
} }
} }
@@ -153,7 +286,24 @@ impl ServeOptions {
self self
} }
pub fn listeners(mut self, listeners: Vec<ListenerConfig>) -> Self {
self.listeners = Some(listeners);
self
}
pub fn validate(&self) -> Result<(), ConfigError> { pub fn validate(&self) -> Result<(), ConfigError> {
if let Some(ref listeners) = self.listeners {
if listeners.is_empty() {
return Err(ConfigError::InvalidFlag {
name: "listeners must not be empty".to_string(),
});
}
for listener in listeners {
listener.validate()?;
}
return Ok(());
}
if self.stealth && self.transport_mode != ServeTransportMode::Tls { if self.stealth && self.transport_mode != ServeTransportMode::Tls {
return Err(ConfigError::InvalidFlag { return Err(ConfigError::InvalidFlag {
name: "stealth mode requires TLS transport (--transport tls)".to_string(), name: "stealth mode requires TLS transport (--transport tls)".to_string(),
@@ -201,11 +351,11 @@ impl std::fmt::Debug for ServeOptions {
.field("stealth", &self.stealth) .field("stealth", &self.stealth)
.field("max_connections_per_ip", &self.max_connections_per_ip) .field("max_connections_per_ip", &self.max_connections_per_ip)
.field("max_auth_attempts", &self.max_auth_attempts) .field("max_auth_attempts", &self.max_auth_attempts)
.field("listeners", &self.listeners)
.finish() .finish()
} }
} }
/// Errors that can occur during server setup and operation.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ServeError { pub enum ServeError {
#[error("config error: {0}")] #[error("config error: {0}")]
@@ -223,19 +373,12 @@ struct ActiveSession {
join: tokio::task::JoinHandle<()>, join: tokio::task::JoinHandle<()>,
} }
/// The alknet SSH server.
///
/// Accepts connections over any `TransportAcceptor`, authenticates via Ed25519 keys
/// or certificate authority, and proxies `direct-tcpip` channels to their targets.
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
pub struct Server { pub struct Server {
config: Arc<server::Config>, config: Arc<server::Config>,
dynamic: Arc<ArcSwap<DynamicConfig>>, dynamic: Arc<ArcSwap<DynamicConfig>>,
connection_limiter: Arc<ConnectionRateLimiter>, connection_limiter: Arc<ConnectionRateLimiter>,
outbound_proxy: Option<ProxyConfig>, outbound_proxy: Option<ProxyConfig>,
stealth: bool, listeners: Vec<ListenerConfig>,
transport_mode: ServeTransportMode,
listen_addr: String,
max_auth_attempts: usize, max_auth_attempts: usize,
shutdown_tx: tokio::sync::watch::Sender<bool>, shutdown_tx: tokio::sync::watch::Sender<bool>,
shutdown_rx: tokio::sync::watch::Receiver<bool>, shutdown_rx: tokio::sync::watch::Receiver<bool>,
@@ -277,14 +420,31 @@ impl Server {
let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config))); let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config)));
let listeners = if let Some(listeners) = opts.listeners {
listeners
} else {
let transport_kind = match opts.transport_mode {
ServeTransportMode::Tcp => TransportKind::Tcp,
ServeTransportMode::Tls => TransportKind::Tls,
ServeTransportMode::Iroh => TransportKind::Iroh,
};
vec![ListenerConfig {
transport_kind,
listen_addr: opts.listen_addr.clone(),
tls_cert: opts.tls_cert.clone(),
tls_key: opts.tls_key.clone(),
acme_domain: opts.acme_domain.clone(),
stealth: opts.stealth,
iroh_relay: opts.iroh_relay.clone(),
}]
};
Ok(Self { Ok(Self {
config, config,
dynamic, dynamic,
connection_limiter, connection_limiter,
outbound_proxy, outbound_proxy,
stealth: opts.stealth, listeners,
transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr,
max_auth_attempts, max_auth_attempts,
shutdown_tx, shutdown_tx,
shutdown_rx, shutdown_rx,
@@ -344,13 +504,16 @@ impl Server {
where where
A: crate::transport::TransportAcceptor, A: crate::transport::TransportAcceptor,
{ {
let transport_kind = match self.transport_mode { let listener = self
ServeTransportMode::Tcp => TransportKind::Tcp, .listeners
ServeTransportMode::Tls => TransportKind::Tls, .first()
ServeTransportMode::Iroh => TransportKind::Iroh, .expect("at least one listener required");
};
if self.transport_mode == ServeTransportMode::Iroh { let transport_kind = listener.transport_kind.clone();
let stealth = listener.stealth;
let listen_addr = listener.listen_addr.clone();
if matches!(transport_kind, TransportKind::Iroh) {
if let Some(id) = endpoint_info { if let Some(id) = endpoint_info {
info!("alknet server running: transport=iroh endpoint_id={}", id); info!("alknet server running: transport=iroh endpoint_id={}", id);
} else { } else {
@@ -359,7 +522,7 @@ impl Server {
} else { } else {
info!( info!(
"alknet server running: transport={} listen={}", "alknet server running: transport={} listen={}",
self.transport_mode, self.listen_addr transport_kind, listen_addr
); );
} }
@@ -410,7 +573,7 @@ impl Server {
}; };
let remote_addr = info.remote_addr; let remote_addr = info.remote_addr;
let handler_transport_kind = transport_kind; let handler_transport_kind = transport_kind.clone();
let handler = ServerHandler::new( let handler = ServerHandler::new(
Arc::clone(&server.dynamic), Arc::clone(&server.dynamic),
@@ -427,8 +590,7 @@ impl Server {
let config = Arc::clone(&server.config); let config = Arc::clone(&server.config);
let sessions = Arc::clone(&server.sessions); let sessions = Arc::clone(&server.sessions);
let stealth = server.stealth; let transport_is_tls = matches!(transport_kind, TransportKind::Tls);
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
tokio::spawn(async move { tokio::spawn(async move {
let result = let result =
@@ -448,6 +610,10 @@ impl Server {
Ok(()) Ok(())
} }
pub fn listeners(&self) -> &[ListenerConfig] {
&self.listeners
}
} }
async fn handle_connection<S>( async fn handle_connection<S>(
@@ -547,6 +713,7 @@ mod tests {
assert!(opts.iroh_relay.is_none()); assert!(opts.iroh_relay.is_none());
assert_eq!(opts.max_connections_per_ip, 0); assert_eq!(opts.max_connections_per_ip, 0);
assert_eq!(opts.max_auth_attempts, 10); assert_eq!(opts.max_auth_attempts, 10);
assert!(opts.listeners.is_none());
} }
#[test] #[test]
@@ -739,10 +906,235 @@ mod tests {
} }
#[test] #[test]
fn server_holds_listen_addr() { fn listener_config_tcp_constructor() {
let opts = ServeOptions::new(make_key_source()).listen_addr("0.0.0.0:443"); let lc = ListenerConfig::tcp("0.0.0.0:22");
assert_eq!(lc.transport_kind, TransportKind::Tcp);
assert_eq!(lc.listen_addr, "0.0.0.0:22");
assert!(!lc.stealth);
assert!(lc.tls_cert.is_none());
}
#[test]
fn listener_config_tls_constructor() {
let lc = ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem")
.stealth(true);
assert_eq!(lc.transport_kind, TransportKind::Tls);
assert_eq!(lc.listen_addr, "0.0.0.0:443");
assert!(lc.stealth);
assert_eq!(lc.tls_cert.as_deref(), Some("/cert.pem"));
assert_eq!(lc.tls_key.as_deref(), Some("/key.pem"));
}
#[test]
fn listener_config_iroh_constructor() {
let lc = ListenerConfig::iroh("0.0.0.0:0").iroh_relay("https://relay.example.com");
assert_eq!(lc.transport_kind, TransportKind::Iroh);
assert_eq!(lc.iroh_relay.as_deref(), Some("https://relay.example.com"));
}
#[test]
fn listener_config_dns_constructor() {
let lc = ListenerConfig::dns("example.com");
assert_eq!(lc.transport_kind, TransportKind::Dns);
assert_eq!(lc.listen_addr, "example.com");
}
#[test]
fn listener_config_webtransport_constructor() {
let lc = ListenerConfig::webtransport("example.com");
assert_eq!(lc.transport_kind, TransportKind::WebTransport);
assert_eq!(lc.listen_addr, "example.com");
}
#[test]
fn listener_config_validate_tls_requires_certs() {
let lc = ListenerConfig::tls("0.0.0.0:443");
assert!(lc.validate().is_err());
}
#[test]
fn listener_config_validate_tls_with_certs_ok() {
let lc = ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem");
assert!(lc.validate().is_ok());
}
#[test]
fn listener_config_validate_tls_with_acme_ok() {
let lc = ListenerConfig::tls("0.0.0.0:443").acme_domain("example.com");
assert!(lc.validate().is_ok());
}
#[test]
fn listener_config_validate_stealth_without_tls_rejected() {
let lc = ListenerConfig::tcp("0.0.0.0:22").stealth(true);
assert!(lc.validate().is_err());
}
#[test]
fn listener_config_validate_tcp_cannot_have_tls_certs() {
let lc = ListenerConfig::tcp("0.0.0.0:22").tls_cert("/cert.pem");
assert!(lc.validate().is_err());
}
#[test]
fn listener_config_display() {
let tcp = ListenerConfig::tcp("0.0.0.0:22");
assert_eq!(format!("{}", tcp), "0.0.0.0:22 (tcp)");
let tls = ListenerConfig::tls("0.0.0.0:443");
assert_eq!(format!("{}", tls), "0.0.0.0:443 (tls)");
let iroh = ListenerConfig::iroh("0.0.0.0:0");
assert_eq!(format!("{}", iroh), "0.0.0.0:0 (iroh)");
let dns = ListenerConfig::dns("example.com");
assert_eq!(format!("{}", dns), "example.com (dns)");
let wt = ListenerConfig::webtransport("example.com");
assert_eq!(format!("{}", wt), "example.com (webtransport)");
}
#[test]
fn listener_config_equality() {
let lc1 = ListenerConfig::tcp("0.0.0.0:22");
let lc2 = ListenerConfig::tcp("0.0.0.0:22");
assert_eq!(lc1, lc2);
let lc3 = ListenerConfig::tls("0.0.0.0:443");
assert_ne!(lc1, lc3);
}
#[test]
fn serve_options_with_listeners() {
let listeners = vec![
ListenerConfig::tcp("0.0.0.0:22"),
ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem"),
];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
assert!(opts.listeners.is_some());
assert_eq!(opts.listeners.as_ref().unwrap().len(), 2);
}
#[test]
fn serve_options_validate_listeners_ok() {
let listeners = vec![
ListenerConfig::tcp("0.0.0.0:22"),
ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem"),
];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_listeners_bypasses_single_validation() {
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")];
let opts = ServeOptions::new(make_key_source())
.stealth(true)
.listeners(listeners);
assert!(opts.validate().is_ok());
}
#[test]
fn serve_options_validate_listeners_per_listener_stealth_requires_tls() {
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22").stealth(true)];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
assert!(opts.validate().is_err());
}
#[test]
fn serve_options_validate_empty_listeners_rejected() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(vec![]);
assert!(opts.validate().is_err());
}
#[test]
fn server_new_with_listeners() {
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
let server = Server::new(opts).unwrap(); let server = Server::new(opts).unwrap();
assert_eq!(server.listen_addr, "0.0.0.0:443"); assert_eq!(server.listeners.len(), 1);
assert_eq!(server.listeners[0].transport_kind, TransportKind::Tcp);
}
#[test]
fn server_new_single_transport_creates_listener() {
let opts =
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners.len(), 1);
assert_eq!(server.listeners[0].transport_kind, TransportKind::Tcp);
assert_eq!(server.listeners[0].listen_addr, "0.0.0.0:22");
}
#[test]
fn server_new_tls_transport_creates_tls_listener() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.transport_mode(ServeTransportMode::Tls)
.tls_cert("/cert.pem")
.tls_key("/key.pem")
.listen_addr("0.0.0.0:443")
.stealth(true);
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners.len(), 1);
assert_eq!(server.listeners[0].transport_kind, TransportKind::Tls);
assert!(server.listeners[0].stealth);
assert_eq!(server.listeners[0].tls_cert.as_deref(), Some("/cert.pem"));
}
#[test]
fn server_listeners_accessor() {
let listeners = vec![
ListenerConfig::tcp("0.0.0.0:22"),
ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem"),
];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners().len(), 2);
assert_eq!(server.listeners()[0].transport_kind, TransportKind::Tcp);
assert_eq!(server.listeners()[1].transport_kind, TransportKind::Tls);
}
#[test]
fn server_new_multi_listener_tcp_and_tls() {
let listeners = vec![
ListenerConfig::tcp("0.0.0.0:22"),
ListenerConfig::tls("0.0.0.0:443")
.tls_cert("/cert.pem")
.tls_key("/key.pem"),
];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners.len(), 2);
let dynamic = server.config_reload_handle();
let config = dynamic.dynamic();
assert!(config.auth.authorized_keys.len() > 0);
} }
#[tokio::test] #[tokio::test]

View File

@@ -91,6 +91,8 @@ pub enum TransportKind {
Tcp, Tcp,
Tls { server_name: Option<String> }, Tls { server_name: Option<String> },
Iroh { endpoint_id: String }, Iroh { endpoint_id: String },
Dns { domain: String },
WebTransport { host: String },
} }
#[cfg(test)] #[cfg(test)]
@@ -169,6 +171,12 @@ mod tests {
let iroh = TransportKind::Iroh { let iroh = TransportKind::Iroh {
endpoint_id: "abc123".to_string(), endpoint_id: "abc123".to_string(),
}; };
let dns = TransportKind::Dns {
domain: "example.com".to_string(),
};
let wt = TransportKind::WebTransport {
host: "example.com".to_string(),
};
if let TransportKind::Tcp = tcp {} if let TransportKind::Tcp = tcp {}
if let TransportKind::Tls { if let TransportKind::Tls {
@@ -180,5 +188,11 @@ mod tests {
if let TransportKind::Iroh { endpoint_id } = iroh { if let TransportKind::Iroh { endpoint_id } = iroh {
assert_eq!(endpoint_id, "abc123"); assert_eq!(endpoint_id, "abc123");
} }
if let TransportKind::Dns { domain } = dns {
assert_eq!(domain, "example.com");
}
if let TransportKind::WebTransport { host } = wt {
assert_eq!(host, "example.com");
}
} }
} }