feat(core): wire IdentityProvider and ForwardingPolicy into ServerHandler

- Change ServerHandler to hold Arc<dyn IdentityProvider> instead of Box<dyn IdentityProvider>
- Refactor Server::new() to use StaticConfig::from_serve_options() producing (StaticConfig, DynamicConfig)
- Remove duplicate parse_proxy_config from serve.rs (now in static_config.rs)
- Add with_identity_provider() accepting Arc<dyn IdentityProvider>
- Add integration tests for DynamicConfig reload and ForwardingPolicy deny
- Add test for custom IdentityProvider injection via with_identity_provider
- Move parse_proxy_config tests to static_config.rs module
This commit is contained in:
2026-06-07 15:12:38 +00:00
parent ee1cee6004
commit fe53300956
3 changed files with 261 additions and 106 deletions

View File

@@ -120,3 +120,85 @@ fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
} }
}) })
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use crate::server::handler::TransportKind;
use crate::server::serve::ServeOptions;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
fn make_key_source() -> KeySource {
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
}
fn make_authorized_keys_source() -> KeySource {
KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec())
}
#[test]
fn parse_proxy_config_socks5() {
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::Socks5(addr) => {
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
}
_ => panic!("expected Socks5"),
}
}
#[test]
fn parse_proxy_config_http() {
let config = parse_proxy_config(Some("http://127.0.0.1:8080"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::HttpConnect(addr) => {
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn parse_proxy_config_none() {
assert!(parse_proxy_config(None).is_none());
}
#[test]
fn static_config_from_serve_options_basic() {
let opts =
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
let (static_config, dynamic) = StaticConfig::from_serve_options(opts).unwrap();
assert_eq!(static_config.listen_addr, "0.0.0.0:22");
assert_eq!(static_config.max_auth_attempts, 10);
assert!(dynamic.auth.authorized_keys.len() > 0);
}
#[test]
fn static_config_from_serve_options_with_proxy() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.proxy("socks5://127.0.0.1:9050");
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
assert!(static_config.proxy_config.is_some());
}
#[test]
fn static_config_from_serve_options_with_listeners() {
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")];
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source())
.listeners(listeners);
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
assert_eq!(static_config.listeners.len(), 1);
assert_eq!(
static_config.listeners[0].transport_kind,
TransportKind::Tcp
);
}
}

View File

@@ -49,7 +49,7 @@ impl std::fmt::Display for TransportKind {
pub struct ServerHandler { pub struct ServerHandler {
dynamic: Arc<ArcSwap<DynamicConfig>>, dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Box<dyn IdentityProvider>, identity_provider: Arc<dyn IdentityProvider>,
#[allow(dead_code)] #[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>, outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>, remote_addr: Option<SocketAddr>,
@@ -72,8 +72,8 @@ impl ServerHandler {
connection_limiter: Arc<ConnectionRateLimiter>, connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize, max_auth_attempts: usize,
) -> Self { ) -> Self {
let identity_provider: Box<dyn IdentityProvider> = let identity_provider: Arc<dyn IdentityProvider> =
Box::new(ConfigIdentityProvider::new(Arc::clone(&dynamic))); Arc::new(ConfigIdentityProvider::new(Arc::clone(&dynamic)));
let allowed = if let Some(addr) = remote_addr { let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip(); let ip = addr.ip();
@@ -112,7 +112,7 @@ impl ServerHandler {
} }
} }
pub fn with_identity_provider(mut self, provider: Box<dyn IdentityProvider>) -> Self { pub fn with_identity_provider(mut self, provider: Arc<dyn IdentityProvider>) -> Self {
self.identity_provider = provider; self.identity_provider = provider;
self self
} }
@@ -818,4 +818,167 @@ mod tests {
10, 10,
); );
} }
#[tokio::test]
async fn config_reload_new_keys_take_effect() {
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
let mut handler = ServerHandler::new(
auth_config.clone(),
None,
None,
TransportKind::Tcp,
default_limiter(),
10,
);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(result, Auth::Accept);
drop(handler);
let new_dynamic = DynamicConfig::default();
auth_config.store(Arc::new(new_dynamic));
let mut handler2 = ServerHandler::new(
auth_config.clone(),
None,
None,
TransportKind::Tcp,
default_limiter(),
10,
);
let result2 = handler2.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(
result2,
Auth::Reject {
proceed_with_methods: None
}
);
}
#[tokio::test]
async fn forwarding_policy_deny_blocks_channel_open() {
use crate::config::forwarding::{
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
};
let deny_policy = ForwardingPolicy {
default: ForwardingAction::Deny,
rules: vec![ForwardingRule {
target: TargetPattern::Host("blocked.example.com".to_string()),
action: ForwardingAction::Deny,
principals: vec![],
transports: vec![],
}],
};
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
{
let dynamic = auth_config.load();
let new_dynamic = DynamicConfig {
auth: dynamic.auth.clone(),
forwarding: deny_policy,
rate_limits: dynamic.rate_limits.clone(),
};
drop(dynamic);
auth_config.store(Arc::new(new_dynamic));
}
let mut handler = ServerHandler::new(
auth_config,
None,
Some("127.0.0.1:12345".parse().unwrap()),
TransportKind::Tcp,
default_limiter(),
10,
);
let ssh_key = load_key().public_key().clone();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(result, Auth::Accept);
assert!(handler.authenticated_identity().is_some());
let identity = handler.authenticated_identity().unwrap();
let dynamic = handler.dynamic.load();
assert!(!dynamic.forwarding.check(
"blocked.example.com",
443,
identity,
TransportKind::Tcp
));
}
#[test]
fn forwarding_policy_deny_with_custom_identity() {
use crate::config::forwarding::{
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
};
use std::collections::HashMap;
let mut resources = HashMap::new();
resources.insert("service".to_string(), vec!["gitea".to_string()]);
let identity = Identity {
id: "SHA256:abc123".to_string(),
scopes: vec!["relay:connect".to_string()],
resources,
};
let policy = ForwardingPolicy {
default: ForwardingAction::Deny,
rules: vec![ForwardingRule {
target: TargetPattern::Host("allowed.example.com".to_string()),
action: ForwardingAction::Allow,
principals: vec!["SHA256:abc123".to_string()],
transports: vec![TransportKind::Tcp],
}],
};
assert!(policy.check("allowed.example.com", 443, &identity, TransportKind::Tcp));
assert!(!policy.check("denied.example.com", 443, &identity, TransportKind::Tcp));
}
#[test]
fn server_handler_with_custom_identity_provider() {
use std::collections::HashMap;
struct MockIdentityProvider {
identities: HashMap<String, Identity>,
}
impl IdentityProvider for MockIdentityProvider {
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
self.identities.get(fingerprint).cloned()
}
fn resolve_from_token(&self, _token: &crate::auth::AuthToken) -> Option<Identity> {
None
}
}
let mut identities = HashMap::new();
identities.insert(
"SHA256:testkey".to_string(),
Identity {
id: "SHA256:testkey".to_string(),
scopes: vec!["admin".to_string()],
resources: HashMap::new(),
},
);
let provider = Arc::new(MockIdentityProvider { identities }) as Arc<dyn IdentityProvider>;
let dynamic = make_empty_auth_config();
let handler = ServerHandler::new(
dynamic,
None,
Some("10.0.0.1:22".parse().unwrap()),
TransportKind::Tcp,
default_limiter(),
10,
)
.with_identity_provider(provider);
assert!(handler.authenticated_identity().is_none());
}
} }

