From 9478e2911d282e1596f37c83d47e3a2d0b6564a2 Mon Sep 17 00:00:00 2001 From: "glm-5.1" Date: Sun, 7 Jun 2026 14:47:44 +0000 Subject: [PATCH] feat(core): implement ForwardingPolicy with rule-based allow/deny Add ForwardingPolicy, ForwardingAction, ForwardingRule, and TargetPattern types in config/forwarding.rs. Implement policy evaluation with first-match wins semantics, principal and transport matching, CIDR and glob patterns. Modify ServerHandler to check ForwardingPolicy before proxying in channel_open_direct_tcpip. Reserved alknet-* destinations bypass policy. Preserve existing behavior with default allow_all() policy. --- .../alknet-core/src/config/dynamic_config.rs | 37 +- crates/alknet-core/src/config/forwarding.rs | 464 ++++++++++++++++++ crates/alknet-core/src/config/mod.rs | 5 +- crates/alknet-core/src/lib.rs | 2 +- crates/alknet-core/src/server/handler.rs | 32 +- crates/alknet-core/src/server/serve.rs | 4 +- 6 files changed, 503 insertions(+), 41 deletions(-) create mode 100644 crates/alknet-core/src/config/forwarding.rs diff --git a/crates/alknet-core/src/config/dynamic_config.rs b/crates/alknet-core/src/config/dynamic_config.rs index 0a67236..a6e1774 100644 --- a/crates/alknet-core/src/config/dynamic_config.rs +++ b/crates/alknet-core/src/config/dynamic_config.rs @@ -6,6 +6,7 @@ use russh::keys::ssh_key::HashAlg; use crate::auth::identity::Identity; use crate::auth::ServerAuthConfig; +use crate::config::forwarding::ForwardingPolicy; pub struct AuthPolicy { pub authorized_keys: std::collections::HashSet, @@ -212,41 +213,6 @@ impl Clone for AuthPolicy { } } -#[derive(Debug, Clone, PartialEq)] -pub enum ForwardingAction { - Allow, - Deny, -} - -#[derive(Debug, Clone)] -pub struct ForwardingRule { - pub action: ForwardingAction, - pub principals: Vec, - pub transports: Vec, -} - -#[derive(Debug, Clone)] -pub struct ForwardingPolicy { - pub default: ForwardingAction, - pub rules: Vec, -} - -impl ForwardingPolicy { - pub fn allow_all() -> Self { - Self { - default: ForwardingAction::Allow, - rules: Vec::new(), - } - } - - pub fn deny_all() -> Self { - Self { - default: ForwardingAction::Deny, - rules: Vec::new(), - } - } -} - #[derive(Debug, Clone)] pub struct RateLimitConfig { pub max_connections_per_ip: usize, @@ -330,6 +296,7 @@ pub fn new_dynamic_config() -> (Arc>, ConfigReloadHandle) #[cfg(test)] mod tests { use super::*; + use crate::config::forwarding::ForwardingAction; #[test] fn forwarding_policy_allow_all_default() { diff --git a/crates/alknet-core/src/config/forwarding.rs b/crates/alknet-core/src/config/forwarding.rs new file mode 100644 index 0000000..9cea36f --- /dev/null +++ b/crates/alknet-core/src/config/forwarding.rs @@ -0,0 +1,464 @@ +use std::net::IpAddr; +use std::ops::Range; +use std::str::FromStr; + +use ipnetwork::IpNetwork; + +use crate::auth::identity::Identity; +use crate::server::handler::TransportKind; + +#[derive(Debug, Clone, PartialEq)] +pub enum ForwardingAction { + Allow, + Deny, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TargetPattern { + Any, + Host(String), + Cidr(IpNetwork), + PortRange(String, Range), + AlknetPrefix, +} + +impl TargetPattern { + pub fn matches(&self, target: &str, port: u16) -> bool { + match self { + TargetPattern::Any => true, + TargetPattern::Host(pattern) => match_host_pattern(pattern, target), + TargetPattern::Cidr(network) => match_cidr(network, target), + TargetPattern::PortRange(host_pattern, port_range) => { + match_host_pattern(host_pattern, target) && port_range.contains(&port) + } + TargetPattern::AlknetPrefix => { + target.starts_with(crate::server::control_channel::ALKNET_PREFIX) + } + } + } +} + +fn match_host_pattern(pattern: &str, target: &str) -> bool { + if pattern == target { + return true; + } + if pattern.contains('*') { + if let Some(pos) = pattern.find('*') { + let prefix = &pattern[..pos]; + let suffix = &pattern[pos + 1..]; + return target.starts_with(prefix) + && target.ends_with(suffix) + && target.len() >= prefix.len() + suffix.len(); + } + } + false +} + +fn match_cidr(network: &IpNetwork, target: &str) -> bool { + let Ok(addr) = IpAddr::from_str(target) else { + return false; + }; + network.contains(addr) +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ForwardingRule { + pub target: TargetPattern, + pub action: ForwardingAction, + pub principals: Vec, + pub transports: Vec, +} + +impl ForwardingRule { + fn matches_principal(&self, identity: &Identity) -> bool { + if self.principals.is_empty() { + return true; + } + self.principals + .iter() + .any(|p| p == &identity.id || identity.scopes.contains(p)) + } + + fn matches_transport(&self, transport: TransportKind) -> bool { + if self.transports.is_empty() { + return true; + } + self.transports.contains(&transport) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ForwardingPolicy { + pub default: ForwardingAction, + pub rules: Vec, +} + +impl ForwardingPolicy { + pub fn allow_all() -> Self { + Self { + default: ForwardingAction::Allow, + rules: Vec::new(), + } + } + + pub fn deny_all() -> Self { + Self { + default: ForwardingAction::Deny, + rules: Vec::new(), + } + } + + pub fn check( + &self, + target: &str, + port: u16, + identity: &Identity, + transport: TransportKind, + ) -> bool { + for rule in &self.rules { + if rule.target.matches(target, port) + && rule.matches_principal(identity) + && rule.matches_transport(transport) + { + return rule.action == ForwardingAction::Allow; + } + } + self.default == ForwardingAction::Allow + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn make_identity(id: &str, scopes: Vec<&str>) -> Identity { + Identity { + id: id.to_string(), + scopes: scopes.into_iter().map(|s| s.to_string()).collect(), + resources: HashMap::new(), + } + } + + #[test] + fn forwarding_action_equality() { + assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow); + assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny); + assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny); + } + + #[test] + fn allow_all_allows_everything() { + let policy = ForwardingPolicy::allow_all(); + let identity = make_identity("user1", vec![]); + assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp)); + assert!(policy.check("10.0.0.1", 22, &identity, TransportKind::Tls)); + } + + #[test] + fn deny_all_denies_everything() { + let policy = ForwardingPolicy::deny_all(); + let identity = make_identity("user1", vec![]); + assert!(!policy.check("example.com", 80, &identity, TransportKind::Tcp)); + assert!(!policy.check("10.0.0.1", 22, &identity, TransportKind::Tls)); + } + + #[test] + fn first_match_wins_allowlist() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Host("allowed.example.com".to_string()), + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![], + }], + }; + let identity = make_identity("user1", vec![]); + assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp)); + assert!(!policy.check("denied.example.com", 80, &identity, TransportKind::Tcp)); + } + + #[test] + fn first_match_wins_blocklist() { + let policy = ForwardingPolicy { + default: ForwardingAction::Allow, + rules: vec![ForwardingRule { + target: TargetPattern::Host("blocked.example.com".to_string()), + action: ForwardingAction::Deny, + principals: vec![], + transports: vec![], + }], + }; + let identity = make_identity("user1", vec![]); + assert!(!policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp)); + assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp)); + } + + #[test] + fn first_match_wins_ordering() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ + ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![], + }, + ForwardingRule { + target: TargetPattern::Host("blocked.example.com".to_string()), + action: ForwardingAction::Deny, + principals: vec![], + transports: vec![], + }, + ], + }; + let identity = make_identity("user1", vec![]); + assert!(policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp)); + } + + #[test] + fn empty_principals_matches_all() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![], + }], + }; + let identity1 = make_identity("user1", vec![]); + let identity2 = make_identity("user2", vec![]); + assert!(policy.check("example.com", 80, &identity1, TransportKind::Tcp)); + assert!(policy.check("example.com", 80, &identity2, TransportKind::Tcp)); + } + + #[test] + fn principal_matching_by_id() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec!["SHA256:abc123".to_string()], + transports: vec![], + }], + }; + let allowed = make_identity("SHA256:abc123", vec![]); + let denied = make_identity("SHA256:other", vec![]); + assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp)); + assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp)); + } + + #[test] + fn principal_matching_by_scope() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec!["admin".to_string()], + transports: vec![], + }], + }; + let allowed = make_identity("user1", vec!["admin"]); + let denied = make_identity("user2", vec!["viewer"]); + assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp)); + assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp)); + } + + #[test] + fn empty_transports_matches_all() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![], + }], + }; + let identity = make_identity("user1", vec![]); + assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp)); + assert!(policy.check("example.com", 80, &identity, TransportKind::Tls)); + assert!(policy.check("example.com", 80, &identity, TransportKind::Iroh)); + } + + #[test] + fn transport_matching() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![TransportKind::Tls], + }], + }; + let identity = make_identity("user1", vec![]); + assert!(!policy.check("example.com", 443, &identity, TransportKind::Tcp)); + assert!(policy.check("example.com", 443, &identity, TransportKind::Tls)); + } + + #[test] + fn target_pattern_any_matches_all() { + let pattern = TargetPattern::Any; + assert!(pattern.matches("example.com", 80)); + assert!(pattern.matches("10.0.0.1", 22)); + assert!(pattern.matches("alknet-control", 0)); + } + + #[test] + fn target_pattern_host_exact_match() { + let pattern = TargetPattern::Host("example.com".to_string()); + assert!(pattern.matches("example.com", 80)); + assert!(!pattern.matches("other.com", 80)); + assert!(!pattern.matches("sub.example.com", 80)); + } + + #[test] + fn target_pattern_host_glob_match() { + let pattern = TargetPattern::Host("*.example.com".to_string()); + assert!(pattern.matches("sub.example.com", 80)); + assert!(pattern.matches("a.example.com", 443)); + assert!(!pattern.matches("example.com", 80)); + assert!(!pattern.matches("xsub.example.com.org", 80)); + } + + #[test] + fn target_pattern_host_glob_prefix() { + let pattern = TargetPattern::Host("db-*".to_string()); + assert!(pattern.matches("db-primary", 5432)); + assert!(pattern.matches("db-replica", 5432)); + assert!(!pattern.matches("web-primary", 5432)); + } + + #[test] + fn target_pattern_host_glob_suffix() { + let pattern = TargetPattern::Host("*.internal".to_string()); + assert!(pattern.matches("app.internal", 8080)); + assert!(pattern.matches("db.internal", 5432)); + assert!(!pattern.matches("app.external", 80)); + } + + #[test] + fn target_pattern_cidr_matches_ip() { + let network: IpNetwork = "10.0.0.0/8".parse().unwrap(); + let pattern = TargetPattern::Cidr(network); + assert!(pattern.matches("10.0.0.1", 22)); + assert!(pattern.matches("10.255.255.255", 22)); + assert!(!pattern.matches("192.168.1.1", 22)); + assert!(!pattern.matches("not-an-ip", 22)); + } + + #[test] + fn target_pattern_cidr_ipv6() { + let network: IpNetwork = "fd00::/8".parse().unwrap(); + let pattern = TargetPattern::Cidr(network); + assert!(pattern.matches("fd00::1", 22)); + assert!(!pattern.matches("10.0.0.1", 22)); + } + + #[test] + fn target_pattern_port_range_matches() { + let pattern = TargetPattern::PortRange("localhost".to_string(), 8080..8090); + assert!(pattern.matches("localhost", 8080)); + assert!(pattern.matches("localhost", 8085)); + assert!(pattern.matches("localhost", 8089)); + assert!(!pattern.matches("localhost", 8079)); + assert!(!pattern.matches("localhost", 8090)); + assert!(!pattern.matches("otherhost", 8080)); + } + + #[test] + fn target_pattern_port_range_with_glob() { + let pattern = TargetPattern::PortRange("*.internal".to_string(), 3000..4000); + assert!(pattern.matches("app.internal", 3000)); + assert!(pattern.matches("app.internal", 3999)); + assert!(!pattern.matches("app.internal", 2999)); + assert!(!pattern.matches("app.internal", 4000)); + assert!(!pattern.matches("app.external", 3000)); + } + + #[test] + fn target_pattern_alknet_prefix() { + let pattern = TargetPattern::AlknetPrefix; + assert!(pattern.matches("alknet-control", 0)); + assert!(pattern.matches("alknet-status", 0)); + assert!(pattern.matches("alknet-", 0)); + assert!(!pattern.matches("example.com", 0)); + assert!(!pattern.matches("alknet.example.com", 0)); + } + + #[test] + fn default_fallthrough_allow() { + let policy = ForwardingPolicy { + default: ForwardingAction::Allow, + rules: vec![], + }; + let identity = make_identity("user1", vec![]); + assert!(policy.check("anything", 80, &identity, TransportKind::Tcp)); + } + + #[test] + fn default_fallthrough_deny() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![], + }; + let identity = make_identity("user1", vec![]); + assert!(!policy.check("anything", 80, &identity, TransportKind::Tcp)); + } + + #[test] + fn combined_principal_and_transport_matching() { + let policy = ForwardingPolicy { + default: ForwardingAction::Deny, + rules: vec![ForwardingRule { + target: TargetPattern::Host("restricted.example.com".to_string()), + action: ForwardingAction::Allow, + principals: vec!["admin".to_string()], + transports: vec![TransportKind::Tls], + }], + }; + let admin = make_identity("admin-user", vec!["admin"]); + let viewer = make_identity("viewer-user", vec!["viewer"]); + assert!(policy.check("restricted.example.com", 443, &admin, TransportKind::Tls)); + assert!(!policy.check("restricted.example.com", 443, &admin, TransportKind::Tcp)); + assert!(!policy.check("restricted.example.com", 443, &viewer, TransportKind::Tls)); + } + + #[test] + fn webtransport_restricted_to_alknet() { + let policy = ForwardingPolicy { + default: ForwardingAction::Allow, + rules: vec![ + ForwardingRule { + target: TargetPattern::AlknetPrefix, + action: ForwardingAction::Allow, + principals: vec![], + transports: vec![TransportKind::WebTransport], + }, + ForwardingRule { + target: TargetPattern::Any, + action: ForwardingAction::Deny, + principals: vec![], + transports: vec![TransportKind::WebTransport], + }, + ], + }; + let identity = make_identity("user1", vec![]); + assert!(policy.check("alknet-control", 0, &identity, TransportKind::WebTransport)); + assert!(!policy.check("example.com", 443, &identity, TransportKind::WebTransport)); + assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp)); + } + + #[test] + fn cidr_does_not_match_hostname() { + let network: IpNetwork = "10.0.0.0/8".parse().unwrap(); + let pattern = TargetPattern::Cidr(network); + assert!(!pattern.matches("example.com", 22)); + } +} diff --git a/crates/alknet-core/src/config/mod.rs b/crates/alknet-core/src/config/mod.rs index e3d0f30..a3048be 100644 --- a/crates/alknet-core/src/config/mod.rs +++ b/crates/alknet-core/src/config/mod.rs @@ -1,10 +1,11 @@ pub mod config_service; pub mod dynamic_config; +pub mod forwarding; pub mod static_config; pub use config_service::ConfigServiceImpl; pub use dynamic_config::{ - new_dynamic_config, AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction, - ForwardingPolicy, ForwardingRule, RateLimitConfig, + new_dynamic_config, AuthPolicy, ConfigReloadHandle, DynamicConfig, RateLimitConfig, }; +pub use forwarding::{ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern}; pub use static_config::StaticConfig; diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index d497f25..403228b 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -66,7 +66,7 @@ pub use client::channel_manager::{ChannelManager, ForwardRequest}; pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode}; pub use config::{ AuthPolicy, ConfigReloadHandle, ConfigServiceImpl, DynamicConfig, ForwardingAction, - ForwardingPolicy, ForwardingRule, RateLimitConfig, StaticConfig, + ForwardingPolicy, ForwardingRule, RateLimitConfig, StaticConfig, TargetPattern, }; pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; pub use server::serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server}; diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index cfa53d3..c3ec0f6 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -26,7 +26,7 @@ pub struct ProxyConfig { pub mode: ProxyMode, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum TransportKind { Tcp, Tls, @@ -48,6 +48,7 @@ impl std::fmt::Display for TransportKind { } pub struct ServerHandler { + dynamic: Arc>, identity_provider: Box, #[allow(dead_code)] outbound_proxy: Option, @@ -97,6 +98,7 @@ impl ServerHandler { }; Self { + dynamic, identity_provider, outbound_proxy, remote_addr, @@ -236,6 +238,34 @@ impl Handler for ServerHandler { return Ok(true); } + let identity = self + .authenticated_identity + .clone() + .unwrap_or_else(|| Identity { + id: String::new(), + scopes: vec![], + resources: std::collections::HashMap::new(), + }); + + let policy = self.dynamic.load(); + let allowed = policy.forwarding.check( + host_to_connect, + port_to_connect as u16, + &identity, + self.transport, + ); + + if !allowed { + tracing::info!( + remote_addr = ?self.remote_addr, + target = %format!("{host_to_connect}:{port_to_connect}"), + identity = %identity.id, + transport = %self.transport, + "forwarding denied by policy" + ); + return Ok(false); + } + let target_host = host_to_connect.to_string(); let target_port = port_to_connect; let proxy_config = self.outbound_proxy.clone().unwrap_or(ProxyConfig { diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index a30086d..80133fb 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -509,7 +509,7 @@ impl Server { .first() .expect("at least one listener required"); - let transport_kind = listener.transport_kind.clone(); + let transport_kind = listener.transport_kind; let stealth = listener.stealth; let listen_addr = listener.listen_addr.clone(); @@ -573,7 +573,7 @@ impl Server { }; let remote_addr = info.remote_addr; - let handler_transport_kind = transport_kind.clone(); + let handler_transport_kind = transport_kind; let handler = ServerHandler::new( Arc::clone(&server.dynamic),