diff --git a/Cargo.lock b/Cargo.lock index 9eea6b9..22a6600 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,7 @@ version = "0.1.0" dependencies = [ "alknet-core", "anyhow", + "arc-swap", "async-trait", "futures", "ipnetwork", @@ -185,6 +186,15 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "arc-swap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207" +dependencies = [ + "rustversion", +] + [[package]] name = "asn1-rs" version = "0.6.2" diff --git a/crates/alknet-core/Cargo.toml b/crates/alknet-core/Cargo.toml index 0f9f423..8b7ac8c 100644 --- a/crates/alknet-core/Cargo.toml +++ b/crates/alknet-core/Cargo.toml @@ -34,6 +34,7 @@ iroh = { version = "0.34", optional = true } url = { version = "2", optional = true } async-trait = "0.1" ipnetwork = "0.21.1" +arc-swap = "1" [dev-dependencies] alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] } diff --git a/crates/alknet-core/src/auth/client_auth.rs b/crates/alknet-core/src/auth/client_auth.rs index 1e83e9f..8e22f70 100644 --- a/crates/alknet-core/src/auth/client_auth.rs +++ b/crates/alknet-core/src/auth/client_auth.rs @@ -173,4 +173,4 @@ mod tests { let key2 = config.private_key(); assert!(Arc::ptr_eq(&key1, &key2)); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/auth/keys.rs b/crates/alknet-core/src/auth/keys.rs index 096899b..883c4c2 100644 --- a/crates/alknet-core/src/auth/keys.rs +++ b/crates/alknet-core/src/auth/keys.rs @@ -6,7 +6,7 @@ use std::path::PathBuf; -use russh::keys::{PrivateKey, PublicKey, decode_secret_key, parse_public_key_base64}; +use russh::keys::{decode_secret_key, parse_public_key_base64, PrivateKey, PublicKey}; use crate::error::ConfigError; @@ -98,10 +98,7 @@ fn parse_authorized_keys_line(line: &str) -> Option HashSet::new(), }; - let encoded_keys: HashSet> = authorized_keys - .iter() - .map(encode_key_data) - .collect(); + let encoded_keys: HashSet> = authorized_keys.iter().map(encode_key_data).collect(); let cert_authorities = match cert_authority_source { Some(src) => load_cert_authority_entries(src)?, @@ -135,10 +132,7 @@ fn check_critical_options( Ok(()) } -fn check_extensions( - cert: &Certificate, - ca_entry: &CertAuthorityEntry, -) -> Result<(), AuthError> { +fn check_extensions(cert: &Certificate, ca_entry: &CertAuthorityEntry) -> Result<(), AuthError> { let ca_permit_port_forwarding = ca_entry .options .iter() @@ -188,8 +182,8 @@ fn check_source_address(allowed: &str, client_ip: Option) -> bool { mod tests { use super::*; use rand_core::OsRng; - use russh::keys::{Certificate, PrivateKey, decode_secret_key}; use russh::keys::ssh_key::certificate::{Builder, CertType}; + use russh::keys::{decode_secret_key, Certificate, PrivateKey}; use std::io::Write; const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n"; @@ -218,13 +212,9 @@ mod tests { principals: Vec<&str>, ) -> Certificate { let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into(); - let mut builder = Builder::new_with_random_nonce( - &mut OsRng, - key_data, - valid_after, - valid_before, - ) - .unwrap(); + let mut builder = + Builder::new_with_random_nonce(&mut OsRng, key_data, valid_after, valid_before) + .unwrap(); builder.cert_type(CertType::User).unwrap(); @@ -252,11 +242,7 @@ mod tests { } else { format!("cert-authority,{}", options.join(",")) }; - let line = format!( - "{} {} CA\n", - opts, - ca_pub.to_openssh().unwrap() - ); + let line = format!("{} {} CA\n", opts, ca_pub.to_openssh().unwrap()); f.write_all(line.as_bytes()).unwrap(); f.flush().unwrap(); f @@ -357,13 +343,8 @@ mod tests { let user_pub = user_key.public_key().clone(); let now = now_secs(); let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into(); - let mut builder = Builder::new_with_random_nonce( - &mut OsRng, - key_data, - now - 60, - now + 3600, - ) - .unwrap(); + let mut builder = + Builder::new_with_random_nonce(&mut OsRng, key_data, now - 60, now + 3600).unwrap(); builder.cert_type(CertType::User).unwrap(); builder.all_principals_valid().unwrap(); let cert = builder.sign(&ca_key).unwrap(); @@ -383,7 +364,13 @@ mod tests { let other_ca_key = load_other_key(); let user_pub = user_key.public_key().clone(); let now = now_secs(); - let cert = make_cert(&other_ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]); + let cert = make_cert( + &other_ca_key, + &user_pub, + now - 60, + now + 3600, + vec!["testuser"], + ); let ca_key = load_ca_key(); let ca_pub = ca_key.public_key().clone(); let f = make_ca_file(&ca_pub, &[]); @@ -398,12 +385,11 @@ mod tests { #[test] fn no_config_accepts_nothing() { - let config = - ServerAuthConfig::from_keys_and_ca(None, None).unwrap(); + let config = ServerAuthConfig::from_keys_and_ca(None, None).unwrap(); let other_pub = load_other_key().public_key().clone(); assert_eq!( config.authenticate_publickey(&other_pub), Err(AuthError::KeyRejected) ); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/client/channel_manager.rs b/crates/alknet-core/src/client/channel_manager.rs index 792de6f..50d12d3 100644 --- a/crates/alknet-core/src/client/channel_manager.rs +++ b/crates/alknet-core/src/client/channel_manager.rs @@ -113,14 +113,10 @@ impl ChannelManager { .await .map_err(|_| ChannelError::ChannelClosed)?; - self.inner - .forwards - .write() - .await - .insert(ForwardRequest { - addr: addr.to_string(), - port, - }); + self.inner.forwards.write().await.insert(ForwardRequest { + addr: addr.to_string(), + port, + }); Ok(result) } @@ -132,14 +128,10 @@ impl ChannelManager { .await .map_err(|_| ChannelError::ChannelClosed)?; - self.inner - .forwards - .write() - .await - .remove(&ForwardRequest { - addr: addr.to_string(), - port, - }); + self.inner.forwards.write().await.remove(&ForwardRequest { + addr: addr.to_string(), + port, + }); Ok(()) } @@ -226,10 +218,7 @@ impl ChannelManager { for fwd in forwards.iter() { match handle.tcpip_forward(&fwd.addr, fwd.port).await { Ok(_) => { - debug!( - "re-registered tcpip_forward: {}:{}", - fwd.addr, fwd.port - ); + debug!("re-registered tcpip_forward: {}:{}", fwd.addr, fwd.port); } Err(e) => { warn!( @@ -476,4 +465,4 @@ mod tests { assert!(duration >= Duration::from_secs(1)); } } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/client/connect.rs b/crates/alknet-core/src/client/connect.rs index e2a485c..9ce38cf 100644 --- a/crates/alknet-core/src/client/connect.rs +++ b/crates/alknet-core/src/client/connect.rs @@ -197,10 +197,7 @@ pub struct ClientSession { } impl ClientSession { - pub async fn new( - opts: ConnectOptions, - transport: Arc, - ) -> Result { + pub async fn new(opts: ConnectOptions, transport: Arc) -> Result { opts.validate().map_err(ConnectError::Config)?; let auth_config = Arc::new( @@ -283,16 +280,13 @@ impl ClientSession { let remote_specs = build_remote_specs(&self.opts)?; for spec in &remote_specs { - let remote_forwarder = RemoteForwarder::new(spec.clone()) - .map_err(|_| ConnectError::ForwardFailed)?; + let remote_forwarder = + RemoteForwarder::new(spec.clone()).map_err(|_| ConnectError::ForwardFailed)?; let mut h = self.handle.lock().await; - remote_forwarder - .register(&mut h) - .await - .map_err(|_| { - warn!("failed to register remote forward {}", spec); - ConnectError::ForwardFailed - })?; + remote_forwarder.register(&mut h).await.map_err(|_| { + warn!("failed to register remote forward {}", spec); + ConnectError::ForwardFailed + })?; info!("registered remote forward: {}", spec); } @@ -307,7 +301,9 @@ impl ClientSession { let fwd_shutdown = self.shutdown_rx.clone(); let forward_task = tokio::spawn(async move { crate::client::forward::run_local_forwarders( - local_forwarders, fwd_handle, fwd_shutdown, + local_forwarders, + fwd_handle, + fwd_shutdown, ) .await; }); @@ -358,7 +354,14 @@ impl ClientSession { let handler = ClientHandler::from_config(&reconnect_auth); let username = reconnect_username.clone(); - match establish_session(&*reconnect_transport, handler, &reconnect_auth, &username).await { + match establish_session( + &*reconnect_transport, + handler, + &reconnect_auth, + &username, + ) + .await + { Ok(new_handle) => { info!("reconnection successful"); { @@ -370,8 +373,13 @@ impl ClientSession { Ok(rf) => { let mut h = reconnect_handle.lock().await; match rf.register(&mut h).await { - Ok(_) => debug!("re-registered remote forward: {}", spec), - Err(e) => warn!("failed to re-register remote forward {}: {e}", spec), + Ok(_) => { + debug!("re-registered remote forward: {}", spec) + } + Err(e) => warn!( + "failed to re-register remote forward {}: {e}", + spec + ), } } Err(e) => warn!("failed to create remote forwarder: {e}"), @@ -493,12 +501,10 @@ fn build_local_forwarders(opts: &ConnectOptions) -> Result, name: format!("invalid forward spec: {}", spec_str), }) })?; - forwarders.push( - LocalForwarder::new(spec).map_err(|e| { - warn!("failed to create local forwarder: {}", e); - ConnectError::ForwardFailed - })?, - ); + forwarders.push(LocalForwarder::new(spec).map_err(|e| { + warn!("failed to create local forwarder: {}", e); + ConnectError::ForwardFailed + })?); } Ok(forwarders) } @@ -576,7 +582,10 @@ mod tests { assert_eq!(opts.forwards.len(), 1); assert_eq!(opts.remote_forwards.len(), 1); assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080")); - assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com")); + assert_eq!( + opts.iroh_relay.as_deref(), + Some("https://relay.example.com") + ); assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test")); assert!(opts.insecure); } @@ -650,9 +659,18 @@ mod tests { #[test] fn connect_error_variants() { - assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed"); - assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed"); - assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed"); + assert_eq!( + ConnectError::ConnectionFailed.to_string(), + "connection failed" + ); + assert_eq!( + ConnectError::AuthFailed.to_string(), + "authentication failed" + ); + assert_eq!( + ConnectError::ForwardFailed.to_string(), + "forward setup failed" + ); } #[test] @@ -703,7 +721,10 @@ mod tests { let transport = Arc::new(FailTransport); let result = ClientSession::new(opts, transport).await; assert!(result.is_err()); - assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed)); + assert!(matches!( + result.err().unwrap(), + ConnectError::ConnectionFailed + )); } #[tokio::test] @@ -714,7 +735,10 @@ mod tests { let opts = ConnectOptions::new(make_identity()).server("example.com:22"); let result = ClientSession::new(opts, transport).await; assert!(result.is_err()); - assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed)); + assert!(matches!( + result.err().unwrap(), + ConnectError::ConnectionFailed + )); } #[test] @@ -750,7 +774,8 @@ mod tests { #[test] fn build_remote_specs_valid() { - let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000"); + let opts = + ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000"); let result = build_remote_specs(&opts); assert!(result.is_ok()); assert_eq!(result.unwrap().len(), 1); @@ -798,8 +823,8 @@ mod tests { #[tokio::test] async fn integration_mock_transport_session() { - use crate::socks5::{ChannelOpener, ChannelOpenError}; - use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + use crate::socks5::{ChannelOpenError, ChannelOpener}; + use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; struct MockOpener; @@ -839,9 +864,7 @@ mod tests { conn.read_exact(&mut auth_resp).await.unwrap(); assert_eq!(auth_resp, [0x05, 0x00]); - let connect_req = [ - 0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80, - ]; + let connect_req = [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80]; conn.write_all(&connect_req).await.unwrap(); let mut reply = [0u8; 10]; @@ -851,4 +874,4 @@ mod tests { conn.write_all(b"test data").await.unwrap(); conn.shutdown().await.unwrap(); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/client/forward.rs b/crates/alknet-core/src/client/forward.rs index b3eac6f..0950e75 100644 --- a/crates/alknet-core/src/client/forward.rs +++ b/crates/alknet-core/src/client/forward.rs @@ -205,12 +205,7 @@ async fn proxy_local_to_remote( let handle_guard = handle.lock().await; let channel = handle_guard - .channel_open_direct_tcpip( - remote_host, - remote_port as u32, - &local_addr, - 0, - ) + .channel_open_direct_tcpip(remote_host, remote_port as u32, &local_addr, 0) .await .map_err(|e| ForwardError::ChannelOpenFailed { source: Box::new(e) as _, @@ -470,11 +465,8 @@ mod tests { let bound_addr = listener.local_addr().unwrap(); drop(listener); - let spec = PortForwardSpec::local(&format!( - "127.0.0.1:{}:remote:5432", - bound_addr.port() - )) - .unwrap(); + let spec = PortForwardSpec::local(&format!("127.0.0.1:{}:remote:5432", bound_addr.port())) + .unwrap(); let forwarder = LocalForwarder::new(spec).unwrap(); assert_eq!(forwarder.local_port(), bound_addr.port()); } @@ -534,4 +526,4 @@ mod tests { let forwarder = RemoteForwarder::new(spec.clone()).unwrap(); assert_eq!(forwarder.spec(), &spec); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/client/mod.rs b/crates/alknet-core/src/client/mod.rs index 284b13d..837f2bb 100644 --- a/crates/alknet-core/src/client/mod.rs +++ b/crates/alknet-core/src/client/mod.rs @@ -14,4 +14,4 @@ pub mod forward; pub use channel_manager::{ChannelManager, ForwardRequest}; pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode}; -pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder}; \ No newline at end of file +pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder}; diff --git a/crates/alknet-core/src/config/dynamic_config.rs b/crates/alknet-core/src/config/dynamic_config.rs new file mode 100644 index 0000000..bdb8130 --- /dev/null +++ b/crates/alknet-core/src/config/dynamic_config.rs @@ -0,0 +1,395 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; + +use crate::auth::ServerAuthConfig; + +pub struct AuthPolicy { + pub authorized_keys: std::collections::HashSet, + pub cert_authorities: Vec, + encoded_keys: std::collections::HashSet>, +} + +fn encode_key_data(key: &russh::keys::PublicKey) -> Vec { + use russh::keys::helpers::EncodedExt; + key.key_data().encoded().unwrap_or_default() +} + +impl AuthPolicy { + pub fn new( + authorized_keys: std::collections::HashSet, + cert_authorities: Vec, + ) -> Self { + let encoded_keys = authorized_keys.iter().map(encode_key_data).collect(); + + Self { + authorized_keys, + cert_authorities, + encoded_keys, + } + } + + pub fn from_server_auth_config(config: ServerAuthConfig) -> Self { + Self::new(config.authorized_keys, config.cert_authorities) + } + + pub fn empty() -> Self { + Self::new(std::collections::HashSet::new(), Vec::new()) + } + + pub fn authenticate_publickey( + &self, + key: &russh::keys::PublicKey, + ) -> Result<(), crate::error::AuthError> { + let encoded = encode_key_data(key); + if self.encoded_keys.contains(&encoded) { + return Ok(()); + } + Err(crate::error::AuthError::KeyRejected) + } + + pub fn authenticate_certificate( + &self, + cert: &russh::keys::Certificate, + user: &str, + client_ip: Option, + ) -> Result<(), crate::error::AuthError> { + use std::time::SystemTime; + + let matching_ca = self + .cert_authorities + .iter() + .find(|ca| cert.signature_key() == ca.public_key.key_data()); + + let ca_entry = match matching_ca { + Some(entry) => entry, + None => return Err(crate::error::AuthError::CertInvalid), + }; + + if cert.verify_signature().is_err() { + return Err(crate::error::AuthError::CertInvalid); + } + + let now = SystemTime::now(); + let now_secs = now + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + + if now_secs < cert.valid_after() || now_secs >= cert.valid_before() { + return Err(crate::error::AuthError::CertExpired); + } + + let principals = cert.valid_principals(); + if !principals.is_empty() && !principals.iter().any(|p| p == user) { + return Err(crate::error::AuthError::CertPrincipalMismatch); + } + + check_critical_options(cert, ca_entry, client_ip)?; + check_extensions(cert, ca_entry)?; + + Ok(()) + } +} + +fn check_critical_options( + cert: &russh::keys::Certificate, + ca_entry: &crate::auth::keys::CertAuthorityEntry, + client_ip: Option, +) -> Result<(), crate::error::AuthError> { + let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty"); + + for (name, data) in cert.critical_options().iter() { + match name.as_str() { + "source-address" => { + if !check_source_address(data, client_ip) { + return Err(crate::error::AuthError::CertInvalid); + } + } + "force-command" => {} + "no-pty" => {} + _ => { + let _ = ca_has_no_pty; + return Err(crate::error::AuthError::CertInvalid); + } + } + } + + Ok(()) +} + +fn check_extensions( + cert: &russh::keys::Certificate, + ca_entry: &crate::auth::keys::CertAuthorityEntry, +) -> Result<(), crate::error::AuthError> { + let ca_permit_port_forwarding = ca_entry + .options + .iter() + .any(|o| o == "permit-port-forwarding"); + + if ca_permit_port_forwarding { + let cert_allows = cert + .extensions() + .iter() + .any(|(n, _)| n == "permit-port-forwarding"); + if !cert_allows { + return Err(crate::error::AuthError::CertInvalid); + } + } + + Ok(()) +} + +fn check_source_address(allowed: &str, client_ip: Option) -> bool { + use ipnetwork::IpNetwork; + use std::net::IpAddr; + use std::str::FromStr; + + let Some(ip) = client_ip else { + return false; + }; + + for pattern in allowed.split(',') { + let pattern = pattern.trim(); + if pattern.is_empty() { + continue; + } + + if let Ok(cidr) = IpNetwork::from_str(pattern) { + if cidr.contains(ip) { + return true; + } + } + + if let Ok(net_ip) = IpAddr::from_str(pattern) { + if net_ip == ip { + return true; + } + } + } + + false +} + +impl std::fmt::Debug for AuthPolicy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthPolicy") + .field("authorized_keys_count", &self.authorized_keys.len()) + .field("cert_authorities_count", &self.cert_authorities.len()) + .finish() + } +} + +impl Clone for AuthPolicy { + fn clone(&self) -> Self { + Self { + authorized_keys: self.authorized_keys.clone(), + cert_authorities: self.cert_authorities.clone(), + encoded_keys: self.encoded_keys.clone(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ForwardingAction { + Allow, + Deny, +} + +#[derive(Debug, Clone)] +pub struct ForwardingRule { + pub action: ForwardingAction, + pub principals: Vec, + pub transports: Vec, +} + +#[derive(Debug, Clone)] +pub struct ForwardingPolicy { + pub default: ForwardingAction, + pub rules: Vec, +} + +impl ForwardingPolicy { + pub fn allow_all() -> Self { + Self { + default: ForwardingAction::Allow, + rules: Vec::new(), + } + } + + pub fn deny_all() -> Self { + Self { + default: ForwardingAction::Deny, + rules: Vec::new(), + } + } +} + +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub max_connections_per_ip: usize, + pub max_auth_attempts: usize, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + max_connections_per_ip: 0, + max_auth_attempts: 10, + } + } +} + +#[derive(Debug, Clone)] +pub struct DynamicConfig { + pub auth: AuthPolicy, + pub forwarding: ForwardingPolicy, + pub rate_limits: RateLimitConfig, +} + +impl DynamicConfig { + pub fn new(auth: AuthPolicy) -> Self { + Self { + auth, + forwarding: ForwardingPolicy::allow_all(), + rate_limits: RateLimitConfig::default(), + } + } + + pub fn with_forwarding_policy(mut self, policy: ForwardingPolicy) -> Self { + self.forwarding = policy; + self + } + + pub fn with_rate_limits(mut self, limits: RateLimitConfig) -> Self { + self.rate_limits = limits; + self + } +} + +impl Default for DynamicConfig { + fn default() -> Self { + Self { + auth: AuthPolicy::empty(), + forwarding: ForwardingPolicy::allow_all(), + rate_limits: RateLimitConfig::default(), + } + } +} + +pub struct ConfigReloadHandle { + pub(crate) dynamic: Arc>, +} + +impl ConfigReloadHandle { + pub fn reload(&self, new_config: DynamicConfig) { + self.dynamic.store(Arc::new(new_config)); + } + + pub fn dynamic(&self) -> Arc { + self.dynamic.load_full() + } +} + +impl std::fmt::Debug for ConfigReloadHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConfigReloadHandle").finish() + } +} + +pub fn new_dynamic_config() -> (Arc>, ConfigReloadHandle) { + let inner = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))); + let handle = ConfigReloadHandle { + dynamic: Arc::clone(&inner), + }; + (inner, handle) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn forwarding_policy_allow_all_default() { + let policy = ForwardingPolicy::allow_all(); + assert_eq!(policy.default, ForwardingAction::Allow); + assert!(policy.rules.is_empty()); + } + + #[test] + fn forwarding_policy_deny_all() { + let policy = ForwardingPolicy::deny_all(); + assert_eq!(policy.default, ForwardingAction::Deny); + assert!(policy.rules.is_empty()); + } + + #[test] + fn dynamic_config_default() { + let config = DynamicConfig::default(); + assert_eq!(config.forwarding.default, ForwardingAction::Allow); + assert_eq!(config.rate_limits.max_connections_per_ip, 0); + assert_eq!(config.rate_limits.max_auth_attempts, 10); + } + + #[test] + fn config_reload_handle_updates_dynamic() { + let (arc_swap, handle) = new_dynamic_config(); + let initial = arc_swap.load(); + assert_eq!(initial.forwarding.default, ForwardingAction::Allow); + + let new_config = DynamicConfig { + auth: AuthPolicy::empty(), + forwarding: ForwardingPolicy::deny_all(), + rate_limits: RateLimitConfig::default(), + }; + handle.reload(new_config); + + let updated = arc_swap.load(); + assert_eq!(updated.forwarding.default, ForwardingAction::Deny); + } + + #[test] + fn dynamic_config_with_forwarding_policy_builder() { + let config = DynamicConfig::new(AuthPolicy::empty()) + .with_forwarding_policy(ForwardingPolicy::deny_all()); + assert_eq!(config.forwarding.default, ForwardingAction::Deny); + } + + #[test] + fn rate_limit_config_custom() { + let limits = RateLimitConfig { + max_connections_per_ip: 5, + max_auth_attempts: 3, + }; + assert_eq!(limits.max_connections_per_ip, 5); + assert_eq!(limits.max_auth_attempts, 3); + } + + #[test] + fn forwarding_action_equality() { + assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow); + assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny); + assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny); + } + + #[test] + fn auth_policy_empty_rejects_all() { + let policy = AuthPolicy::empty(); + let key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host"; + let other_ssh_key = + russh::keys::parse_public_key_base64(key_text.split_whitespace().nth(1).unwrap()) + .unwrap(); + assert_eq!( + policy.authenticate_publickey(&other_ssh_key), + Err(crate::error::AuthError::KeyRejected) + ); + } + + #[test] + fn auth_policy_debug_redacts_keys() { + let policy = AuthPolicy::empty(); + let debug_str = format!("{:?}", policy); + assert!(debug_str.contains("authorized_keys_count")); + assert!(debug_str.contains("cert_authorities_count")); + } +} diff --git a/crates/alknet-core/src/config/mod.rs b/crates/alknet-core/src/config/mod.rs new file mode 100644 index 0000000..dd4879c --- /dev/null +++ b/crates/alknet-core/src/config/mod.rs @@ -0,0 +1,8 @@ +pub mod dynamic_config; +pub mod static_config; + +pub use dynamic_config::{ + new_dynamic_config, AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction, + ForwardingPolicy, ForwardingRule, RateLimitConfig, +}; +pub use static_config::StaticConfig; diff --git a/crates/alknet-core/src/config/static_config.rs b/crates/alknet-core/src/config/static_config.rs new file mode 100644 index 0000000..f846a79 --- /dev/null +++ b/crates/alknet-core/src/config/static_config.rs @@ -0,0 +1,101 @@ +use crate::server::handler::{ProxyConfig, ProxyMode}; +use crate::server::serve::ServeTransportMode; +use std::net::SocketAddr; + +pub struct StaticConfig { + pub transport_mode: ServeTransportMode, + pub listen_addr: String, + pub tls_cert: Option, + pub tls_key: Option, + pub acme_domain: Option, + pub stealth: bool, + pub host_key: russh::keys::PrivateKey, + pub host_key_algorithm: russh::keys::Algorithm, + pub max_auth_attempts: usize, + pub max_connections_per_ip: usize, + pub proxy_config: Option, + pub iroh_relay: Option, +} + +impl std::fmt::Debug for StaticConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StaticConfig") + .field("transport_mode", &self.transport_mode) + .field("listen_addr", &self.listen_addr) + .field("tls_cert", &self.tls_cert.as_ref().map(|_| "")) + .field("tls_key", &self.tls_key.as_ref().map(|_| "")) + .field("acme_domain", &self.acme_domain) + .field("stealth", &self.stealth) + .field("host_key_algorithm", &self.host_key_algorithm) + .field("max_auth_attempts", &self.max_auth_attempts) + .field("max_connections_per_ip", &self.max_connections_per_ip) + .field("proxy_config", &self.proxy_config) + .field("iroh_relay", &self.iroh_relay) + .finish() + } +} + +impl StaticConfig { + pub fn from_serve_options( + opts: crate::server::serve::ServeOptions, + ) -> Result<(Self, crate::config::DynamicConfig), crate::error::ConfigError> { + opts.validate()?; + + let host_key = crate::auth::keys::load_private_key(opts.key.clone())?; + let host_key_algorithm = host_key.algorithm(); + + let auth_config = crate::auth::ServerAuthConfig::from_keys_and_ca( + opts.authorized_keys.clone(), + opts.cert_authority.clone(), + )?; + + let auth_policy = crate::config::AuthPolicy::from_server_auth_config(auth_config); + + let dynamic = crate::config::DynamicConfig::new(auth_policy); + + let proxy_config = parse_proxy_config(opts.proxy.as_deref()); + + let static_config = StaticConfig { + transport_mode: opts.transport_mode, + listen_addr: opts.listen_addr, + tls_cert: opts.tls_cert, + tls_key: opts.tls_key, + acme_domain: opts.acme_domain, + stealth: opts.stealth, + host_key, + host_key_algorithm, + max_auth_attempts: opts.max_auth_attempts, + max_connections_per_ip: opts.max_connections_per_ip, + proxy_config, + iroh_relay: opts.iroh_relay, + }; + + Ok((static_config, dynamic)) + } +} + +fn parse_proxy_config(proxy: Option<&str>) -> Option { + proxy.map(|url| { + if url.starts_with("socks5://") { + let addr: SocketAddr = url + .strip_prefix("socks5://") + .unwrap() + .parse() + .expect("invalid socks5 proxy address"); + ProxyConfig { + mode: ProxyMode::Socks5(addr), + } + } else if url.starts_with("http://") { + let addr: SocketAddr = url + .strip_prefix("http://") + .unwrap() + .parse() + .expect("invalid http connect proxy address"); + ProxyConfig { + mode: ProxyMode::HttpConnect(addr), + } + } else { + panic!("unsupported proxy URL scheme: {url}"); + } + }) +} diff --git a/crates/alknet-core/src/error.rs b/crates/alknet-core/src/error.rs index 556e1f4..8c5b04f 100644 --- a/crates/alknet-core/src/error.rs +++ b/crates/alknet-core/src/error.rs @@ -97,7 +97,10 @@ mod tests { #[test] fn transport_error_display() { - assert_eq!(TransportError::ConnectionFailed.to_string(), "connection failed"); + assert_eq!( + TransportError::ConnectionFailed.to_string(), + "connection failed" + ); assert_eq!( TransportError::HandshakeFailed { source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed") @@ -120,13 +123,19 @@ mod tests { assert_eq!(AuthError::KeyRejected.to_string(), "key rejected"); assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid"); assert_eq!(AuthError::CertExpired.to_string(), "certificate expired"); - assert_eq!(AuthError::CertPrincipalMismatch.to_string(), "certificate principal mismatch"); + assert_eq!( + AuthError::CertPrincipalMismatch.to_string(), + "certificate principal mismatch" + ); assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key"); } #[test] fn channel_error_display() { - assert_eq!(ChannelError::TargetUnreachable.to_string(), "target unreachable"); + assert_eq!( + ChannelError::TargetUnreachable.to_string(), + "target unreachable" + ); assert_eq!( ChannelError::ProxyConnectFailed { source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused") @@ -160,7 +169,10 @@ mod tests { .to_string(), "bind failed" ); - assert_eq!(ConfigError::IncompatibleOptions.to_string(), "incompatible options"); + assert_eq!( + ConfigError::IncompatibleOptions.to_string(), + "incompatible options" + ); } #[test] @@ -184,7 +196,10 @@ mod tests { #[test] fn forward_error_display() { assert_eq!( - ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(), + ForwardError::InvalidSpec { + spec: "bad".to_string() + } + .to_string(), "invalid port forward spec: bad" ); assert_eq!( @@ -209,7 +224,9 @@ mod tests { let forward_err = ForwardError::BindFailed { source: io_err }; assert!(forward_err.source().is_some()); - let plain = ForwardError::InvalidSpec { spec: "bad".to_string() }; + let plain = ForwardError::InvalidSpec { + spec: "bad".to_string(), + }; assert!(plain.source().is_none()); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index 72496c1..6566828 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -50,18 +50,23 @@ //! } //! ``` -pub mod transport; -pub mod client; -pub mod server; pub mod auth; -pub mod socks5; +pub mod client; +pub mod config; pub mod error; +pub mod server; +pub mod socks5; +pub mod transport; #[cfg(feature = "testutil")] pub mod testutil; -pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; -pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; pub use client::channel_manager::{ChannelManager, ForwardRequest}; pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode}; -pub use server::serve::{Server, ServeError, ServeOptions, ServeTransportMode}; \ No newline at end of file +pub use config::{ + AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction, ForwardingPolicy, + ForwardingRule, RateLimitConfig, StaticConfig, +}; +pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; +pub use server::serve::{ServeError, ServeOptions, ServeTransportMode, Server}; +pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; diff --git a/crates/alknet-core/src/server/channel_proxy.rs b/crates/alknet-core/src/server/channel_proxy.rs index c00a156..cb5b4d5 100644 --- a/crates/alknet-core/src/server/channel_proxy.rs +++ b/crates/alknet-core/src/server/channel_proxy.rs @@ -46,7 +46,10 @@ async fn connect_direct(target: SocketAddr) -> Result Result { +async fn connect_socks5( + target: SocketAddr, + proxy_addr: SocketAddr, +) -> Result { let mut stream = TcpStream::connect(proxy_addr) .await .map_err(ChannelProxyError::from)?; @@ -134,10 +137,7 @@ async fn connect_http_connect( } let response_str = String::from_utf8_lossy(&response); - let status_line = response_str - .lines() - .next() - .unwrap_or(""); + let status_line = response_str.lines().next().unwrap_or(""); if status_line.contains("200") { Ok(stream) @@ -279,11 +279,7 @@ mod tests { .parse() .unwrap(); - let reply = vec![ - 0x05, 0x00, 0x00, 0x01, - 0, 0, 0, 0, - 0, 0, - ]; + let reply = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0]; proxy_sock.write_all(&reply).await.unwrap(); let mut target_stream = TcpStream::connect(target).await.unwrap(); @@ -323,11 +319,7 @@ mod tests { let mut port_bytes = [0u8; 2]; proxy_sock.read_exact(&mut port_bytes).await.unwrap(); - let reply = vec![ - 0x05, 0x05, 0x00, 0x01, - 0, 0, 0, 0, - 0, 0, - ]; + let reply = vec![0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]; proxy_sock.write_all(&reply).await.unwrap(); }); @@ -560,4 +552,4 @@ mod tests { let proxy = direct_config(); proxy_channel(channel, target, &proxy).await; } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/server/control_channel.rs b/crates/alknet-core/src/server/control_channel.rs index cb232c7..9ba316a 100644 --- a/crates/alknet-core/src/server/control_channel.rs +++ b/crates/alknet-core/src/server/control_channel.rs @@ -189,4 +189,4 @@ mod tests { fn control_channel_destination_matches_prefix() { assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION)); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index e53d70e..23ce48c 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -2,16 +2,15 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Instant; +use arc_swap::ArcSwap; use async_trait::async_trait; use russh::keys::ssh_key::HashAlg; use russh::server::{Auth, Handler, Msg, Session}; use russh::Channel; use russh::ChannelId; -use crate::auth::ServerAuthConfig; -use crate::server::control_channel::{ - ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX, -}; +use crate::config::DynamicConfig; +use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX}; use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; #[derive(Debug, Clone)] @@ -44,7 +43,7 @@ impl std::fmt::Display for TransportKind { } pub struct ServerHandler { - auth_config: Arc, + dynamic: Arc>, #[allow(dead_code)] outbound_proxy: Option, remote_addr: Option, @@ -59,7 +58,7 @@ pub struct ServerHandler { impl ServerHandler { pub fn new( - auth_config: Arc, + dynamic: Arc>, outbound_proxy: Option, remote_addr: Option, transport: TransportKind, @@ -89,7 +88,7 @@ impl ServerHandler { }; Self { - auth_config, + dynamic, outbound_proxy, remote_addr, control_channel_router: ControlChannelRouter::without_handler(), @@ -127,10 +126,7 @@ impl Drop for ServerHandler { } impl ServerHandler { - pub fn with_control_channel_handler( - mut self, - handler: Box, - ) -> Self { + pub fn with_control_channel_handler(mut self, handler: Box) -> Self { self.control_channel_router = ControlChannelRouter::with_handler(handler); self } @@ -172,7 +168,8 @@ impl Handler for ServerHandler { .map_or("unknown".to_string(), |a| a.to_string()); let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user); - let result = self.auth_config.authenticate_publickey(&russh_pub); + let auth_config = self.dynamic.load(); + let result = auth_config.auth.authenticate_publickey(&russh_pub); match result { Ok(()) => { @@ -226,17 +223,25 @@ impl Handler for ServerHandler { }); tokio::spawn(async move { - let target = match format!("{target_host}:{target_port}").parse::() { - Ok(addr) => addr, - Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)).await { - Ok(mut addrs) => match addrs.next() { - Some(addr) => addr, - None => return, + let target = + match format!("{target_host}:{target_port}").parse::() { + Ok(addr) => addr, + Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)) + .await + { + Ok(mut addrs) => match addrs.next() { + Some(addr) => addr, + None => return, + }, + Err(_) => return, }, - Err(_) => return, - }, - }; - crate::server::channel_proxy::proxy_channel(channel.into_stream(), target, &proxy_config).await; + }; + crate::server::channel_proxy::proxy_channel( + channel.into_stream(), + target, + &proxy_config, + ) + .await; }); let _ = (originator_address, originator_port); @@ -389,7 +394,12 @@ impl Handler for ServerHandler { channel = %channel, "rejected x11 request on channel" ); - let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number); + let _ = ( + single_connection, + x11_auth_protocol, + x11_auth_cookie, + x11_screen_number, + ); let _ = session.channel_failure(channel); Ok(()) } @@ -469,6 +479,8 @@ impl Handler for ServerHandler { mod tests { use super::*; use crate::auth::keys::KeySource; + use crate::auth::ServerAuthConfig; + use crate::config::AuthPolicy; use russh::keys::{decode_secret_key, PrivateKey}; use std::io::Write; @@ -487,19 +499,19 @@ mod tests { decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap() } - fn make_auth_config(keys_content: &str) -> Arc { + fn make_auth_config(keys_content: &str) -> Arc> { let f = make_authorized_keys_file(keys_content); - Arc::new( - ServerAuthConfig::from_keys_and_ca( - Some(KeySource::File(f.path().to_path_buf())), - None, - ) - .unwrap(), - ) + let server_auth = + ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None) + .unwrap(); + let auth_policy = AuthPolicy::from_server_auth_config(server_auth); + let dynamic = DynamicConfig::new(auth_policy); + Arc::new(ArcSwap::new(Arc::new(dynamic))) } - fn make_empty_auth_config() -> Arc { - Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()) + fn make_empty_auth_config() -> Arc> { + let dynamic = DynamicConfig::default(); + Arc::new(ArcSwap::new(Arc::new(dynamic))) } fn default_limiter() -> Arc { @@ -507,11 +519,18 @@ mod tests { } fn make_handler( - auth_config: Arc, + dynamic: Arc>, outbound_proxy: Option, remote_addr: Option, ) -> ServerHandler { - ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10) + ServerHandler::new( + dynamic, + outbound_proxy, + remote_addr, + TransportKind::Tcp, + default_limiter(), + 10, + ) } #[tokio::test] @@ -530,10 +549,9 @@ mod tests { let mut handler = make_handler(auth_config, None, None); let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host"; - let other_ssh_key = russh::keys::parse_public_key_base64( - other_key_text.split_whitespace().nth(1).unwrap(), - ) - .unwrap(); + let other_ssh_key = + russh::keys::parse_public_key_base64(other_key_text.split_whitespace().nth(1).unwrap()) + .unwrap(); let result = handler .auth_publickey("testuser", &other_ssh_key) @@ -553,10 +571,7 @@ mod tests { let mut handler = make_handler(auth_config, None, None); let ssh_key = load_key().public_key().clone(); - let result = handler - .auth_publickey("testuser", &ssh_key) - .await - .unwrap(); + let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap(); assert_eq!( result, Auth::Reject { @@ -629,8 +644,16 @@ mod tests { #[test] fn one_handler_per_connection() { let auth_config = make_empty_auth_config(); - let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap())); - let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap())); + let handler1 = make_handler( + auth_config.clone(), + None, + Some("10.0.0.1:22".parse().unwrap()), + ); + let handler2 = make_handler( + auth_config.clone(), + None, + Some("10.0.0.2:22".parse().unwrap()), + ); assert!(handler1.remote_addr != handler2.remote_addr); } @@ -651,10 +674,20 @@ mod tests { let ssh_key = load_key().public_key().clone(); let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap(); - assert_eq!(r1, Auth::Reject { proceed_with_methods: None }); + assert_eq!( + r1, + Auth::Reject { + proceed_with_methods: None + } + ); let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap(); - assert_eq!(r2, Auth::Reject { proceed_with_methods: None }); + assert_eq!( + r2, + Auth::Reject { + proceed_with_methods: None + } + ); assert!(!handler.auth_limiter.check()); } @@ -733,4 +766,4 @@ mod tests { 10, ); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/server/mod.rs b/crates/alknet-core/src/server/mod.rs index 682fd4b..1624c85 100644 --- a/crates/alknet-core/src/server/mod.rs +++ b/crates/alknet-core/src/server/mod.rs @@ -16,10 +16,12 @@ pub mod stealth; pub use channel_proxy::{connect_outbound, proxy_channel}; pub use control_channel::{ - ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION, - ALKNET_PREFIX, is_reserved_destination, + is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream, + ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX, }; pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; -pub use serve::{Server, ServeError, ServeOptions, ServeTransportMode}; -pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config}; \ No newline at end of file +pub use serve::{ServeError, ServeOptions, ServeTransportMode, Server}; +pub use stealth::{ + detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection, +}; diff --git a/crates/alknet-core/src/server/rate_limit.rs b/crates/alknet-core/src/server/rate_limit.rs index b4cc308..9eeb686 100644 --- a/crates/alknet-core/src/server/rate_limit.rs +++ b/crates/alknet-core/src/server/rate_limit.rs @@ -197,4 +197,4 @@ mod tests { h.join().unwrap(); } } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index cb33879..1a586aa 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -8,12 +8,14 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use arc_swap::ArcSwap; use russh::server::{self, Config}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; use crate::auth::keys::KeySource; use crate::auth::server_auth::ServerAuthConfig; +use crate::config::{AuthPolicy, ConfigReloadHandle, DynamicConfig}; use crate::error::ConfigError; use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind}; use crate::server::rate_limit::ConnectionRateLimiter; @@ -228,7 +230,7 @@ struct ActiveSession { /// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting. pub struct Server { config: Arc, - auth_config: Arc, + dynamic: Arc>, connection_limiter: Arc, outbound_proxy: Option, stealth: bool, @@ -244,17 +246,24 @@ impl Server { pub fn new(opts: ServeOptions) -> Result { opts.validate().map_err(ServeError::Config)?; - let private_key = - crate::auth::keys::load_private_key(opts.key.clone()).map_err(ServeError::KeyLoadFailed)?; + let private_key = crate::auth::keys::load_private_key(opts.key.clone()) + .map_err(ServeError::KeyLoadFailed)?; - let auth_config = Arc::new( - ServerAuthConfig::from_keys_and_ca(opts.authorized_keys.clone(), opts.cert_authority.clone()) - .map_err(ServeError::KeyLoadFailed)?, - ); + let auth_config = ServerAuthConfig::from_keys_and_ca( + opts.authorized_keys.clone(), + opts.cert_authority.clone(), + ) + .map_err(ServeError::KeyLoadFailed)?; + + let auth_policy = AuthPolicy::from_server_auth_config(auth_config); + let dynamic_config = DynamicConfig::new(auth_policy); + + let max_auth_attempts = opts.max_auth_attempts; + let max_connections_per_ip = opts.max_connections_per_ip; let config = Arc::new(Config { keys: vec![private_key], - max_auth_attempts: opts.max_auth_attempts, + max_auth_attempts, methods: russh::MethodSet::PUBLICKEY, preferred: russh::Preferred::DEFAULT, ..Default::default() @@ -262,19 +271,21 @@ impl Server { let outbound_proxy = parse_proxy_config(opts.proxy.as_deref()); - let connection_limiter = Arc::new(ConnectionRateLimiter::new(opts.max_connections_per_ip)); + let connection_limiter = Arc::new(ConnectionRateLimiter::new(max_connections_per_ip)); let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config))); + Ok(Self { config, - auth_config, + dynamic, connection_limiter, outbound_proxy, stealth: opts.stealth, transport_mode: opts.transport_mode, listen_addr: opts.listen_addr, - max_auth_attempts: opts.max_auth_attempts, + max_auth_attempts, shutdown_tx, shutdown_rx, sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())), @@ -285,6 +296,12 @@ impl Server { self.shutdown_tx.clone() } + pub fn config_reload_handle(&self) -> ConfigReloadHandle { + ConfigReloadHandle { + dynamic: Arc::clone(&self.dynamic), + } + } + pub async fn shutdown(&self) -> Result<(), ServeError> { info!("initiating graceful shutdown"); let _ = self.shutdown_tx.send(true); @@ -292,11 +309,15 @@ impl Server { { let sessions = self.sessions.lock().await; for session in sessions.iter() { - if let Err(e) = session.handle.disconnect( - russh::Disconnect::ByApplication, - "shutdown".to_string(), - String::new(), - ).await { + if let Err(e) = session + .handle + .disconnect( + russh::Disconnect::ByApplication, + "shutdown".to_string(), + String::new(), + ) + .await + { warn!("failed to send SSH disconnect: {e}"); } } @@ -392,7 +413,7 @@ impl Server { let handler_transport_kind = transport_kind; let handler = ServerHandler::new( - Arc::clone(&server.auth_config), + Arc::clone(&server.dynamic), server.outbound_proxy.clone(), remote_addr, handler_transport_kind, @@ -410,15 +431,9 @@ impl Server { let transport_is_tls = server.transport_mode == ServeTransportMode::Tls; tokio::spawn(async move { - let result = handle_connection( - stream, - config, - handler, - sessions, - stealth, - transport_is_tls, - ) - .await; + let result = + handle_connection(stream, config, handler, sessions, stealth, transport_is_tls) + .await; if let Err(e) = result { warn!("connection error: {e}"); @@ -611,8 +626,7 @@ mod tests { #[test] fn serve_options_validate_tcp_with_acme_rejected() { - let opts = - ServeOptions::new(make_key_source()).acme_domain("example.com"); + let opts = ServeOptions::new(make_key_source()).acme_domain("example.com"); assert!(opts.validate().is_err()); } @@ -626,8 +640,8 @@ mod tests { #[test] fn server_new_creates_server() { - let opts = ServeOptions::new(make_key_source()) - .authorized_keys(make_authorized_keys_source()); + let opts = + ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source()); let server = Server::new(opts).unwrap(); assert_eq!(server.max_auth_attempts, 10); } @@ -662,8 +676,8 @@ mod tests { #[test] fn serve_options_debug_redacts_keys() { - let opts = ServeOptions::new(make_key_source()) - .authorized_keys(make_authorized_keys_source()); + let opts = + ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source()); let debug_str = format!("{:?}", opts); assert!(debug_str.contains("")); assert!(!debug_str.contains("OPENSSH")); @@ -715,8 +729,8 @@ mod tests { #[test] fn server_shutdown_sender_clones() { - let opts = ServeOptions::new(make_key_source()) - .authorized_keys(make_authorized_keys_source()); + let opts = + ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source()); let server = Server::new(opts).unwrap(); let sender = server.shutdown_sender(); assert!(!server.is_shutdown()); @@ -726,8 +740,7 @@ mod tests { #[test] fn server_holds_listen_addr() { - let opts = ServeOptions::new(make_key_source()) - .listen_addr("0.0.0.0:443"); + let opts = ServeOptions::new(make_key_source()).listen_addr("0.0.0.0:443"); let server = Server::new(opts).unwrap(); assert_eq!(server.listen_addr, "0.0.0.0:443"); } @@ -747,12 +760,10 @@ mod tests { let server = Server::new(opts).unwrap(); let shutdown_tx = server.shutdown_sender(); - let server_handle = tokio::spawn(async move { - server - .run(acceptor, None) - .await - .expect("server run failed") - }); + let server_handle = + tokio::spawn( + async move { server.run(acceptor, None).await.expect("server run failed") }, + ); tokio::time::sleep(Duration::from_millis(50)).await; @@ -760,6 +771,9 @@ mod tests { let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await; - assert!(result.is_ok(), "server should have shut down within timeout"); + assert!( + result.is_ok(), + "server should have shut down within timeout" + ); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/server/stealth.rs b/crates/alknet-core/src/server/stealth.rs index 42e212f..1481205 100644 --- a/crates/alknet-core/src/server/stealth.rs +++ b/crates/alknet-core/src/server/stealth.rs @@ -134,7 +134,10 @@ mod tests { let mut all_data = Vec::new(); reader.read_to_end(&mut all_data).await.unwrap(); - assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection"); + assert!( + all_data.starts_with(banner), + "banner bytes must be preserved after detection" + ); } #[tokio::test] @@ -142,7 +145,10 @@ mod tests { let (client, server) = duplex(1024); let (mut client_read, mut client_write) = tokio::io::split(client); - client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap(); + client_write + .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + .await + .unwrap(); drop(client_write); let (detection, mut reader) = detect_protocol(server).await; @@ -206,7 +212,10 @@ mod tests { let (client, server) = duplex(1024); let mut client = client; - client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap(); + client + .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + .await + .unwrap(); let (detection, mut reader) = detect_protocol(server).await; assert_eq!(detection, ProtocolDetection::Http); @@ -223,4 +232,4 @@ mod tests { let result = client.read(&mut extra).await; assert!(result.is_err() || result.unwrap() == 0); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/socks5/mod.rs b/crates/alknet-core/src/socks5/mod.rs index 88658d7..7089644 100644 --- a/crates/alknet-core/src/socks5/mod.rs +++ b/crates/alknet-core/src/socks5/mod.rs @@ -52,9 +52,7 @@ impl Socks5Server { } pub fn with_addr(channel_opener: C, addr: &str) -> Self { - let listen_addr: SocketAddr = addr - .parse() - .expect("invalid SOCKS5 listen address"); + let listen_addr: SocketAddr = addr.parse().expect("invalid SOCKS5 listen address"); Self { listen_addr, channel_opener: Arc::new(channel_opener), @@ -80,10 +78,7 @@ impl Socks5Server { } } -async fn handle_socks5_connection( - mut socket: S, - opener: Arc, -) -> Result<(), Socks5Error> +async fn handle_socks5_connection(mut socket: S, opener: Arc) -> Result<(), Socks5Error> where S: AsyncRead + AsyncWrite + Unpin, C: ChannelOpener, @@ -173,7 +168,11 @@ impl HandleChannelOpener { impl ChannelOpener for HandleChannelOpener { type Stream = russh::ChannelStream; - async fn open_channel(&self, host: String, port: u16) -> Result { + async fn open_channel( + &self, + host: String, + port: u16, + ) -> Result { let handle = self.handle.lock().await; if handle.is_closed() { return Err(ChannelOpenError::SessionClosed); @@ -241,7 +240,10 @@ mod tests { } async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] { - client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap(); + client + .write_all(&build_socks5_greeting(&[0x00])) + .await + .unwrap(); client.flush().await.unwrap(); let mut resp = [0u8; 2]; client.read_exact(&mut resp).await.unwrap(); @@ -264,9 +266,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); let resp = do_handshake(&mut client).await; assert_eq!(resp, [0x05, 0x00]); @@ -284,9 +285,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); client .write_all(&build_socks5_greeting(&[0x02])) @@ -301,10 +301,7 @@ mod tests { drop(client); let result = server_handle.await.unwrap(); assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - Socks5Error::NoAcceptableAuth - )); + assert!(matches!(result.unwrap_err(), Socks5Error::NoAcceptableAuth)); } #[tokio::test] @@ -312,9 +309,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); do_handshake(&mut client).await; let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await; @@ -329,9 +325,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); do_handshake(&mut client).await; @@ -354,9 +349,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); do_handshake(&mut client).await; @@ -381,9 +375,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: true }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); do_handshake(&mut client).await; let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await; @@ -399,9 +392,8 @@ mod tests { let (mut client, server) = duplex(4096); let opener = MockChannelOpener { fail: false }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await }); do_handshake(&mut client).await; @@ -450,9 +442,10 @@ mod tests { stream: Arc::clone(&ssh_stream), }; - let server_handle = tokio::spawn(async move { - handle_socks5_connection(server_sock, Arc::new(opener)).await - }); + let server_handle = + tokio::spawn( + async move { handle_socks5_connection(server_sock, Arc::new(opener)).await }, + ); do_handshake(&mut client_sock).await; let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await; @@ -494,4 +487,4 @@ mod tests { let server = Socks5Server::with_addr(opener, "127.0.0.1:9050"); assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap()); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/socks5/protocol.rs b/crates/alknet-core/src/socks5/protocol.rs index 272c187..1ca0efd 100644 --- a/crates/alknet-core/src/socks5/protocol.rs +++ b/crates/alknet-core/src/socks5/protocol.rs @@ -169,10 +169,7 @@ mod tests { let req = Socks5Request::read_from(&mut cursor).await.unwrap(); assert_eq!(req.version, 0x05); assert_eq!(req.command, 0x01); - assert_eq!( - req.address, - Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1)) - ); + assert_eq!(req.address, Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))); assert_eq!(req.port, 443); } @@ -201,7 +198,10 @@ mod tests { let req = Socks5Request::read_from(&mut cursor).await.unwrap(); assert_eq!(req.version, 0x05); assert_eq!(req.command, 0x01); - assert_eq!(req.address, Socks5Address::Domain("example.com".to_string())); + assert_eq!( + req.address, + Socks5Address::Domain("example.com".to_string()) + ); assert_eq!(req.port, 443); } @@ -301,4 +301,4 @@ mod tests { let port = cursor.read_u16().await.unwrap(); assert_eq!(port, 8080); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/testutil.rs b/crates/alknet-core/src/testutil.rs index 62e8ae4..a73e9e3 100644 --- a/crates/alknet-core/src/testutil.rs +++ b/crates/alknet-core/src/testutil.rs @@ -1,5 +1,5 @@ -use tokio::io::{DuplexStream, AsyncRead, AsyncWrite}; use anyhow::Result; +use tokio::io::{AsyncRead, AsyncWrite, DuplexStream}; #[cfg(feature = "transport-traits")] pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind}; @@ -9,10 +9,10 @@ pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKin #[cfg(not(feature = "transport-traits"))] mod local_traits { - use std::net::SocketAddr; use anyhow::Result; - use tokio::io::{AsyncRead, AsyncWrite}; use async_trait::async_trait; + use std::net::SocketAddr; + use tokio::io::{AsyncRead, AsyncWrite}; #[async_trait] pub trait Transport: Send + Sync + 'static { @@ -138,4 +138,4 @@ impl TransportAcceptor for MockTransportAcceptor { pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) { let (client, server) = tokio::io::duplex(buf_size); (MockStream::new(client), MockStream::new(server)) -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/transport/acme.rs b/crates/alknet-core/src/transport/acme.rs index c87cc46..b3080c3 100644 --- a/crates/alknet-core/src/transport/acme.rs +++ b/crates/alknet-core/src/transport/acme.rs @@ -7,9 +7,9 @@ use rustls::crypto::aws_lc_rs::default_provider; use rustls::ServerConfig; use rustls_acme::caches::DirCache; use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme}; -use tracing::{error, info}; use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor as TokioTlsAcceptor; +use tracing::{error, info}; use super::{TransportAcceptor, TransportInfo, TransportKind}; @@ -94,14 +94,10 @@ impl AcmeCertProvider { .contact(self.contact.clone()); let state = match &self.cache_dir { - Some(cache_dir) => { - base_config.cache(DirCache::new(cache_dir.clone())).state() - } - None => { - base_config - .cache(rustls_acme::caches::NoCache::default()) - .state() - } + Some(cache_dir) => base_config.cache(DirCache::new(cache_dir.clone())).state(), + None => base_config + .cache(rustls_acme::caches::NoCache::default()) + .state(), }; let resolver = state.resolver(); @@ -132,10 +128,7 @@ pub struct AcmeTlsAcceptor { } impl AcmeTlsAcceptor { - pub async fn bind_acme( - addr: SocketAddr, - provider: Arc, - ) -> Result { + pub async fn bind_acme(addr: SocketAddr, provider: Arc) -> Result { let (state, resolver) = provider.build_acme_state(); let server_config = provider.build_server_config_with_resolver(resolver.clone())?; @@ -193,11 +186,7 @@ impl TransportAcceptor for AcmeTlsAcceptor { let (tcp_stream, remote_addr) = self.listener.accept().await?; let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?; - let server_name = tls_stream - .get_ref() - .1 - .server_name() - .map(|s| s.to_string()); + let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string()); let info = TransportInfo { remote_addr: Some(remote_addr), @@ -277,8 +266,7 @@ mod tests { #[test] fn acme_cert_provider_build_state_with_cache() { - let provider = - AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache"); + let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache"); let (_state, resolver) = provider.build_acme_state(); assert!(Arc::strong_count(&resolver) >= 2); } @@ -288,7 +276,9 @@ mod tests { let _ = default_provider().install_default(); let provider = AcmeCertProvider::domain("example.com"); let (_, resolver) = provider.build_acme_state(); - let config = provider.build_server_config_with_resolver(resolver).unwrap(); + let config = provider + .build_server_config_with_resolver(resolver) + .unwrap(); assert!(!config.alpn_protocols.is_empty()); assert!(config .alpn_protocols @@ -359,4 +349,4 @@ mod tests { let acceptor = result.unwrap(); assert_eq!(acceptor.listen_addr().port(), 443); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/transport/iroh_transport.rs b/crates/alknet-core/src/transport/iroh_transport.rs index 08ef161..b1605b9 100644 --- a/crates/alknet-core/src/transport/iroh_transport.rs +++ b/crates/alknet-core/src/transport/iroh_transport.rs @@ -1,9 +1,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use iroh::{ - endpoint::RecvStream, - node_info::NodeIdExt, - Endpoint, NodeId, RelayMap, RelayMode, RelayUrl, + endpoint::RecvStream, node_info::NodeIdExt, Endpoint, NodeId, RelayMap, RelayMode, RelayUrl, }; use tokio::io; @@ -39,7 +37,9 @@ impl IrohTransport { proxy_url: Option, ) -> Result { let relay_url = relay_url.unwrap_or_else(|| { - DEFAULT_RELAY_URL.parse().expect("default relay URL is valid") + DEFAULT_RELAY_URL + .parse() + .expect("default relay URL is valid") }); let relay_map = RelayMap::from_url(relay_url); let mut builder = Endpoint::builder() @@ -49,7 +49,11 @@ impl IrohTransport { builder = builder.proxy_url(proxy.clone()); } let endpoint = builder.bind().await?; - Ok(Self { node_id, endpoint, owned: true }) + Ok(Self { + node_id, + endpoint, + owned: true, + }) } /// Create an iroh transport using an existing shared endpoint. @@ -60,7 +64,11 @@ impl IrohTransport { /// other protocol handlers on the same QUIC endpoint — one connection /// per peer, multiplexed by ALPN. pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self { - Self { node_id, endpoint, owned: false } + Self { + node_id, + endpoint, + owned: false, + } } pub fn endpoint_id(&self) -> String { @@ -115,12 +123,11 @@ impl IrohAcceptor { /// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN. /// /// Use this when alknet is the only iroh service on this node. - pub async fn bind( - relay_url: Option, - proxy_url: Option, - ) -> Result { + pub async fn bind(relay_url: Option, proxy_url: Option) -> Result { let relay_url = relay_url.unwrap_or_else(|| { - DEFAULT_RELAY_URL.parse().expect("default relay URL is valid") + DEFAULT_RELAY_URL + .parse() + .expect("default relay URL is valid") }); let relay_map = RelayMap::from_url(relay_url); let mut builder = Endpoint::builder() @@ -130,7 +137,10 @@ impl IrohAcceptor { builder = builder.proxy_url(proxy.clone()); } let endpoint = builder.bind().await?; - Ok(Self { endpoint, owned: true }) + Ok(Self { + endpoint, + owned: true, + }) } /// Create an iroh acceptor using an existing shared endpoint. @@ -146,7 +156,10 @@ impl IrohAcceptor { /// [`IrohAcceptor::bind`] instead, which handles the accept loop /// internally. pub fn from_endpoint(endpoint: Endpoint) -> Self { - Self { endpoint, owned: false } + Self { + endpoint, + owned: false, + } } pub fn endpoint_id(&self) -> String { @@ -219,18 +232,14 @@ mod tests { #[test] fn iroh_transport_describe_format() { - let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng) - .public() - .into(); + let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into(); let desc = format!("iroh://{}", node_id.to_z32()); assert!(desc.starts_with("iroh://")); } #[tokio::test] async fn iroh_transport_connect_builds_endpoint() { - let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng) - .public() - .into(); + let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into(); let transport = IrohTransport::new(node_id, None, None).await.unwrap(); assert!(transport.describe().starts_with("iroh://")); assert!(!transport.endpoint_id().is_empty()); @@ -239,9 +248,7 @@ mod tests { #[tokio::test] async fn iroh_transport_from_endpoint() { - let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng) - .public() - .into(); + let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into(); let acceptor = IrohAcceptor::bind(None, None).await.unwrap(); let endpoint = acceptor.endpoint.clone(); let transport = IrohTransport::from_endpoint(node_id, endpoint); @@ -318,4 +325,4 @@ mod tests { transport.connect().await.unwrap(); let _server_stream = accept_handle.await.unwrap(); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/transport/mod.rs b/crates/alknet-core/src/transport/mod.rs index d104bf8..71a05ba 100644 --- a/crates/alknet-core/src/transport/mod.rs +++ b/crates/alknet-core/src/transport/mod.rs @@ -13,13 +13,13 @@ //! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and //! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale. -mod tcp; #[cfg(feature = "iroh")] mod iroh_transport; +mod tcp; -pub use tcp::{TcpAcceptor, TcpTransport}; #[cfg(feature = "iroh")] pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN}; +pub use tcp::{TcpAcceptor, TcpTransport}; #[cfg(feature = "tls")] mod tls; @@ -89,12 +89,8 @@ pub struct TransportInfo { #[derive(Debug, Clone)] pub enum TransportKind { Tcp, - Tls { - server_name: Option, - }, - Iroh { - endpoint_id: String, - }, + Tls { server_name: Option }, + Iroh { endpoint_id: String }, } #[cfg(test)] @@ -185,4 +181,4 @@ mod tests { assert_eq!(endpoint_id, "abc123"); } } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/transport/tcp.rs b/crates/alknet-core/src/transport/tcp.rs index a0454a8..473f1d7 100644 --- a/crates/alknet-core/src/transport/tcp.rs +++ b/crates/alknet-core/src/transport/tcp.rs @@ -159,4 +159,4 @@ mod tests { .unwrap(); assert_ne!(acceptor.listen_addr().port(), 0); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/src/transport/tls.rs b/crates/alknet-core/src/transport/tls.rs index 8c53d25..e8c9a2f 100644 --- a/crates/alknet-core/src/transport/tls.rs +++ b/crates/alknet-core/src/transport/tls.rs @@ -7,7 +7,9 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig}; use tokio::net::{TcpListener, TcpStream}; -use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector}; +use tokio_rustls::{ + client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector, +}; #[cfg(feature = "acme")] use rustls::crypto::aws_lc_rs::default_provider; @@ -169,7 +171,9 @@ impl TlsAcceptor { .map_err(|e| anyhow!("failed to set protocol versions: {}", e))? .with_no_client_auth() .with_cert_resolver(acme_resolver); - server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec()); + server_config + .alpn_protocols + .push(ACME_TLS_ALPN_NAME.to_vec()); let server_config = Arc::new(server_config); let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone()); @@ -195,11 +199,7 @@ impl TransportAcceptor for TlsAcceptor { let (tcp_stream, remote_addr) = self.listener.accept().await?; let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?; - let server_name = tls_stream - .get_ref() - .1 - .server_name() - .map(|s| s.to_string()); + let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string()); let info = TransportInfo { remote_addr: Some(remote_addr), @@ -324,10 +324,7 @@ mod tests { let (mut server, info) = accept_handle.await.unwrap(); assert!(info.remote_addr.is_some()); - assert!(matches!( - info.transport_kind, - TransportKind::Tls { .. } - )); + assert!(matches!(info.transport_kind, TransportKind::Tls { .. })); client.write_all(b"hello tls").await.unwrap(); let mut buf = [0u8; 9]; @@ -429,4 +426,4 @@ mod tests { let verifier = NoVerifier; assert!(verifier.supported_verify_schemes().len() > 0); } -} \ No newline at end of file +} diff --git a/crates/alknet-core/tests/auth_tests.rs b/crates/alknet-core/tests/auth_tests.rs index 4504cca..0111cd3 100644 --- a/crates/alknet-core/tests/auth_tests.rs +++ b/crates/alknet-core/tests/auth_tests.rs @@ -1,2 +1,2 @@ #[tokio::test] -async fn auth_placeholder() {} \ No newline at end of file +async fn auth_placeholder() {} diff --git a/crates/alknet-core/tests/client_tests.rs b/crates/alknet-core/tests/client_tests.rs index 2276dcd..6b4477a 100644 --- a/crates/alknet-core/tests/client_tests.rs +++ b/crates/alknet-core/tests/client_tests.rs @@ -1,2 +1,2 @@ #[tokio::test] -async fn client_placeholder() {} \ No newline at end of file +async fn client_placeholder() {} diff --git a/crates/alknet-core/tests/server_tests.rs b/crates/alknet-core/tests/server_tests.rs index 5de0852..5bdedd3 100644 --- a/crates/alknet-core/tests/server_tests.rs +++ b/crates/alknet-core/tests/server_tests.rs @@ -1,2 +1,2 @@ #[tokio::test] -async fn server_placeholder() {} \ No newline at end of file +async fn server_placeholder() {} diff --git a/crates/alknet-core/tests/transport_tests.rs b/crates/alknet-core/tests/transport_tests.rs index ead868b..0439be4 100644 --- a/crates/alknet-core/tests/transport_tests.rs +++ b/crates/alknet-core/tests/transport_tests.rs @@ -1,4 +1,6 @@ -use alknet_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair}; +use alknet_core::testutil::{ + mock_pair, MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, +}; #[tokio::test] async fn mock_transport_connect() { @@ -23,4 +25,4 @@ async fn mock_pair_communicates() { let mut buf = [0u8; 5]; server.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"hello"); -} \ No newline at end of file +} diff --git a/crates/alknet-napi/src/serve.rs b/crates/alknet-napi/src/serve.rs index 7a87b06..baa0a14 100644 --- a/crates/alknet-napi/src/serve.rs +++ b/crates/alknet-napi/src/serve.rs @@ -328,7 +328,12 @@ impl russh::server::Handler for NapiServerHandler { session: &mut russh::server::Session, ) -> std::result::Result<(), Self::Error> { tracing::warn!(channel = %channel, "rejected x11 request"); - let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number); + let _ = ( + single_connection, + x11_auth_protocol, + x11_auth_cookie, + x11_screen_number, + ); let _ = session.channel_failure(channel); Ok(()) } @@ -348,7 +353,11 @@ impl russh::server::Handler for NapiServerHandler { port: &mut u32, _session: &mut russh::server::Session, ) -> std::result::Result { - tracing::warn!(address = address, port = *port, "rejected tcpip-forward request"); + tracing::warn!( + address = address, + port = *port, + "rejected tcpip-forward request" + ); Ok(false) } @@ -367,7 +376,10 @@ impl russh::server::Handler for NapiServerHandler { socket_path: &str, _session: &mut russh::server::Session, ) -> std::result::Result { - tracing::warn!(socket_path = socket_path, "rejected streamlocal-forward request"); + tracing::warn!( + socket_path = socket_path, + "rejected streamlocal-forward request" + ); Ok(false) } @@ -542,8 +554,8 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result { })?, ); - let private_key = - alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| { + let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone()) + .map_err(|e| { napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e)) })?; @@ -635,26 +647,28 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result { ) })?; - let acceptor = TlsAcceptor::bind(addr, certs, key, None).await.map_err(|e| { - napi::Error::new( - napi::Status::GenericFailure, - format!("tls bind failed: {}", e), - ) - })?; + let acceptor = TlsAcceptor::bind(addr, certs, key, None) + .await + .map_err(|e| { + napi::Error::new( + napi::Status::GenericFailure, + format!("tls bind failed: {}", e), + ) + })?; let actual_listen = acceptor.listen_addr().to_string(); let auth_config = Arc::new( ServerAuthConfig::from_keys_and_ca(authorized_keys_source, cert_authority_source) .map_err(|e| { - napi::Error::new( - napi::Status::InvalidArg, - format!("auth config error: {}", e), - ) - })?, + napi::Error::new( + napi::Status::InvalidArg, + format!("auth config error: {}", e), + ) + })?, ); - let private_key = - alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| { + let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone()) + .map_err(|e| { napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e)) })?; @@ -728,11 +742,11 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result { let auth_config = Arc::new( ServerAuthConfig::from_keys_and_ca(authorized_keys_source, cert_authority_source) .map_err(|e| { - napi::Error::new( - napi::Status::InvalidArg, - format!("auth config error: {}", e), - ) - })?, + napi::Error::new( + napi::Status::InvalidArg, + format!("auth config error: {}", e), + ) + })?, ); let private_key = diff --git a/crates/alknet/src/main.rs b/crates/alknet/src/main.rs index 070ab84..93f9eea 100644 --- a/crates/alknet/src/main.rs +++ b/crates/alknet/src/main.rs @@ -10,8 +10,6 @@ use std::net::SocketAddr; use std::process; use std::sync::Arc; -use anyhow::{anyhow, Result}; -use clap::{Parser, Subcommand, ValueEnum}; use alknet_core::auth::keys::KeySource; use alknet_core::client::{ConnectOptions, TransportMode}; use alknet_core::server::{ServeOptions, ServeTransportMode, Server}; @@ -21,6 +19,8 @@ use alknet_core::transport::TcpTransport; #[cfg(feature = "tls")] use alknet_core::transport::TlsTransport; use alknet_core::transport::Transport; +use anyhow::{anyhow, Result}; +use clap::{Parser, Subcommand, ValueEnum}; #[derive(Parser)] #[command(name = "alknet", version, about = "Alknet SSH tunnel tool")] @@ -76,7 +76,7 @@ enum Commands { insecure: bool, }, - #[command( about = "Start the alknet server (accept SSH connections)")] + #[command(about = "Start the alknet server (accept SSH connections)")] Serve { #[arg(long, help = "SSH host key path (required)")] key: String,