View File

@@ -5,7 +5,6 @@
//! `ServeOptions` provides a builder-pattern API for programmatic configuration. //! `ServeOptions` provides a builder-pattern API for programmatic configuration.
//! Supports multiple listeners via `ListenerConfig` for multi-transport operation. //! Supports multiple listeners via `ListenerConfig` for multi-transport operation.
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@@ -15,10 +14,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::auth::keys::KeySource; use crate::auth::keys::KeySource;
use crate::auth::server_auth::ServerAuthConfig; use crate::config::{ConfigReloadHandle, DynamicConfig};
use crate::config::{AuthPolicy, ConfigReloadHandle, DynamicConfig};
use crate::error::ConfigError; use crate::error::ConfigError;
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; use crate::server::handler::{ProxyConfig, ServerHandler, TransportKind};
use crate::server::rate_limit::ConnectionRateLimiter; use crate::server::rate_limit::ConnectionRateLimiter;
use crate::server::stealth::{self, ProtocolDetection}; use crate::server::stealth::{self, ProtocolDetection};
@@ -387,65 +385,32 @@ pub struct Server {
impl Server { impl Server {
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> { pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
opts.validate().map_err(ServeError::Config)?; let (static_config, dynamic_config) =
crate::config::StaticConfig::from_serve_options(opts).map_err(ServeError::Config)?;
let private_key = crate::auth::keys::load_private_key(opts.key.clone()) let connection_limiter = Arc::new(ConnectionRateLimiter::new(
.map_err(ServeError::KeyLoadFailed)?; static_config.max_connections_per_ip,
));
let auth_config = ServerAuthConfig::from_keys_and_ca(
opts.authorized_keys.clone(),
opts.cert_authority.clone(),
)
.map_err(ServeError::KeyLoadFailed)?;
let auth_policy = AuthPolicy::from_server_auth_config(auth_config);
let dynamic_config = DynamicConfig::new(auth_policy);
let max_auth_attempts = opts.max_auth_attempts;
let max_connections_per_ip = opts.max_connections_per_ip;
let config = Arc::new(Config { let config = Arc::new(Config {
keys: vec![private_key], keys: vec![static_config.host_key],
max_auth_attempts, max_auth_attempts: static_config.max_auth_attempts,
methods: russh::MethodSet::PUBLICKEY, methods: russh::MethodSet::PUBLICKEY,
preferred: russh::Preferred::DEFAULT, preferred: russh::Preferred::DEFAULT,
..Default::default() ..Default::default()
}); });
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
let connection_limiter = Arc::new(ConnectionRateLimiter::new(max_connections_per_ip));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config))); let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config)));
let listeners = if let Some(listeners) = opts.listeners {
listeners
} else {
let transport_kind = match opts.transport_mode {
ServeTransportMode::Tcp => TransportKind::Tcp,
ServeTransportMode::Tls => TransportKind::Tls,
ServeTransportMode::Iroh => TransportKind::Iroh,
};
vec![ListenerConfig {
transport_kind,
listen_addr: opts.listen_addr.clone(),
tls_cert: opts.tls_cert.clone(),
tls_key: opts.tls_key.clone(),
acme_domain: opts.acme_domain.clone(),
stealth: opts.stealth,
iroh_relay: opts.iroh_relay.clone(),
}]
};
Ok(Self { Ok(Self {
config, config,
dynamic, dynamic,
connection_limiter, connection_limiter,
outbound_proxy, outbound_proxy: static_config.proxy_config,
listeners, listeners: static_config.listeners,
max_auth_attempts, max_auth_attempts: static_config.max_auth_attempts,
shutdown_tx, shutdown_tx,
shutdown_rx, shutdown_rx,
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())), sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
@@ -656,32 +621,6 @@ where
Ok(()) Ok(())
} }
fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
proxy.map(|url| {
if url.starts_with("socks5://") {
let addr: SocketAddr = url
.strip_prefix("socks5://")
.unwrap()
.parse()
.expect("invalid socks5 proxy address");
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
} else if url.starts_with("http://") {
let addr: SocketAddr = url
.strip_prefix("http://")
.unwrap()
.parse()
.expect("invalid http connect proxy address");
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
} else {
panic!("unsupported proxy URL scheme: {url}");
}
})
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -850,35 +789,6 @@ mod tests {
assert!(!debug_str.contains("OPENSSH")); assert!(!debug_str.contains("OPENSSH"));
} }
#[test]
fn parse_proxy_config_socks5() {
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::Socks5(addr) => {
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
}
_ => panic!("expected Socks5"),
}
}
#[test]
fn parse_proxy_config_http() {
let config = parse_proxy_config(Some("http://127.0.0.1:8080"));
assert!(config.is_some());
match config.unwrap().mode {
ProxyMode::HttpConnect(addr) => {
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
}
_ => panic!("expected HttpConnect"),
}
}
#[test]
fn parse_proxy_config_none() {
assert!(parse_proxy_config(None).is_none());
}
#[test] #[test]
fn serve_error_variants() { fn serve_error_variants() {
assert_eq!(ServeError::AcceptFailed.to_string(), "accept failed"); assert_eq!(ServeError::AcceptFailed.to_string(), "accept failed");