diff --git a/crates/alknet-core/src/config/static_config.rs b/crates/alknet-core/src/config/static_config.rs index 6e571ca..79d6054 100644 --- a/crates/alknet-core/src/config/static_config.rs +++ b/crates/alknet-core/src/config/static_config.rs @@ -120,3 +120,85 @@ fn parse_proxy_config(proxy: Option<&str>) -> Option { } }) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::keys::KeySource; + use crate::server::handler::TransportKind; + use crate::server::serve::ServeOptions; + + const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n"; + + const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096"; + + fn make_key_source() -> KeySource { + KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec()) + } + + fn make_authorized_keys_source() -> KeySource { + KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec()) + } + + #[test] + fn parse_proxy_config_socks5() { + let config = parse_proxy_config(Some("socks5://127.0.0.1:9050")); + assert!(config.is_some()); + match config.unwrap().mode { + ProxyMode::Socks5(addr) => { + assert_eq!(addr, "127.0.0.1:9050".parse().unwrap()); + } + _ => panic!("expected Socks5"), + } + } + + #[test] + fn parse_proxy_config_http() { + let config = parse_proxy_config(Some("http://127.0.0.1:8080")); + assert!(config.is_some()); + match config.unwrap().mode { + ProxyMode::HttpConnect(addr) => { + assert_eq!(addr, "127.0.0.1:8080".parse().unwrap()); + } + _ => panic!("expected HttpConnect"), + } + } + + #[test] + fn parse_proxy_config_none() { + assert!(parse_proxy_config(None).is_none()); + } + + #[test] + fn static_config_from_serve_options_basic() { + let opts = + ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source()); + let (static_config, dynamic) = StaticConfig::from_serve_options(opts).unwrap(); + assert_eq!(static_config.listen_addr, "0.0.0.0:22"); + assert_eq!(static_config.max_auth_attempts, 10); + assert!(dynamic.auth.authorized_keys.len() > 0); + } + + #[test] + fn static_config_from_serve_options_with_proxy() { + let opts = ServeOptions::new(make_key_source()) + .authorized_keys(make_authorized_keys_source()) + .proxy("socks5://127.0.0.1:9050"); + let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap(); + assert!(static_config.proxy_config.is_some()); + } + + #[test] + fn static_config_from_serve_options_with_listeners() { + let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")]; + let opts = ServeOptions::new(make_key_source()) + .authorized_keys(make_authorized_keys_source()) + .listeners(listeners); + let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap(); + assert_eq!(static_config.listeners.len(), 1); + assert_eq!( + static_config.listeners[0].transport_kind, + TransportKind::Tcp + ); + } +} diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index c3ec0f6..10c4401 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -49,7 +49,7 @@ impl std::fmt::Display for TransportKind { pub struct ServerHandler { dynamic: Arc>, - identity_provider: Box, + identity_provider: Arc, #[allow(dead_code)] outbound_proxy: Option, remote_addr: Option, @@ -72,8 +72,8 @@ impl ServerHandler { connection_limiter: Arc, max_auth_attempts: usize, ) -> Self { - let identity_provider: Box = - Box::new(ConfigIdentityProvider::new(Arc::clone(&dynamic))); + let identity_provider: Arc = + Arc::new(ConfigIdentityProvider::new(Arc::clone(&dynamic))); let allowed = if let Some(addr) = remote_addr { let ip = addr.ip(); @@ -112,7 +112,7 @@ impl ServerHandler { } } - pub fn with_identity_provider(mut self, provider: Box) -> Self { + pub fn with_identity_provider(mut self, provider: Arc) -> Self { self.identity_provider = provider; self } @@ -818,4 +818,167 @@ mod tests { 10, ); } + + #[tokio::test] + async fn config_reload_new_keys_take_effect() { + let auth_config = make_auth_config(ED25519_PUBLIC_KEY); + let mut handler = ServerHandler::new( + auth_config.clone(), + None, + None, + TransportKind::Tcp, + default_limiter(), + 10, + ); + + let ssh_key = load_key().public_key().clone(); + let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap(); + assert_eq!(result, Auth::Accept); + drop(handler); + + let new_dynamic = DynamicConfig::default(); + auth_config.store(Arc::new(new_dynamic)); + + let mut handler2 = ServerHandler::new( + auth_config.clone(), + None, + None, + TransportKind::Tcp, + default_limiter(), + 10, + ); + + let result2 = handler2.auth_publickey("testuser", &ssh_key).await.unwrap(); + assert_eq!( + result2, + Auth::Reject { + proceed_with_methods: None + } + ); + } + + #[tokio::test] + async fn forwarding_policy_deny_blocks_channel_open() { + use crate::config::forwarding::{ + ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern, + }; + + let deny_policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Host("blocked.example.com".to_string()), + action: ForwardingAction::Deny, + principals: vec![], + transports: vec![], + }], + }; + + let auth_config = make_auth_config(ED25519_PUBLIC_KEY); + { + let dynamic = auth_config.load(); + let new_dynamic = DynamicConfig { + auth: dynamic.auth.clone(), + forwarding: deny_policy, + rate_limits: dynamic.rate_limits.clone(), + }; + drop(dynamic); + auth_config.store(Arc::new(new_dynamic)); + } + + let mut handler = ServerHandler::new( + auth_config, + None, + Some("127.0.0.1:12345".parse().unwrap()), + TransportKind::Tcp, + default_limiter(), + 10, + ); + + let ssh_key = load_key().public_key().clone(); + let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap(); + assert_eq!(result, Auth::Accept); + assert!(handler.authenticated_identity().is_some()); + + let identity = handler.authenticated_identity().unwrap(); + let dynamic = handler.dynamic.load(); + assert!(!dynamic.forwarding.check( + "blocked.example.com", + 443, + identity, + TransportKind::Tcp + )); + } + + #[test] + fn forwarding_policy_deny_with_custom_identity() { + use crate::config::forwarding::{ + ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern, + }; + use std::collections::HashMap; + + let mut resources = HashMap::new(); + resources.insert("service".to_string(), vec!["gitea".to_string()]); + let identity = Identity { + id: "SHA256:abc123".to_string(), + scopes: vec!["relay:connect".to_string()], + resources, + }; + + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Host("allowed.example.com".to_string()), + action: ForwardingAction::Allow, + principals: vec!["SHA256:abc123".to_string()], + transports: vec![TransportKind::Tcp], + }], + }; + + assert!(policy.check("allowed.example.com", 443, &identity, TransportKind::Tcp)); + assert!(!policy.check("denied.example.com", 443, &identity, TransportKind::Tcp)); + } + + #[test] + fn server_handler_with_custom_identity_provider() { + use std::collections::HashMap; + + struct MockIdentityProvider { + identities: HashMap, + } + + impl IdentityProvider for MockIdentityProvider { + fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option { + self.identities.get(fingerprint).cloned() + } + + fn resolve_from_token(&self, _token: &crate::auth::AuthToken) -> Option { + None + } + } + + let mut identities = HashMap::new(); + identities.insert( + "SHA256:testkey".to_string(), + Identity { + id: "SHA256:testkey".to_string(), + scopes: vec!["admin".to_string()], + resources: HashMap::new(), + }, + ); + + let provider = Arc::new(MockIdentityProvider { identities }) as Arc; + let dynamic = make_empty_auth_config(); + + let handler = ServerHandler::new( + dynamic, + None, + Some("10.0.0.1:22".parse().unwrap()), + TransportKind::Tcp, + default_limiter(), + 10, + ) + .with_identity_provider(provider); + + assert!(handler.authenticated_identity().is_none()); + } } diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index 80133fb..757322f 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -5,7 +5,6 @@ //! `ServeOptions` provides a builder-pattern API for programmatic configuration. //! Supports multiple listeners via `ListenerConfig` for multi-transport operation. -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -15,10 +14,9 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; use crate::auth::keys::KeySource; -use crate::auth::server_auth::ServerAuthConfig; -use crate::config::{AuthPolicy, ConfigReloadHandle, DynamicConfig}; +use crate::config::{ConfigReloadHandle, DynamicConfig}; use crate::error::ConfigError; -use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; +use crate::server::handler::{ProxyConfig, ServerHandler, TransportKind}; use crate::server::rate_limit::ConnectionRateLimiter; use crate::server::stealth::{self, ProtocolDetection}; @@ -387,65 +385,32 @@ pub struct Server { impl Server { pub fn new(opts: ServeOptions) -> Result { - opts.validate().map_err(ServeError::Config)?; + let (static_config, dynamic_config) = + crate::config::StaticConfig::from_serve_options(opts).map_err(ServeError::Config)?; - let private_key = crate::auth::keys::load_private_key(opts.key.clone()) - .map_err(ServeError::KeyLoadFailed)?; - - let auth_config = ServerAuthConfig::from_keys_and_ca( - opts.authorized_keys.clone(), - opts.cert_authority.clone(), - ) - .map_err(ServeError::KeyLoadFailed)?; - - let auth_policy = AuthPolicy::from_server_auth_config(auth_config); - let dynamic_config = DynamicConfig::new(auth_policy); - - let max_auth_attempts = opts.max_auth_attempts; - let max_connections_per_ip = opts.max_connections_per_ip; + let connection_limiter = Arc::new(ConnectionRateLimiter::new( + static_config.max_connections_per_ip, + )); let config = Arc::new(Config { - keys: vec![private_key], - max_auth_attempts, + keys: vec![static_config.host_key], + max_auth_attempts: static_config.max_auth_attempts, methods: russh::MethodSet::PUBLICKEY, preferred: russh::Preferred::DEFAULT, ..Default::default() }); - let outbound_proxy = parse_proxy_config(opts.proxy.as_deref()); - - let connection_limiter = Arc::new(ConnectionRateLimiter::new(max_connections_per_ip)); - let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config))); - let listeners = if let Some(listeners) = opts.listeners { - listeners - } else { - let transport_kind = match opts.transport_mode { - ServeTransportMode::Tcp => TransportKind::Tcp, - ServeTransportMode::Tls => TransportKind::Tls, - ServeTransportMode::Iroh => TransportKind::Iroh, - }; - vec![ListenerConfig { - transport_kind, - listen_addr: opts.listen_addr.clone(), - tls_cert: opts.tls_cert.clone(), - tls_key: opts.tls_key.clone(), - acme_domain: opts.acme_domain.clone(), - stealth: opts.stealth, - iroh_relay: opts.iroh_relay.clone(), - }] - }; - Ok(Self { config, dynamic, connection_limiter, - outbound_proxy, - listeners, - max_auth_attempts, + outbound_proxy: static_config.proxy_config, + listeners: static_config.listeners, + max_auth_attempts: static_config.max_auth_attempts, shutdown_tx, shutdown_rx, sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())), @@ -656,32 +621,6 @@ where Ok(()) } -fn parse_proxy_config(proxy: Option<&str>) -> Option { - proxy.map(|url| { - if url.starts_with("socks5://") { - let addr: SocketAddr = url - .strip_prefix("socks5://") - .unwrap() - .parse() - .expect("invalid socks5 proxy address"); - ProxyConfig { - mode: ProxyMode::Socks5(addr), - } - } else if url.starts_with("http://") { - let addr: SocketAddr = url - .strip_prefix("http://") - .unwrap() - .parse() - .expect("invalid http connect proxy address"); - ProxyConfig { - mode: ProxyMode::HttpConnect(addr), - } - } else { - panic!("unsupported proxy URL scheme: {url}"); - } - }) -} - #[cfg(test)] mod tests { use super::*; @@ -850,35 +789,6 @@ mod tests { assert!(!debug_str.contains("OPENSSH")); } - #[test] - fn parse_proxy_config_socks5() { - let config = parse_proxy_config(Some("socks5://127.0.0.1:9050")); - assert!(config.is_some()); - match config.unwrap().mode { - ProxyMode::Socks5(addr) => { - assert_eq!(addr, "127.0.0.1:9050".parse().unwrap()); - } - _ => panic!("expected Socks5"), - } - } - - #[test] - fn parse_proxy_config_http() { - let config = parse_proxy_config(Some("http://127.0.0.1:8080")); - assert!(config.is_some()); - match config.unwrap().mode { - ProxyMode::HttpConnect(addr) => { - assert_eq!(addr, "127.0.0.1:8080".parse().unwrap()); - } - _ => panic!("expected HttpConnect"), - } - } - - #[test] - fn parse_proxy_config_none() { - assert!(parse_proxy_config(None).is_none()); - } - #[test] fn serve_error_variants() { assert_eq!(ServeError::AcceptFailed.to_string(), "accept failed");