diff --git a/crates/alknet-core/src/config/forwarding.rs b/crates/alknet-core/src/config/forwarding.rs index 9cea36f..43a34fd 100644 --- a/crates/alknet-core/src/config/forwarding.rs +++ b/crates/alknet-core/src/config/forwarding.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use ipnetwork::IpNetwork; use crate::auth::identity::Identity; -use crate::server::handler::TransportKind; +use crate::transport::TransportKind; #[derive(Debug, Clone, PartialEq)] pub enum ForwardingAction { @@ -79,11 +79,11 @@ impl ForwardingRule { .any(|p| p == &identity.id || identity.scopes.contains(p)) } - fn matches_transport(&self, transport: TransportKind) -> bool { + fn matches_transport(&self, transport: &TransportKind) -> bool { if self.transports.is_empty() { return true; } - self.transports.contains(&transport) + self.transports.contains(transport) } } @@ -118,7 +118,7 @@ impl ForwardingPolicy { for rule in &self.rules { if rule.target.matches(target, port) && rule.matches_principal(identity) - && rule.matches_transport(transport) + && rule.matches_transport(&transport) { return rule.action == ForwardingAction::Allow; } @@ -152,7 +152,12 @@ mod tests { let policy = ForwardingPolicy::allow_all(); let identity = make_identity("user1", vec![]); assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp)); - assert!(policy.check("10.0.0.1", 22, &identity, TransportKind::Tls)); + assert!(policy.check( + "10.0.0.1", + 22, + &identity, + TransportKind::Tls { server_name: None } + )); } #[test] @@ -160,7 +165,12 @@ mod tests { let policy = ForwardingPolicy::deny_all(); let identity = make_identity("user1", vec![]); assert!(!policy.check("example.com", 80, &identity, TransportKind::Tcp)); - assert!(!policy.check("10.0.0.1", 22, &identity, TransportKind::Tls)); + assert!(!policy.check( + "10.0.0.1", + 22, + &identity, + TransportKind::Tls { server_name: None } + )); } #[test] @@ -282,8 +292,20 @@ mod tests { }; let identity = make_identity("user1", vec![]); assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp)); - assert!(policy.check("example.com", 80, &identity, TransportKind::Tls)); - assert!(policy.check("example.com", 80, &identity, TransportKind::Iroh)); + assert!(policy.check( + "example.com", + 80, + &identity, + TransportKind::Tls { server_name: None } + )); + assert!(policy.check( + "example.com", + 80, + &identity, + TransportKind::Iroh { + endpoint_id: String::new() + } + )); } #[test] @@ -294,12 +316,17 @@ mod tests { target: TargetPattern::Any, action: ForwardingAction::Allow, principals: vec![], - transports: vec![TransportKind::Tls], + transports: vec![TransportKind::Tls { server_name: None }], }], }; let identity = make_identity("user1", vec![]); assert!(!policy.check("example.com", 443, &identity, TransportKind::Tcp)); - assert!(policy.check("example.com", 443, &identity, TransportKind::Tls)); + assert!(policy.check( + "example.com", + 443, + &identity, + TransportKind::Tls { server_name: None } + )); } #[test] @@ -420,14 +447,24 @@ mod tests { target: TargetPattern::Host("restricted.example.com".to_string()), action: ForwardingAction::Allow, principals: vec!["admin".to_string()], - transports: vec![TransportKind::Tls], + transports: vec![TransportKind::Tls { server_name: None }], }], }; let admin = make_identity("admin-user", vec!["admin"]); let viewer = make_identity("viewer-user", vec!["viewer"]); - assert!(policy.check("restricted.example.com", 443, &admin, TransportKind::Tls)); + assert!(policy.check( + "restricted.example.com", + 443, + &admin, + TransportKind::Tls { server_name: None } + )); assert!(!policy.check("restricted.example.com", 443, &admin, TransportKind::Tcp)); - assert!(!policy.check("restricted.example.com", 443, &viewer, TransportKind::Tls)); + assert!(!policy.check( + "restricted.example.com", + 443, + &viewer, + TransportKind::Tls { server_name: None } + )); } #[test] @@ -439,19 +476,37 @@ mod tests { target: TargetPattern::AlknetPrefix, action: ForwardingAction::Allow, principals: vec![], - transports: vec![TransportKind::WebTransport], + transports: vec![TransportKind::WebTransport { + host: String::new(), + }], }, ForwardingRule { target: TargetPattern::Any, action: ForwardingAction::Deny, principals: vec![], - transports: vec![TransportKind::WebTransport], + transports: vec![TransportKind::WebTransport { + host: String::new(), + }], }, ], }; let identity = make_identity("user1", vec![]); - assert!(policy.check("alknet-control", 0, &identity, TransportKind::WebTransport)); - assert!(!policy.check("example.com", 443, &identity, TransportKind::WebTransport)); + assert!(policy.check( + "alknet-control", + 0, + &identity, + TransportKind::WebTransport { + host: String::new() + } + )); + assert!(!policy.check( + "example.com", + 443, + &identity, + TransportKind::WebTransport { + host: String::new() + } + )); assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp)); } diff --git a/crates/alknet-core/src/config/static_config.rs b/crates/alknet-core/src/config/static_config.rs index 79d6054..69398b9 100644 --- a/crates/alknet-core/src/config/static_config.rs +++ b/crates/alknet-core/src/config/static_config.rs @@ -1,5 +1,7 @@ +use crate::interface::InterfaceKind; use crate::server::handler::{ProxyConfig, ProxyMode}; use crate::server::serve::{ListenerConfig, ServeTransportMode}; +use crate::transport::TransportKind; use std::net::SocketAddr; pub struct StaticConfig { @@ -62,10 +64,13 @@ impl StaticConfig { } else { vec![ListenerConfig { transport_kind: match opts.transport_mode { - ServeTransportMode::Tcp => crate::server::handler::TransportKind::Tcp, - ServeTransportMode::Tls => crate::server::handler::TransportKind::Tls, - ServeTransportMode::Iroh => crate::server::handler::TransportKind::Iroh, + ServeTransportMode::Tcp => TransportKind::Tcp, + ServeTransportMode::Tls => TransportKind::Tls { server_name: None }, + ServeTransportMode::Iroh => TransportKind::Iroh { + endpoint_id: String::new(), + }, }, + interface_kind: InterfaceKind::Ssh, listen_addr: opts.listen_addr.clone(), tls_cert: opts.tls_cert.clone(), tls_key: opts.tls_key.clone(), @@ -125,8 +130,8 @@ fn parse_proxy_config(proxy: Option<&str>) -> Option { mod tests { use super::*; use crate::auth::keys::KeySource; - use crate::server::handler::TransportKind; use crate::server::serve::ServeOptions; + use crate::transport::TransportKind; const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n"; diff --git a/crates/alknet-core/src/interface/mod.rs b/crates/alknet-core/src/interface/mod.rs index 3dc29e8..c6f6d1a 100644 --- a/crates/alknet-core/src/interface/mod.rs +++ b/crates/alknet-core/src/interface/mod.rs @@ -23,7 +23,9 @@ pub mod config; pub mod pairs; +pub mod raw_framing; pub mod session; +pub mod ssh; use anyhow::Result; use async_trait::async_trait; @@ -31,7 +33,9 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub use config::{InterfaceConfig, InterfaceKind, RawFramingConfig, SshInterfaceConfig}; pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS}; +pub use raw_framing::{RawFramingInterface, RawFramingSession}; pub use session::{InterfaceEvent, InterfaceSession}; +pub use ssh::{SshInterface, SshSession}; pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {} diff --git a/crates/alknet-core/src/interface/raw_framing.rs b/crates/alknet-core/src/interface/raw_framing.rs new file mode 100644 index 0000000..9f509bf --- /dev/null +++ b/crates/alknet-core/src/interface/raw_framing.rs @@ -0,0 +1,62 @@ +use anyhow::Result; +use async_trait::async_trait; + +use crate::interface::session::{InterfaceEvent, InterfaceSession}; +use crate::interface::{Interface, InterfaceConfig, TransportStream}; + +pub struct RawFramingInterface; + +pub struct RawFramingSession; + +#[async_trait] +impl Interface for RawFramingInterface { + type Session = RawFramingSession; + + async fn accept( + &self, + _stream: Box, + _config: &InterfaceConfig, + ) -> Result { + Err(anyhow::anyhow!( + "RawFramingInterface is not yet implemented (Phase 4+)" + )) + } +} + +#[async_trait] +impl InterfaceSession for RawFramingSession { + async fn recv(&mut self) -> Option { + None + } + + async fn send(&mut self, _envelope: crate::call::EventEnvelope) -> Result<()> { + Err(anyhow::anyhow!( + "RawFramingSession is not yet implemented (Phase 4+)" + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn raw_framing_interface_type_exists() { + let _iface = RawFramingInterface; + } + + #[test] + fn raw_framing_session_type_exists() { + let _session = RawFramingSession; + } + + #[tokio::test] + async fn raw_framing_interface_accept_returns_error() { + let iface = RawFramingInterface; + let (_client, server) = tokio::io::duplex(1024); + let stream: Box = Box::new(server); + let config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); + let result = iface.accept(stream, &config).await; + assert!(result.is_err()); + } +} diff --git a/crates/alknet-core/src/interface/ssh.rs b/crates/alknet-core/src/interface/ssh.rs new file mode 100644 index 0000000..e46f6a7 --- /dev/null +++ b/crates/alknet-core/src/interface/ssh.rs @@ -0,0 +1,733 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use arc_swap::ArcSwap; +use async_trait::async_trait; +use russh::keys::ssh_key::HashAlg; +use russh::server::{self, Config}; +use russh::Channel; +use russh::ChannelId; + +use crate::auth::identity::{Identity, IdentityProvider}; +use crate::call::EventEnvelope; +use crate::config::DynamicConfig; +use crate::interface::session::{InterfaceEvent, InterfaceSession}; +use crate::interface::{Interface, InterfaceConfig, TransportStream}; +use crate::server::control_channel::{ControlChannelRouter, ALKNET_PREFIX}; +use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; +use crate::transport::TransportKind; + +struct SshHandler { + dynamic: Arc>, + identity_provider: Arc, + outbound_proxy: Option, + remote_addr: Option, + transport: TransportKind, + connection_limiter: Arc, + connection_allowed: bool, + auth_limiter: AuthAttemptLimiter, + authenticated_identity: Option, + control_channel_router: ControlChannelRouter, + connected_at: Instant, +} + +impl SshHandler { + fn new( + dynamic: Arc>, + identity_provider: Arc, + outbound_proxy: Option, + remote_addr: Option, + transport: TransportKind, + connection_limiter: Arc, + max_auth_attempts: usize, + ) -> Self { + let allowed = if let Some(addr) = remote_addr { + let ip = addr.ip(); + if connection_limiter.check(ip) { + connection_limiter.on_connect(ip); + tracing::info!( + remote_addr = %addr, + transport = %transport, + "connection opened" + ); + true + } else { + tracing::info!( + remote_addr = %addr, + transport = %transport, + "connection rejected" + ); + false + } + } else { + true + }; + + Self { + dynamic, + identity_provider, + outbound_proxy, + remote_addr, + transport, + connection_limiter, + connection_allowed: allowed, + auth_limiter: AuthAttemptLimiter::new(max_auth_attempts), + authenticated_identity: None, + control_channel_router: ControlChannelRouter::without_handler(), + connected_at: Instant::now(), + } + } + + #[allow(dead_code)] + fn with_control_channel_router(mut self, router: ControlChannelRouter) -> Self { + self.control_channel_router = router; + self + } +} + +impl Drop for SshHandler { + fn drop(&mut self) { + if let Some(addr) = self.remote_addr { + if self.connection_allowed { + self.connection_limiter.on_disconnect(addr.ip()); + let duration = self.connected_at.elapsed(); + tracing::info!( + remote_addr = %addr, + duration_secs = duration.as_secs_f64(), + "connection closed" + ); + } + } + } +} + +#[async_trait] +impl server::Handler for SshHandler { + type Error = russh::Error; + + async fn auth_publickey( + &mut self, + user: &str, + public_key: &russh::keys::ssh_key::PublicKey, + ) -> Result { + if !self.auth_limiter.check() { + let remote_addr_display = self + .remote_addr + .map_or("unknown".to_string(), |a| a.to_string()); + let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); + tracing::info!( + remote_addr = %remote_addr_display, + user = user, + key_fingerprint = %fingerprint, + result = "reject", + "auth attempt" + ); + return Ok(server::Auth::Reject { + proceed_with_methods: None, + }); + } + + let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256)); + let remote_addr_display = self + .remote_addr + .map_or("unknown".to_string(), |a| a.to_string()); + + let identity = self + .identity_provider + .resolve_from_fingerprint(&fingerprint); + + match identity { + Some(id) => { + self.authenticated_identity = Some(id); + tracing::info!( + remote_addr = %remote_addr_display, + user = user, + key_fingerprint = %fingerprint, + result = "accept", + "auth attempt" + ); + Ok(server::Auth::Accept) + } + None => { + self.auth_limiter.on_failure(); + tracing::info!( + remote_addr = %remote_addr_display, + user = user, + key_fingerprint = %fingerprint, + result = "reject", + "auth attempt" + ); + Ok(server::Auth::Reject { + proceed_with_methods: None, + }) + } + } + } + + async fn channel_open_direct_tcpip( + &mut self, + channel: Channel, + host_to_connect: &str, + port_to_connect: u32, + originator_address: &str, + originator_port: u32, + _session: &mut server::Session, + ) -> Result { + if host_to_connect.starts_with(ALKNET_PREFIX) { + if !self.control_channel_router.has_handler() { + return Ok(false); + } + let _ = channel; + return Ok(true); + } + + let identity = self + .authenticated_identity + .clone() + .unwrap_or_else(|| Identity { + id: String::new(), + scopes: vec![], + resources: std::collections::HashMap::new(), + }); + + let policy = self.dynamic.load(); + let allowed = policy.forwarding.check( + host_to_connect, + port_to_connect as u16, + &identity, + self.transport.clone(), + ); + + if !allowed { + tracing::info!( + remote_addr = ?self.remote_addr, + target = %format!("{host_to_connect}:{port_to_connect}"), + identity = %identity.id, + transport = %self.transport, + "forwarding denied by policy" + ); + return Ok(false); + } + + let target_host = host_to_connect.to_string(); + let target_port = port_to_connect; + let proxy_config = + self.outbound_proxy + .clone() + .unwrap_or(crate::server::handler::ProxyConfig { + mode: crate::server::handler::ProxyMode::Direct, + }); + + tokio::spawn(async move { + let target = match format!("{target_host}:{target_port}") + .parse::() + { + Ok(addr) => addr, + Err(_) => { + match tokio::net::lookup_host((&target_host[..], target_port as u16)).await { + Ok(mut addrs) => match addrs.next() { + Some(addr) => addr, + None => return, + }, + Err(_) => return, + } + } + }; + crate::server::channel_proxy::proxy_channel( + channel.into_stream(), + target, + &proxy_config, + ) + .await; + }); + + let _ = (originator_address, originator_port); + Ok(true) + } + + async fn channel_open_session( + &mut self, + _channel: Channel, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + "rejected session channel (shell/exec not supported)" + ); + let _ = session; + Ok(false) + } + + async fn channel_open_x11( + &mut self, + _channel: Channel, + _originator_address: &str, + _originator_port: u32, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + "rejected x11 channel" + ); + let _ = session; + Ok(false) + } + + async fn channel_open_forwarded_tcpip( + &mut self, + _channel: Channel, + host_to_connect: &str, + port_to_connect: u32, + _originator_address: &str, + _originator_port: u32, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + target = %format!("{host_to_connect}:{port_to_connect}"), + "rejected forwarded-tcpip channel (remote port forwarding not supported)" + ); + let _ = session; + Ok(false) + } + + async fn exec_request( + &mut self, + channel: ChannelId, + data: &[u8], + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + data_len = data.len(), + "rejected exec request on channel (shell/exec not supported)" + ); + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn shell_request( + &mut self, + channel: ChannelId, + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + "rejected shell request on channel" + ); + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn subsystem_request( + &mut self, + channel: ChannelId, + name: &str, + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + subsystem = name, + "rejected subsystem request on channel" + ); + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn pty_request( + &mut self, + channel: ChannelId, + term: &str, + col_width: u32, + row_height: u32, + pix_width: u32, + pix_height: u32, + modes: &[(russh::Pty, u32)], + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + term = term, + "rejected pty request on channel" + ); + let _ = (col_width, row_height, pix_width, pix_height, modes); + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn env_request( + &mut self, + channel: ChannelId, + variable_name: &str, + variable_value: &str, + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + variable = variable_name, + "rejected env request on channel" + ); + let _ = variable_value; + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn x11_request( + &mut self, + channel: ChannelId, + single_connection: bool, + x11_auth_protocol: &str, + x11_auth_cookie: &str, + x11_screen_number: u32, + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + "rejected x11 request on channel" + ); + let _ = ( + single_connection, + x11_auth_protocol, + x11_auth_cookie, + x11_screen_number, + ); + let _ = session.channel_failure(channel); + Ok(()) + } + + async fn agent_request( + &mut self, + channel: ChannelId, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + channel = %channel, + "rejected agent forwarding request on channel" + ); + let _ = session; + Ok(false) + } + + async fn tcpip_forward( + &mut self, + address: &str, + port: &mut u32, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + address = address, + port = *port, + "rejected tcpip-forward request (remote port forwarding not supported)" + ); + let _ = session; + Ok(false) + } + + async fn cancel_tcpip_forward( + &mut self, + address: &str, + port: u32, + session: &mut server::Session, + ) -> Result { + let _ = (address, port, session); + Ok(false) + } + + async fn streamlocal_forward( + &mut self, + socket_path: &str, + session: &mut server::Session, + ) -> Result { + tracing::warn!( + remote_addr = ?self.remote_addr, + socket_path = socket_path, + "rejected streamlocal-forward request" + ); + let _ = session; + Ok(false) + } + + async fn signal( + &mut self, + channel: ChannelId, + signal: russh::Sig, + session: &mut server::Session, + ) -> Result<(), Self::Error> { + tracing::debug!( + remote_addr = ?self.remote_addr, + channel = %channel, + signal = ?signal, + "received signal on channel (ignored)" + ); + let _ = session; + Ok(()) + } +} + +pub struct SshInterface { + config: Arc, + dynamic: Arc>, + connection_limiter: Arc, + outbound_proxy: Option, + max_auth_attempts: usize, +} + +impl SshInterface { + pub fn new(config: Arc, dynamic: Arc>) -> Self { + Self { + config, + dynamic, + connection_limiter: Arc::new(ConnectionRateLimiter::new(0)), + outbound_proxy: None, + max_auth_attempts: 10, + } + } + + pub fn with_connection_limiter(mut self, limiter: Arc) -> Self { + self.connection_limiter = limiter; + self + } + + pub fn with_outbound_proxy( + mut self, + proxy: Option, + ) -> Self { + self.outbound_proxy = proxy; + self + } + + pub fn with_max_auth_attempts(mut self, max: usize) -> Self { + self.max_auth_attempts = max; + self + } + + pub fn config(&self) -> &Arc { + &self.config + } + + pub fn dynamic(&self) -> &Arc> { + &self.dynamic + } + + async fn accept_inner( + &self, + stream: Box, + ssh_config: &crate::interface::SshInterfaceConfig, + remote_addr: Option, + transport: TransportKind, + ) -> Result { + let identity_provider = Arc::clone(&ssh_config.auth); + let _forwarding = Arc::clone(&ssh_config.forwarding); + + let handler = SshHandler::new( + Arc::clone(&self.dynamic), + identity_provider, + self.outbound_proxy.clone(), + remote_addr, + transport, + Arc::clone(&self.connection_limiter), + self.max_auth_attempts, + ); + + let running = server::run_stream(Arc::clone(&self.config), stream, handler).await?; + let handle = running.handle(); + let join = tokio::spawn(async { + let _ = running.await; + }); + + Ok(SshSession { + handle, + _join: join, + }) + } +} + +#[async_trait] +impl Interface for SshInterface { + type Session = SshSession; + + async fn accept( + &self, + stream: Box, + config: &InterfaceConfig, + ) -> Result { + let ssh_config = match config { + InterfaceConfig::Ssh(c) => c, + InterfaceConfig::RawFraming(_) => { + return Err(anyhow::anyhow!("SshInterface received RawFramingConfig")); + } + }; + + self.accept_inner(stream, ssh_config, None, TransportKind::Tcp) + .await + } +} + +pub struct SshSession { + handle: server::Handle, + _join: tokio::task::JoinHandle<()>, +} + +impl SshSession { + pub fn handle(&self) -> &server::Handle { + &self.handle + } +} + +#[async_trait] +impl InterfaceSession for SshSession { + async fn recv(&mut self) -> Option { + None + } + + async fn send(&mut self, _envelope: EventEnvelope) -> Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ssh_interface_constructs_with_config() { + let config = Arc::new(Config { + keys: vec![russh::keys::PrivateKey::random( + &mut rand_core::OsRng, + russh::keys::Algorithm::Ed25519, + ) + .unwrap()], + ..Default::default() + }); + let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + + let iface = SshInterface::new(config, dynamic); + assert!(iface.config().keys.len() >= 1); + } + + #[test] + fn ssh_interface_builder_pattern() { + let config = Arc::new(Config { + keys: vec![russh::keys::PrivateKey::random( + &mut rand_core::OsRng, + russh::keys::Algorithm::Ed25519, + ) + .unwrap()], + ..Default::default() + }); + let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + let limiter = Arc::new(ConnectionRateLimiter::new(5)); + + let iface = SshInterface::new(config, dynamic) + .with_connection_limiter(limiter) + .with_max_auth_attempts(3); + + assert!(iface.config().keys.len() >= 1); + } + + #[test] + fn ssh_handler_auth_delegates_to_identity_provider() { + use std::collections::HashMap; + + struct MockProvider { + identities: HashMap, + } + + impl IdentityProvider for MockProvider { + fn resolve_from_fingerprint(&self, fp: &str) -> Option { + self.identities.get(fp).cloned() + } + fn resolve_from_token(&self, _t: &crate::auth::AuthToken) -> Option { + None + } + } + + let mut ids = HashMap::new(); + ids.insert( + "SHA256:testkey".to_string(), + Identity { + id: "SHA256:testkey".to_string(), + scopes: vec!["admin".to_string()], + resources: HashMap::new(), + }, + ); + + let provider: Arc = Arc::new(MockProvider { identities: ids }); + let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + let limiter = Arc::new(ConnectionRateLimiter::new(0)); + + let handler = SshHandler::new( + dynamic, + provider, + None, + None, + TransportKind::Tcp, + limiter, + 10, + ); + + assert!(handler.authenticated_identity.is_none()); + } + + #[test] + fn ssh_handler_connection_rate_limiting() { + let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + let provider: Arc = Arc::new( + crate::auth::identity::ConfigIdentityProvider::new(Arc::clone(&dynamic)), + ); + let limiter = Arc::new(ConnectionRateLimiter::new(1)); + let addr: SocketAddr = "10.0.0.1:22".parse().unwrap(); + + let h1 = SshHandler::new( + Arc::clone(&dynamic), + Arc::clone(&provider), + None, + Some(addr), + TransportKind::Tcp, + Arc::clone(&limiter), + 10, + ); + assert!(h1.connection_allowed); + + let h2 = SshHandler::new( + dynamic, + provider, + None, + Some(addr), + TransportKind::Tcp, + limiter, + 10, + ); + assert!(!h2.connection_allowed); + } + + #[tokio::test] + async fn ssh_interface_rejects_raw_framing_config() { + let config = Arc::new(Config { + keys: vec![russh::keys::PrivateKey::random( + &mut rand_core::OsRng, + russh::keys::Algorithm::Ed25519, + ) + .unwrap()], + ..Default::default() + }); + let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + let iface = SshInterface::new(config, dynamic); + let (_client, server) = tokio::io::duplex(1024); + let stream: Box = Box::new(server); + + let raw_config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {}); + let result = iface.accept(stream, &raw_config).await; + assert!(result.is_err()); + } +} diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index 5fe70d4..9686e9d 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -87,8 +87,8 @@ pub use config::{ pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; pub use interface::{ is_valid_pair, Interface, InterfaceConfig, InterfaceEvent, InterfaceKind, InterfaceSession, - RawFramingConfig, SshInterfaceConfig, TransportKindBase, TransportStream, - VALID_TRANSPORT_INTERFACE_PAIRS, + RawFramingConfig, RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig, + SshSession, TransportKindBase, TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS, }; pub use server::serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server}; pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index 10c4401..d39a4d8 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -14,6 +14,8 @@ use crate::config::DynamicConfig; use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX}; use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; +pub use crate::transport::TransportKind; + #[derive(Debug, Clone)] pub enum ProxyMode { Direct, @@ -26,27 +28,6 @@ pub struct ProxyConfig { pub mode: ProxyMode, } -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum TransportKind { - Tcp, - Tls, - Iroh, - Dns, - WebTransport, -} - -impl std::fmt::Display for TransportKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TransportKind::Tcp => write!(f, "tcp"), - TransportKind::Tls => write!(f, "tls"), - TransportKind::Iroh => write!(f, "iroh"), - TransportKind::Dns => write!(f, "dns"), - TransportKind::WebTransport => write!(f, "webtransport"), - } - } -} - pub struct ServerHandler { dynamic: Arc>, identity_provider: Arc, @@ -252,7 +233,7 @@ impl Handler for ServerHandler { host_to_connect, port_to_connect as u16, &identity, - self.transport, + self.transport.clone(), ); if !allowed { @@ -784,10 +765,28 @@ mod tests { #[test] fn transport_kind_display() { assert_eq!(TransportKind::Tcp.to_string(), "tcp"); - assert_eq!(TransportKind::Tls.to_string(), "tls"); - assert_eq!(TransportKind::Iroh.to_string(), "iroh"); - assert_eq!(TransportKind::Dns.to_string(), "dns"); - assert_eq!(TransportKind::WebTransport.to_string(), "webtransport"); + assert_eq!(TransportKind::Tls { server_name: None }.to_string(), "tls"); + assert_eq!( + TransportKind::Iroh { + endpoint_id: String::new() + } + .to_string(), + "iroh" + ); + assert_eq!( + TransportKind::Dns { + domain: String::new() + } + .to_string(), + "dns" + ); + assert_eq!( + TransportKind::WebTransport { + host: String::new() + } + .to_string(), + "webtransport" + ); } #[tokio::test] @@ -797,7 +796,7 @@ mod tests { auth_config, None, Some("203.0.113.50:12345".parse().unwrap()), - TransportKind::Tls, + TransportKind::Tls { server_name: None }, Arc::new(ConnectionRateLimiter::new(0)), 10, ); diff --git a/crates/alknet-core/src/server/mod.rs b/crates/alknet-core/src/server/mod.rs index 901f9f6..993e2b3 100644 --- a/crates/alknet-core/src/server/mod.rs +++ b/crates/alknet-core/src/server/mod.rs @@ -19,9 +19,11 @@ pub use control_channel::{ is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX, }; -pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; +pub use handler::{ProxyConfig, ProxyMode, ServerHandler}; pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; pub use serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server}; + +pub use crate::transport::TransportKind; pub use stealth::{ detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection, }; diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index 757322f..85188f4 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -16,9 +16,11 @@ use tracing::{error, info, warn}; use crate::auth::keys::KeySource; use crate::config::{ConfigReloadHandle, DynamicConfig}; use crate::error::ConfigError; -use crate::server::handler::{ProxyConfig, ServerHandler, TransportKind}; +use crate::interface::InterfaceKind; +use crate::server::handler::{ProxyConfig, ServerHandler}; use crate::server::rate_limit::ConnectionRateLimiter; use crate::server::stealth::{self, ProtocolDetection}; +use crate::transport::TransportKind; const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22"; const DRAIN_TIMEOUT: Duration = Duration::from_secs(2); @@ -43,6 +45,7 @@ impl std::fmt::Display for ServeTransportMode { #[derive(Debug, Clone, PartialEq)] pub struct ListenerConfig { pub transport_kind: TransportKind, + pub interface_kind: InterfaceKind, pub listen_addr: String, pub tls_cert: Option, pub tls_key: Option, @@ -55,6 +58,7 @@ impl ListenerConfig { pub fn tcp(addr: impl Into) -> Self { Self { transport_kind: TransportKind::Tcp, + interface_kind: InterfaceKind::Ssh, listen_addr: addr.into(), tls_cert: None, tls_key: None, @@ -66,7 +70,8 @@ impl ListenerConfig { pub fn tls(addr: impl Into) -> Self { Self { - transport_kind: TransportKind::Tls, + transport_kind: TransportKind::Tls { server_name: None }, + interface_kind: InterfaceKind::Ssh, listen_addr: addr.into(), tls_cert: None, tls_key: None, @@ -78,7 +83,10 @@ impl ListenerConfig { pub fn iroh(addr: impl Into) -> Self { Self { - transport_kind: TransportKind::Iroh, + transport_kind: TransportKind::Iroh { + endpoint_id: String::new(), + }, + interface_kind: InterfaceKind::Ssh, listen_addr: addr.into(), tls_cert: None, tls_key: None, @@ -90,7 +98,10 @@ impl ListenerConfig { pub fn dns(domain: impl Into) -> Self { Self { - transport_kind: TransportKind::Dns, + transport_kind: TransportKind::Dns { + domain: String::new(), + }, + interface_kind: InterfaceKind::RawFraming, listen_addr: domain.into(), tls_cert: None, tls_key: None, @@ -102,7 +113,10 @@ impl ListenerConfig { pub fn webtransport(host: impl Into) -> Self { Self { - transport_kind: TransportKind::WebTransport, + transport_kind: TransportKind::WebTransport { + host: String::new(), + }, + interface_kind: InterfaceKind::Ssh, listen_addr: host.into(), tls_cert: None, tls_key: None, @@ -138,14 +152,14 @@ impl ListenerConfig { } pub fn validate(&self) -> Result<(), ConfigError> { - if self.stealth && self.transport_kind != TransportKind::Tls { + if self.stealth && !matches!(self.transport_kind, TransportKind::Tls { .. }) { return Err(ConfigError::InvalidFlag { name: "stealth mode requires TLS transport".to_string(), }); } match self.transport_kind { - TransportKind::Tls => { + TransportKind::Tls { .. } => { if self.tls_cert.is_none() && self.acme_domain.is_none() { return Err(ConfigError::InvalidFlag { name: "TLS transport requires tls_cert/tls_key or acme_domain".to_string(), @@ -163,9 +177,9 @@ impl ListenerConfig { } } TransportKind::Tcp - | TransportKind::Iroh - | TransportKind::Dns - | TransportKind::WebTransport => { + | TransportKind::Iroh { .. } + | TransportKind::Dns { .. } + | TransportKind::WebTransport { .. } => { if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() { return Err(ConfigError::IncompatibleOptions); } @@ -179,9 +193,9 @@ impl ListenerConfig { impl std::fmt::Display for ListenerConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.transport_kind { - TransportKind::Iroh => write!(f, "{} (iroh)", self.listen_addr), - TransportKind::Dns => write!(f, "{} (dns)", self.listen_addr), - TransportKind::WebTransport => write!(f, "{} (webtransport)", self.listen_addr), + TransportKind::Iroh { .. } => write!(f, "{} (iroh)", self.listen_addr), + TransportKind::Dns { .. } => write!(f, "{} (dns)", self.listen_addr), + TransportKind::WebTransport { .. } => write!(f, "{} (webtransport)", self.listen_addr), _ => write!(f, "{} ({})", self.listen_addr, self.transport_kind), } } @@ -474,11 +488,11 @@ impl Server { .first() .expect("at least one listener required"); - let transport_kind = listener.transport_kind; + let transport_kind = listener.transport_kind.clone(); let stealth = listener.stealth; let listen_addr = listener.listen_addr.clone(); - if matches!(transport_kind, TransportKind::Iroh) { + if matches!(transport_kind, TransportKind::Iroh { .. }) { if let Some(id) = endpoint_info { info!("alknet server running: transport=iroh endpoint_id={}", id); } else { @@ -538,7 +552,7 @@ impl Server { }; let remote_addr = info.remote_addr; - let handler_transport_kind = transport_kind; + let handler_transport_kind = transport_kind.clone(); let handler = ServerHandler::new( Arc::clone(&server.dynamic), @@ -555,7 +569,7 @@ impl Server { let config = Arc::clone(&server.config); let sessions = Arc::clone(&server.sessions); - let transport_is_tls = matches!(transport_kind, TransportKind::Tls); + let transport_is_tls = matches!(transport_kind, TransportKind::Tls { .. }); tokio::spawn(async move { let result = @@ -830,7 +844,7 @@ mod tests { .tls_cert("/cert.pem") .tls_key("/key.pem") .stealth(true); - assert_eq!(lc.transport_kind, TransportKind::Tls); + assert_eq!(lc.transport_kind, TransportKind::Tls { server_name: None }); assert_eq!(lc.listen_addr, "0.0.0.0:443"); assert!(lc.stealth); assert_eq!(lc.tls_cert.as_deref(), Some("/cert.pem")); @@ -840,21 +854,36 @@ mod tests { #[test] fn listener_config_iroh_constructor() { let lc = ListenerConfig::iroh("0.0.0.0:0").iroh_relay("https://relay.example.com"); - assert_eq!(lc.transport_kind, TransportKind::Iroh); + assert_eq!( + lc.transport_kind, + TransportKind::Iroh { + endpoint_id: String::new() + } + ); assert_eq!(lc.iroh_relay.as_deref(), Some("https://relay.example.com")); } #[test] fn listener_config_dns_constructor() { let lc = ListenerConfig::dns("example.com"); - assert_eq!(lc.transport_kind, TransportKind::Dns); + assert_eq!( + lc.transport_kind, + TransportKind::Dns { + domain: String::new() + } + ); assert_eq!(lc.listen_addr, "example.com"); } #[test] fn listener_config_webtransport_constructor() { let lc = ListenerConfig::webtransport("example.com"); - assert_eq!(lc.transport_kind, TransportKind::WebTransport); + assert_eq!( + lc.transport_kind, + TransportKind::WebTransport { + host: String::new() + } + ); assert_eq!(lc.listen_addr, "example.com"); } @@ -1006,7 +1035,10 @@ mod tests { .stealth(true); let server = Server::new(opts).unwrap(); assert_eq!(server.listeners.len(), 1); - assert_eq!(server.listeners[0].transport_kind, TransportKind::Tls); + assert_eq!( + server.listeners[0].transport_kind, + TransportKind::Tls { server_name: None } + ); assert!(server.listeners[0].stealth); assert_eq!(server.listeners[0].tls_cert.as_deref(), Some("/cert.pem")); } @@ -1025,7 +1057,10 @@ mod tests { let server = Server::new(opts).unwrap(); assert_eq!(server.listeners().len(), 2); assert_eq!(server.listeners()[0].transport_kind, TransportKind::Tcp); - assert_eq!(server.listeners()[1].transport_kind, TransportKind::Tls); + assert_eq!( + server.listeners()[1].transport_kind, + TransportKind::Tls { server_name: None } + ); } #[test] diff --git a/crates/alknet-core/src/transport/mod.rs b/crates/alknet-core/src/transport/mod.rs index dd9c4d6..5ab93b2 100644 --- a/crates/alknet-core/src/transport/mod.rs +++ b/crates/alknet-core/src/transport/mod.rs @@ -86,7 +86,7 @@ pub struct TransportInfo { /// Each variant identifies the transport mechanism. Used by the /// server handler for logging and authorization decisions. /// See ADR-001 and ADR-004. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum TransportKind { Tcp, Tls { server_name: Option }, @@ -95,6 +95,18 @@ pub enum TransportKind { WebTransport { host: String }, } +impl std::fmt::Display for TransportKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransportKind::Tcp => write!(f, "tcp"), + TransportKind::Tls { .. } => write!(f, "tls"), + TransportKind::Iroh { .. } => write!(f, "iroh"), + TransportKind::Dns { .. } => write!(f, "dns"), + TransportKind::WebTransport { .. } => write!(f, "webtransport"), + } + } +} + #[cfg(test)] mod tests { use super::*;