feat(core): implement StaticConfig/DynamicConfig split with ArcSwap hot-reload

Split alknet-core configuration into StaticConfig (immutable after startup)
and DynamicConfig (hot-reloadable at runtime via ArcSwap).

- Add StaticConfig struct in config/static_config.rs with all fields per ADR-030
- Add DynamicConfig struct with AuthPolicy, ForwardingPolicy, RateLimitConfig
- Add ForwardingPolicy with allow_all()/deny_all() defaults (ADR-031)
- Add ConfigReloadHandle with reload() method for runtime config updates
- Replace Arc<ServerAuthConfig> with Arc<ArcSwap<DynamicConfig>> in ServerHandler
- Add config_reload_handle() to Server for obtaining reload handles
- Add AuthPolicy with authenticate_publickey/authenticate_certificate methods
- All existing tests pass with the new config structure
- Default DynamicConfig produces identical behavior to current code
This commit is contained in:
2026-06-07 14:03:46 +00:00
parent a7f0dcdeb9
commit ee1b3f3819
36 changed files with 964 additions and 393 deletions

10
Cargo.lock generated
View File

@@ -75,6 +75,7 @@ version = "0.1.0"
dependencies = [
"alknet-core",
"anyhow",
"arc-swap",
"async-trait",
"futures",
"ipnetwork",
@@ -185,6 +186,15 @@ version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]]
name = "arc-swap"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207"
dependencies = [
"rustversion",
]
[[package]]
name = "asn1-rs"
version = "0.6.2"

View File

@@ -34,6 +34,7 @@ iroh = { version = "0.34", optional = true }
url = { version = "2", optional = true }
async-trait = "0.1"
ipnetwork = "0.21.1"
arc-swap = "1"
[dev-dependencies]
alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] }

View File

@@ -173,4 +173,4 @@ mod tests {
let key2 = config.private_key();
assert!(Arc::ptr_eq(&key1, &key2));
}
}
}

View File

@@ -6,7 +6,7 @@
use std::path::PathBuf;
use russh::keys::{PrivateKey, PublicKey, decode_secret_key, parse_public_key_base64};
use russh::keys::{decode_secret_key, parse_public_key_base64, PrivateKey, PublicKey};
use crate::error::ConfigError;
@@ -98,10 +98,7 @@ fn parse_authorized_keys_line(line: &str) -> Option<Result<(PublicKey, Vec<Strin
|| parts[0].starts_with("principals=")
{
let opts_str = parts[0];
options = opts_str
.split(',')
.map(|s| s.to_string())
.collect();
options = opts_str.split(',').map(|s| s.to_string()).collect();
key_type_idx = 1;
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
key_type_idx = 0;
@@ -218,9 +215,7 @@ mod tests {
#[test]
fn parse_authorized_keys_multiple_entries() {
let content = format!(
"{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n"
);
let content = format!("{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n");
let f = make_authorized_keys(&content);
let source = KeySource::File(f.path().to_path_buf());
let keys = load_public_keys(source).unwrap();
@@ -260,4 +255,4 @@ mod tests {
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].options, vec!["cert-authority"]);
}
}
}

View File

@@ -8,5 +8,5 @@ pub mod keys;
pub mod server_auth;
pub use client_auth::{ClientAuthConfig, ClientHandler};
pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys};
pub use server_auth::ServerAuthConfig;
pub use keys::{load_private_key, load_public_keys, CertAuthorityEntry, KeySource};
pub use server_auth::ServerAuthConfig;

View File

@@ -13,7 +13,7 @@ use ipnetwork::IpNetwork;
use russh::keys::helpers::EncodedExt;
use russh::keys::{Certificate, PublicKey};
use super::keys::{CertAuthorityEntry, KeySource, load_cert_authority_entries, load_public_keys};
use super::keys::{load_cert_authority_entries, load_public_keys, CertAuthorityEntry, KeySource};
use crate::error::AuthError;
/// Server-side authentication configuration.
@@ -41,10 +41,7 @@ impl ServerAuthConfig {
None => HashSet::new(),
};
let encoded_keys: HashSet<Vec<u8>> = authorized_keys
.iter()
.map(encode_key_data)
.collect();
let encoded_keys: HashSet<Vec<u8>> = authorized_keys.iter().map(encode_key_data).collect();
let cert_authorities = match cert_authority_source {
Some(src) => load_cert_authority_entries(src)?,
@@ -135,10 +132,7 @@ fn check_critical_options(
Ok(())
}
fn check_extensions(
cert: &Certificate,
ca_entry: &CertAuthorityEntry,
) -> Result<(), AuthError> {
fn check_extensions(cert: &Certificate, ca_entry: &CertAuthorityEntry) -> Result<(), AuthError> {
let ca_permit_port_forwarding = ca_entry
.options
.iter()
@@ -188,8 +182,8 @@ fn check_source_address(allowed: &str, client_ip: Option<IpAddr>) -> bool {
mod tests {
use super::*;
use rand_core::OsRng;
use russh::keys::{Certificate, PrivateKey, decode_secret_key};
use russh::keys::ssh_key::certificate::{Builder, CertType};
use russh::keys::{decode_secret_key, Certificate, PrivateKey};
use std::io::Write;
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
@@ -218,13 +212,9 @@ mod tests {
principals: Vec<&str>,
) -> Certificate {
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
let mut builder = Builder::new_with_random_nonce(
&mut OsRng,
key_data,
valid_after,
valid_before,
)
.unwrap();
let mut builder =
Builder::new_with_random_nonce(&mut OsRng, key_data, valid_after, valid_before)
.unwrap();
builder.cert_type(CertType::User).unwrap();
@@ -252,11 +242,7 @@ mod tests {
} else {
format!("cert-authority,{}", options.join(","))
};
let line = format!(
"{} {} CA\n",
opts,
ca_pub.to_openssh().unwrap()
);
let line = format!("{} {} CA\n", opts, ca_pub.to_openssh().unwrap());
f.write_all(line.as_bytes()).unwrap();
f.flush().unwrap();
f
@@ -357,13 +343,8 @@ mod tests {
let user_pub = user_key.public_key().clone();
let now = now_secs();
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
let mut builder = Builder::new_with_random_nonce(
&mut OsRng,
key_data,
now - 60,
now + 3600,
)
.unwrap();
let mut builder =
Builder::new_with_random_nonce(&mut OsRng, key_data, now - 60, now + 3600).unwrap();
builder.cert_type(CertType::User).unwrap();
builder.all_principals_valid().unwrap();
let cert = builder.sign(&ca_key).unwrap();
@@ -383,7 +364,13 @@ mod tests {
let other_ca_key = load_other_key();
let user_pub = user_key.public_key().clone();
let now = now_secs();
let cert = make_cert(&other_ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
let cert = make_cert(
&other_ca_key,
&user_pub,
now - 60,
now + 3600,
vec!["testuser"],
);
let ca_key = load_ca_key();
let ca_pub = ca_key.public_key().clone();
let f = make_ca_file(&ca_pub, &[]);
@@ -398,12 +385,11 @@ mod tests {
#[test]
fn no_config_accepts_nothing() {
let config =
ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
let config = ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
let other_pub = load_other_key().public_key().clone();
assert_eq!(
config.authenticate_publickey(&other_pub),
Err(AuthError::KeyRejected)
);
}
}
}

View File

@@ -113,14 +113,10 @@ impl<T: Transport> ChannelManager<T> {
.await
.map_err(|_| ChannelError::ChannelClosed)?;
self.inner
.forwards
.write()
.await
.insert(ForwardRequest {
addr: addr.to_string(),
port,
});
self.inner.forwards.write().await.insert(ForwardRequest {
addr: addr.to_string(),
port,
});
Ok(result)
}
@@ -132,14 +128,10 @@ impl<T: Transport> ChannelManager<T> {
.await
.map_err(|_| ChannelError::ChannelClosed)?;
self.inner
.forwards
.write()
.await
.remove(&ForwardRequest {
addr: addr.to_string(),
port,
});
self.inner.forwards.write().await.remove(&ForwardRequest {
addr: addr.to_string(),
port,
});
Ok(())
}
@@ -226,10 +218,7 @@ impl<T: Transport> ChannelManager<T> {
for fwd in forwards.iter() {
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
Ok(_) => {
debug!(
"re-registered tcpip_forward: {}:{}",
fwd.addr, fwd.port
);
debug!("re-registered tcpip_forward: {}:{}", fwd.addr, fwd.port);
}
Err(e) => {
warn!(
@@ -476,4 +465,4 @@ mod tests {
assert!(duration >= Duration::from_secs(1));
}
}
}
}

View File

@@ -197,10 +197,7 @@ pub struct ClientSession<T: Transport> {
}
impl<T: Transport> ClientSession<T> {
pub async fn new(
opts: ConnectOptions,
transport: Arc<T>,
) -> Result<Self, ConnectError> {
pub async fn new(opts: ConnectOptions, transport: Arc<T>) -> Result<Self, ConnectError> {
opts.validate().map_err(ConnectError::Config)?;
let auth_config = Arc::new(
@@ -283,16 +280,13 @@ impl<T: Transport> ClientSession<T> {
let remote_specs = build_remote_specs(&self.opts)?;
for spec in &remote_specs {
let remote_forwarder = RemoteForwarder::new(spec.clone())
.map_err(|_| ConnectError::ForwardFailed)?;
let remote_forwarder =
RemoteForwarder::new(spec.clone()).map_err(|_| ConnectError::ForwardFailed)?;
let mut h = self.handle.lock().await;
remote_forwarder
.register(&mut h)
.await
.map_err(|_| {
warn!("failed to register remote forward {}", spec);
ConnectError::ForwardFailed
})?;
remote_forwarder.register(&mut h).await.map_err(|_| {
warn!("failed to register remote forward {}", spec);
ConnectError::ForwardFailed
})?;
info!("registered remote forward: {}", spec);
}
@@ -307,7 +301,9 @@ impl<T: Transport> ClientSession<T> {
let fwd_shutdown = self.shutdown_rx.clone();
let forward_task = tokio::spawn(async move {
crate::client::forward::run_local_forwarders(
local_forwarders, fwd_handle, fwd_shutdown,
local_forwarders,
fwd_handle,
fwd_shutdown,
)
.await;
});
@@ -358,7 +354,14 @@ impl<T: Transport> ClientSession<T> {
let handler = ClientHandler::from_config(&reconnect_auth);
let username = reconnect_username.clone();
match establish_session(&*reconnect_transport, handler, &reconnect_auth, &username).await {
match establish_session(
&*reconnect_transport,
handler,
&reconnect_auth,
&username,
)
.await
{
Ok(new_handle) => {
info!("reconnection successful");
{
@@ -370,8 +373,13 @@ impl<T: Transport> ClientSession<T> {
Ok(rf) => {
let mut h = reconnect_handle.lock().await;
match rf.register(&mut h).await {
Ok(_) => debug!("re-registered remote forward: {}", spec),
Err(e) => warn!("failed to re-register remote forward {}: {e}", spec),
Ok(_) => {
debug!("re-registered remote forward: {}", spec)
}
Err(e) => warn!(
"failed to re-register remote forward {}: {e}",
spec
),
}
}
Err(e) => warn!("failed to create remote forwarder: {e}"),
@@ -493,12 +501,10 @@ fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>,
name: format!("invalid forward spec: {}", spec_str),
})
})?;
forwarders.push(
LocalForwarder::new(spec).map_err(|e| {
warn!("failed to create local forwarder: {}", e);
ConnectError::ForwardFailed
})?,
);
forwarders.push(LocalForwarder::new(spec).map_err(|e| {
warn!("failed to create local forwarder: {}", e);
ConnectError::ForwardFailed
})?);
}
Ok(forwarders)
}
@@ -576,7 +582,10 @@ mod tests {
assert_eq!(opts.forwards.len(), 1);
assert_eq!(opts.remote_forwards.len(), 1);
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
assert_eq!(
opts.iroh_relay.as_deref(),
Some("https://relay.example.com")
);
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
assert!(opts.insecure);
}
@@ -650,9 +659,18 @@ mod tests {
#[test]
fn connect_error_variants() {
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
assert_eq!(
ConnectError::ConnectionFailed.to_string(),
"connection failed"
);
assert_eq!(
ConnectError::AuthFailed.to_string(),
"authentication failed"
);
assert_eq!(
ConnectError::ForwardFailed.to_string(),
"forward setup failed"
);
}
#[test]
@@ -703,7 +721,10 @@ mod tests {
let transport = Arc::new(FailTransport);
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
assert!(matches!(
result.err().unwrap(),
ConnectError::ConnectionFailed
));
}
#[tokio::test]
@@ -714,7 +735,10 @@ mod tests {
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
let result = ClientSession::new(opts, transport).await;
assert!(result.is_err());
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
assert!(matches!(
result.err().unwrap(),
ConnectError::ConnectionFailed
));
}
#[test]
@@ -750,7 +774,8 @@ mod tests {
#[test]
fn build_remote_specs_valid() {
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
let opts =
ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
let result = build_remote_specs(&opts);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 1);
@@ -798,8 +823,8 @@ mod tests {
#[tokio::test]
async fn integration_mock_transport_session() {
use crate::socks5::{ChannelOpener, ChannelOpenError};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use crate::socks5::{ChannelOpenError, ChannelOpener};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
struct MockOpener;
@@ -839,9 +864,7 @@ mod tests {
conn.read_exact(&mut auth_resp).await.unwrap();
assert_eq!(auth_resp, [0x05, 0x00]);
let connect_req = [
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
];
let connect_req = [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80];
conn.write_all(&connect_req).await.unwrap();
let mut reply = [0u8; 10];
@@ -851,4 +874,4 @@ mod tests {
conn.write_all(b"test data").await.unwrap();
conn.shutdown().await.unwrap();
}
}
}

View File

@@ -205,12 +205,7 @@ async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
let handle_guard = handle.lock().await;
let channel = handle_guard
.channel_open_direct_tcpip(
remote_host,
remote_port as u32,
&local_addr,
0,
)
.channel_open_direct_tcpip(remote_host, remote_port as u32, &local_addr, 0)
.await
.map_err(|e| ForwardError::ChannelOpenFailed {
source: Box::new(e) as _,
@@ -470,11 +465,8 @@ mod tests {
let bound_addr = listener.local_addr().unwrap();
drop(listener);
let spec = PortForwardSpec::local(&format!(
"127.0.0.1:{}:remote:5432",
bound_addr.port()
))
.unwrap();
let spec = PortForwardSpec::local(&format!("127.0.0.1:{}:remote:5432", bound_addr.port()))
.unwrap();
let forwarder = LocalForwarder::new(spec).unwrap();
assert_eq!(forwarder.local_port(), bound_addr.port());
}
@@ -534,4 +526,4 @@ mod tests {
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
assert_eq!(forwarder.spec(), &spec);
}
}
}

View File

@@ -14,4 +14,4 @@ pub mod forward;
pub use channel_manager::{ChannelManager, ForwardRequest};
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};

View File

@@ -0,0 +1,395 @@
use std::sync::Arc;
use arc_swap::ArcSwap;
use crate::auth::ServerAuthConfig;
pub struct AuthPolicy {
pub authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
pub cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
encoded_keys: std::collections::HashSet<Vec<u8>>,
}
fn encode_key_data(key: &russh::keys::PublicKey) -> Vec<u8> {
use russh::keys::helpers::EncodedExt;
key.key_data().encoded().unwrap_or_default()
}
impl AuthPolicy {
pub fn new(
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
) -> Self {
let encoded_keys = authorized_keys.iter().map(encode_key_data).collect();
Self {
authorized_keys,
cert_authorities,
encoded_keys,
}
}
pub fn from_server_auth_config(config: ServerAuthConfig) -> Self {
Self::new(config.authorized_keys, config.cert_authorities)
}
pub fn empty() -> Self {
Self::new(std::collections::HashSet::new(), Vec::new())
}
pub fn authenticate_publickey(
&self,
key: &russh::keys::PublicKey,
) -> Result<(), crate::error::AuthError> {
let encoded = encode_key_data(key);
if self.encoded_keys.contains(&encoded) {
return Ok(());
}
Err(crate::error::AuthError::KeyRejected)
}
pub fn authenticate_certificate(
&self,
cert: &russh::keys::Certificate,
user: &str,
client_ip: Option<std::net::IpAddr>,
) -> Result<(), crate::error::AuthError> {
use std::time::SystemTime;
let matching_ca = self
.cert_authorities
.iter()
.find(|ca| cert.signature_key() == ca.public_key.key_data());
let ca_entry = match matching_ca {
Some(entry) => entry,
None => return Err(crate::error::AuthError::CertInvalid),
};
if cert.verify_signature().is_err() {
return Err(crate::error::AuthError::CertInvalid);
}
let now = SystemTime::now();
let now_secs = now
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
return Err(crate::error::AuthError::CertExpired);
}
let principals = cert.valid_principals();
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
return Err(crate::error::AuthError::CertPrincipalMismatch);
}
check_critical_options(cert, ca_entry, client_ip)?;
check_extensions(cert, ca_entry)?;
Ok(())
}
}
fn check_critical_options(
cert: &russh::keys::Certificate,
ca_entry: &crate::auth::keys::CertAuthorityEntry,
client_ip: Option<std::net::IpAddr>,
) -> Result<(), crate::error::AuthError> {
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
for (name, data) in cert.critical_options().iter() {
match name.as_str() {
"source-address" => {
if !check_source_address(data, client_ip) {
return Err(crate::error::AuthError::CertInvalid);
}
}
"force-command" => {}
"no-pty" => {}
_ => {
let _ = ca_has_no_pty;
return Err(crate::error::AuthError::CertInvalid);
}
}
}
Ok(())
}
fn check_extensions(
cert: &russh::keys::Certificate,
ca_entry: &crate::auth::keys::CertAuthorityEntry,
) -> Result<(), crate::error::AuthError> {
let ca_permit_port_forwarding = ca_entry
.options
.iter()
.any(|o| o == "permit-port-forwarding");
if ca_permit_port_forwarding {
let cert_allows = cert
.extensions()
.iter()
.any(|(n, _)| n == "permit-port-forwarding");
if !cert_allows {
return Err(crate::error::AuthError::CertInvalid);
}
}
Ok(())
}
fn check_source_address(allowed: &str, client_ip: Option<std::net::IpAddr>) -> bool {
use ipnetwork::IpNetwork;
use std::net::IpAddr;
use std::str::FromStr;
let Some(ip) = client_ip else {
return false;
};
for pattern in allowed.split(',') {
let pattern = pattern.trim();
if pattern.is_empty() {
continue;
}
if let Ok(cidr) = IpNetwork::from_str(pattern) {
if cidr.contains(ip) {
return true;
}
}
if let Ok(net_ip) = IpAddr::from_str(pattern) {
if net_ip == ip {
return true;
}
}
}
false
}
impl std::fmt::Debug for AuthPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthPolicy")
.field("authorized_keys_count", &self.authorized_keys.len())
.field("cert_authorities_count", &self.cert_authorities.len())
.finish()
}
}
impl Clone for AuthPolicy {
fn clone(&self) -> Self {
Self {
authorized_keys: self.authorized_keys.clone(),
cert_authorities: self.cert_authorities.clone(),
encoded_keys: self.encoded_keys.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ForwardingAction {
Allow,
Deny,
}
#[derive(Debug, Clone)]
pub struct ForwardingRule {
pub action: ForwardingAction,
pub principals: Vec<String>,
pub transports: Vec<crate::server::handler::TransportKind>,
}
#[derive(Debug, Clone)]
pub struct ForwardingPolicy {
pub default: ForwardingAction,
pub rules: Vec<ForwardingRule>,
}
impl ForwardingPolicy {
pub fn allow_all() -> Self {
Self {
default: ForwardingAction::Allow,
rules: Vec::new(),
}
}
pub fn deny_all() -> Self {
Self {
default: ForwardingAction::Deny,
rules: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_connections_per_ip: usize,
pub max_auth_attempts: usize,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_connections_per_ip: 0,
max_auth_attempts: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct DynamicConfig {
pub auth: AuthPolicy,
pub forwarding: ForwardingPolicy,
pub rate_limits: RateLimitConfig,
}
impl DynamicConfig {
pub fn new(auth: AuthPolicy) -> Self {
Self {
auth,
forwarding: ForwardingPolicy::allow_all(),
rate_limits: RateLimitConfig::default(),
}
}
pub fn with_forwarding_policy(mut self, policy: ForwardingPolicy) -> Self {
self.forwarding = policy;
self
}
pub fn with_rate_limits(mut self, limits: RateLimitConfig) -> Self {
self.rate_limits = limits;
self
}
}
impl Default for DynamicConfig {
fn default() -> Self {
Self {
auth: AuthPolicy::empty(),
forwarding: ForwardingPolicy::allow_all(),
rate_limits: RateLimitConfig::default(),
}
}
}
pub struct ConfigReloadHandle {
pub(crate) dynamic: Arc<ArcSwap<DynamicConfig>>,
}
impl ConfigReloadHandle {
pub fn reload(&self, new_config: DynamicConfig) {
self.dynamic.store(Arc::new(new_config));
}
pub fn dynamic(&self) -> Arc<DynamicConfig> {
self.dynamic.load_full()
}
}
impl std::fmt::Debug for ConfigReloadHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfigReloadHandle").finish()
}
}
pub fn new_dynamic_config() -> (Arc<ArcSwap<DynamicConfig>>, ConfigReloadHandle) {
let inner = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let handle = ConfigReloadHandle {
dynamic: Arc::clone(&inner),
};
(inner, handle)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forwarding_policy_allow_all_default() {
let policy = ForwardingPolicy::allow_all();
assert_eq!(policy.default, ForwardingAction::Allow);
assert!(policy.rules.is_empty());
}
#[test]
fn forwarding_policy_deny_all() {
let policy = ForwardingPolicy::deny_all();
assert_eq!(policy.default, ForwardingAction::Deny);
assert!(policy.rules.is_empty());
}
#[test]
fn dynamic_config_default() {
let config = DynamicConfig::default();
assert_eq!(config.forwarding.default, ForwardingAction::Allow);
assert_eq!(config.rate_limits.max_connections_per_ip, 0);
assert_eq!(config.rate_limits.max_auth_attempts, 10);
}
#[test]
fn config_reload_handle_updates_dynamic() {
let (arc_swap, handle) = new_dynamic_config();
let initial = arc_swap.load();
assert_eq!(initial.forwarding.default, ForwardingAction::Allow);
let new_config = DynamicConfig {
auth: AuthPolicy::empty(),
forwarding: ForwardingPolicy::deny_all(),
rate_limits: RateLimitConfig::default(),
};
handle.reload(new_config);
let updated = arc_swap.load();
assert_eq!(updated.forwarding.default, ForwardingAction::Deny);
}
#[test]
fn dynamic_config_with_forwarding_policy_builder() {
let config = DynamicConfig::new(AuthPolicy::empty())
.with_forwarding_policy(ForwardingPolicy::deny_all());
assert_eq!(config.forwarding.default, ForwardingAction::Deny);
}
#[test]
fn rate_limit_config_custom() {
let limits = RateLimitConfig {
max_connections_per_ip: 5,
max_auth_attempts: 3,
};
assert_eq!(limits.max_connections_per_ip, 5);
assert_eq!(limits.max_auth_attempts, 3);
}
#[test]
fn forwarding_action_equality() {
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
}
#[test]
fn auth_policy_empty_rejects_all() {
let policy = AuthPolicy::empty();
let key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key =
russh::keys::parse_public_key_base64(key_text.split_whitespace().nth(1).unwrap())
.unwrap();
assert_eq!(
policy.authenticate_publickey(&other_ssh_key),
Err(crate::error::AuthError::KeyRejected)
);
}
#[test]
fn auth_policy_debug_redacts_keys() {
let policy = AuthPolicy::empty();
let debug_str = format!("{:?}", policy);
assert!(debug_str.contains("authorized_keys_count"));
assert!(debug_str.contains("cert_authorities_count"));
}
}

View File

@@ -0,0 +1,8 @@
pub mod dynamic_config;
pub mod static_config;
pub use dynamic_config::{
new_dynamic_config, AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction,
ForwardingPolicy, ForwardingRule, RateLimitConfig,
};
pub use static_config::StaticConfig;

View File

@@ -0,0 +1,101 @@
use crate::server::handler::{ProxyConfig, ProxyMode};
use crate::server::serve::ServeTransportMode;
use std::net::SocketAddr;
pub struct StaticConfig {
pub transport_mode: ServeTransportMode,
pub listen_addr: String,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub acme_domain: Option<String>,
pub stealth: bool,
pub host_key: russh::keys::PrivateKey,
pub host_key_algorithm: russh::keys::Algorithm,
pub max_auth_attempts: usize,
pub max_connections_per_ip: usize,
pub proxy_config: Option<ProxyConfig>,
pub iroh_relay: Option<String>,
}
impl std::fmt::Debug for StaticConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StaticConfig")
.field("transport_mode", &self.transport_mode)
.field("listen_addr", &self.listen_addr)
.field("tls_cert", &self.tls_cert.as_ref().map(|_| "<redacted>"))
.field("tls_key", &self.tls_key.as_ref().map(|_| "<redacted>"))
.field("acme_domain", &self.acme_domain)
.field("stealth", &self.stealth)
.field("host_key_algorithm", &self.host_key_algorithm)
.field("max_auth_attempts", &self.max_auth_attempts)
.field("max_connections_per_ip", &self.max_connections_per_ip)
.field("proxy_config", &self.proxy_config)
.field("iroh_relay", &self.iroh_relay)
.finish()
}
}
impl StaticConfig {
pub fn from_serve_options(
opts: crate::server::serve::ServeOptions,
) -> Result<(Self, crate::config::DynamicConfig), crate::error::ConfigError> {
opts.validate()?;
let host_key = crate::auth::keys::load_private_key(opts.key.clone())?;
let host_key_algorithm = host_key.algorithm();
let auth_config = crate::auth::ServerAuthConfig::from_keys_and_ca(
opts.authorized_keys.clone(),
opts.cert_authority.clone(),
)?;
let auth_policy = crate::config::AuthPolicy::from_server_auth_config(auth_config);
let dynamic = crate::config::DynamicConfig::new(auth_policy);
let proxy_config = parse_proxy_config(opts.proxy.as_deref());
let static_config = StaticConfig {
transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr,
tls_cert: opts.tls_cert,
tls_key: opts.tls_key,
acme_domain: opts.acme_domain,
stealth: opts.stealth,
host_key,
host_key_algorithm,
max_auth_attempts: opts.max_auth_attempts,
max_connections_per_ip: opts.max_connections_per_ip,
proxy_config,
iroh_relay: opts.iroh_relay,
};
Ok((static_config, dynamic))
}
}
fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
proxy.map(|url| {
if url.starts_with("socks5://") {
let addr: SocketAddr = url
.strip_prefix("socks5://")
.unwrap()
.parse()
.expect("invalid socks5 proxy address");
ProxyConfig {
mode: ProxyMode::Socks5(addr),
}
} else if url.starts_with("http://") {
let addr: SocketAddr = url
.strip_prefix("http://")
.unwrap()
.parse()
.expect("invalid http connect proxy address");
ProxyConfig {
mode: ProxyMode::HttpConnect(addr),
}
} else {
panic!("unsupported proxy URL scheme: {url}");
}
})
}

View File

@@ -97,7 +97,10 @@ mod tests {
#[test]
fn transport_error_display() {
assert_eq!(TransportError::ConnectionFailed.to_string(), "connection failed");
assert_eq!(
TransportError::ConnectionFailed.to_string(),
"connection failed"
);
assert_eq!(
TransportError::HandshakeFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
@@ -120,13 +123,19 @@ mod tests {
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
assert_eq!(AuthError::CertPrincipalMismatch.to_string(), "certificate principal mismatch");
assert_eq!(
AuthError::CertPrincipalMismatch.to_string(),
"certificate principal mismatch"
);
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
}
#[test]
fn channel_error_display() {
assert_eq!(ChannelError::TargetUnreachable.to_string(), "target unreachable");
assert_eq!(
ChannelError::TargetUnreachable.to_string(),
"target unreachable"
);
assert_eq!(
ChannelError::ProxyConnectFailed {
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
@@ -160,7 +169,10 @@ mod tests {
.to_string(),
"bind failed"
);
assert_eq!(ConfigError::IncompatibleOptions.to_string(), "incompatible options");
assert_eq!(
ConfigError::IncompatibleOptions.to_string(),
"incompatible options"
);
}
#[test]
@@ -184,7 +196,10 @@ mod tests {
#[test]
fn forward_error_display() {
assert_eq!(
ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(),
ForwardError::InvalidSpec {
spec: "bad".to_string()
}
.to_string(),
"invalid port forward spec: bad"
);
assert_eq!(
@@ -209,7 +224,9 @@ mod tests {
let forward_err = ForwardError::BindFailed { source: io_err };
assert!(forward_err.source().is_some());
let plain = ForwardError::InvalidSpec { spec: "bad".to_string() };
let plain = ForwardError::InvalidSpec {
spec: "bad".to_string(),
};
assert!(plain.source().is_none());
}
}
}

View File

@@ -50,18 +50,23 @@
//! }
//! ```
pub mod transport;
pub mod client;
pub mod server;
pub mod auth;
pub mod socks5;
pub mod client;
pub mod config;
pub mod error;
pub mod server;
pub mod socks5;
pub mod transport;
#[cfg(feature = "testutil")]
pub mod testutil;
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
pub use client::channel_manager::{ChannelManager, ForwardRequest};
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
pub use server::serve::{Server, ServeError, ServeOptions, ServeTransportMode};
pub use config::{
AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction, ForwardingPolicy,
ForwardingRule, RateLimitConfig, StaticConfig,
};
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use server::serve::{ServeError, ServeOptions, ServeTransportMode, Server};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};

View File

@@ -46,7 +46,10 @@ async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyErr
.map_err(|e| map_connection_error(e, target))
}
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
async fn connect_socks5(
target: SocketAddr,
proxy_addr: SocketAddr,
) -> Result<TcpStream, ChannelProxyError> {
let mut stream = TcpStream::connect(proxy_addr)
.await
.map_err(ChannelProxyError::from)?;
@@ -134,10 +137,7 @@ async fn connect_http_connect(
}
let response_str = String::from_utf8_lossy(&response);
let status_line = response_str
.lines()
.next()
.unwrap_or("");
let status_line = response_str.lines().next().unwrap_or("");
if status_line.contains("200") {
Ok(stream)
@@ -279,11 +279,7 @@ mod tests {
.parse()
.unwrap();
let reply = vec![
0x05, 0x00, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
let reply = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
proxy_sock.write_all(&reply).await.unwrap();
let mut target_stream = TcpStream::connect(target).await.unwrap();
@@ -323,11 +319,7 @@ mod tests {
let mut port_bytes = [0u8; 2];
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
let reply = vec![
0x05, 0x05, 0x00, 0x01,
0, 0, 0, 0,
0, 0,
];
let reply = vec![0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
proxy_sock.write_all(&reply).await.unwrap();
});
@@ -560,4 +552,4 @@ mod tests {
let proxy = direct_config();
proxy_channel(channel, target, &proxy).await;
}
}
}

View File

@@ -189,4 +189,4 @@ mod tests {
fn control_channel_destination_matches_prefix() {
assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION));
}
}
}

View File

@@ -2,16 +2,15 @@ use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
use russh::server::{Auth, Handler, Msg, Session};
use russh::Channel;
use russh::ChannelId;
use crate::auth::ServerAuthConfig;
use crate::server::control_channel::{
ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX,
};
use crate::config::DynamicConfig;
use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
#[derive(Debug, Clone)]
@@ -44,7 +43,7 @@ impl std::fmt::Display for TransportKind {
}
pub struct ServerHandler {
auth_config: Arc<ServerAuthConfig>,
dynamic: Arc<ArcSwap<DynamicConfig>>,
#[allow(dead_code)]
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
@@ -59,7 +58,7 @@ pub struct ServerHandler {
impl ServerHandler {
pub fn new(
auth_config: Arc<ServerAuthConfig>,
dynamic: Arc<ArcSwap<DynamicConfig>>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
@@ -89,7 +88,7 @@ impl ServerHandler {
};
Self {
auth_config,
dynamic,
outbound_proxy,
remote_addr,
control_channel_router: ControlChannelRouter::without_handler(),
@@ -127,10 +126,7 @@ impl Drop for ServerHandler {
}
impl ServerHandler {
pub fn with_control_channel_handler(
mut self,
handler: Box<dyn ControlChannelHandler>,
) -> Self {
pub fn with_control_channel_handler(mut self, handler: Box<dyn ControlChannelHandler>) -> Self {
self.control_channel_router = ControlChannelRouter::with_handler(handler);
self
}
@@ -172,7 +168,8 @@ impl Handler for ServerHandler {
.map_or("unknown".to_string(), |a| a.to_string());
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
let result = self.auth_config.authenticate_publickey(&russh_pub);
let auth_config = self.dynamic.load();
let result = auth_config.auth.authenticate_publickey(&russh_pub);
match result {
Ok(()) => {
@@ -226,17 +223,25 @@ impl Handler for ServerHandler {
});
tokio::spawn(async move {
let target = match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
Ok(addr) => addr,
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => return,
let target =
match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
Ok(addr) => addr,
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16))
.await
{
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => return,
},
Err(_) => return,
},
Err(_) => return,
},
};
crate::server::channel_proxy::proxy_channel(channel.into_stream(), target, &proxy_config).await;
};
crate::server::channel_proxy::proxy_channel(
channel.into_stream(),
target,
&proxy_config,
)
.await;
});
let _ = (originator_address, originator_port);
@@ -389,7 +394,12 @@ impl Handler for ServerHandler {
channel = %channel,
"rejected x11 request on channel"
);
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
let _ = (
single_connection,
x11_auth_protocol,
x11_auth_cookie,
x11_screen_number,
);
let _ = session.channel_failure(channel);
Ok(())
}
@@ -469,6 +479,8 @@ impl Handler for ServerHandler {
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use crate::auth::ServerAuthConfig;
use crate::config::AuthPolicy;
use russh::keys::{decode_secret_key, PrivateKey};
use std::io::Write;
@@ -487,19 +499,19 @@ mod tests {
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
}
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
fn make_auth_config(keys_content: &str) -> Arc<ArcSwap<DynamicConfig>> {
let f = make_authorized_keys_file(keys_content);
Arc::new(
ServerAuthConfig::from_keys_and_ca(
Some(KeySource::File(f.path().to_path_buf())),
None,
)
.unwrap(),
)
let server_auth =
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
.unwrap();
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
let dynamic = DynamicConfig::new(auth_policy);
Arc::new(ArcSwap::new(Arc::new(dynamic)))
}
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
fn make_empty_auth_config() -> Arc<ArcSwap<DynamicConfig>> {
let dynamic = DynamicConfig::default();
Arc::new(ArcSwap::new(Arc::new(dynamic)))
}
fn default_limiter() -> Arc<ConnectionRateLimiter> {
@@ -507,11 +519,18 @@ mod tests {
}
fn make_handler(
auth_config: Arc<ServerAuthConfig>,
dynamic: Arc<ArcSwap<DynamicConfig>>,
outbound_proxy: Option<ProxyConfig>,
remote_addr: Option<SocketAddr>,
) -> ServerHandler {
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
ServerHandler::new(
dynamic,
outbound_proxy,
remote_addr,
TransportKind::Tcp,
default_limiter(),
10,
)
}
#[tokio::test]
@@ -530,10 +549,9 @@ mod tests {
let mut handler = make_handler(auth_config, None, None);
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
let other_ssh_key = russh::keys::parse_public_key_base64(
other_key_text.split_whitespace().nth(1).unwrap(),
)
.unwrap();
let other_ssh_key =
russh::keys::parse_public_key_base64(other_key_text.split_whitespace().nth(1).unwrap())
.unwrap();
let result = handler
.auth_publickey("testuser", &other_ssh_key)
@@ -553,10 +571,7 @@ mod tests {
let mut handler = make_handler(auth_config, None, None);
let ssh_key = load_key().public_key().clone();
let result = handler
.auth_publickey("testuser", &ssh_key)
.await
.unwrap();
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
assert_eq!(
result,
Auth::Reject {
@@ -629,8 +644,16 @@ mod tests {
#[test]
fn one_handler_per_connection() {
let auth_config = make_empty_auth_config();
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
let handler1 = make_handler(
auth_config.clone(),
None,
Some("10.0.0.1:22".parse().unwrap()),
);
let handler2 = make_handler(
auth_config.clone(),
None,
Some("10.0.0.2:22".parse().unwrap()),
);
assert!(handler1.remote_addr != handler2.remote_addr);
}
@@ -651,10 +674,20 @@ mod tests {
let ssh_key = load_key().public_key().clone();
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
assert_eq!(
r1,
Auth::Reject {
proceed_with_methods: None
}
);
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
assert_eq!(
r2,
Auth::Reject {
proceed_with_methods: None
}
);
assert!(!handler.auth_limiter.check());
}
@@ -733,4 +766,4 @@ mod tests {
10,
);
}
}
}

View File

@@ -16,10 +16,12 @@ pub mod stealth;
pub use channel_proxy::{connect_outbound, proxy_channel};
pub use control_channel::{
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
ALKNET_PREFIX, is_reserved_destination,
is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream,
ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use serve::{Server, ServeError, ServeOptions, ServeTransportMode};
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};
pub use serve::{ServeError, ServeOptions, ServeTransportMode, Server};
pub use stealth::{
detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection,
};

View File

@@ -197,4 +197,4 @@ mod tests {
h.join().unwrap();
}
}
}
}

View File

@@ -8,12 +8,14 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use russh::server::{self, Config};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use crate::auth::keys::KeySource;
use crate::auth::server_auth::ServerAuthConfig;
use crate::config::{AuthPolicy, ConfigReloadHandle, DynamicConfig};
use crate::error::ConfigError;
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
use crate::server::rate_limit::ConnectionRateLimiter;
@@ -228,7 +230,7 @@ struct ActiveSession {
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
pub struct Server {
config: Arc<server::Config>,
auth_config: Arc<ServerAuthConfig>,
dynamic: Arc<ArcSwap<DynamicConfig>>,
connection_limiter: Arc<ConnectionRateLimiter>,
outbound_proxy: Option<ProxyConfig>,
stealth: bool,
@@ -244,17 +246,24 @@ impl Server {
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
opts.validate().map_err(ServeError::Config)?;
let private_key =
crate::auth::keys::load_private_key(opts.key.clone()).map_err(ServeError::KeyLoadFailed)?;
let private_key = crate::auth::keys::load_private_key(opts.key.clone())
.map_err(ServeError::KeyLoadFailed)?;
let auth_config = Arc::new(
ServerAuthConfig::from_keys_and_ca(opts.authorized_keys.clone(), opts.cert_authority.clone())
.map_err(ServeError::KeyLoadFailed)?,
);
let auth_config = ServerAuthConfig::from_keys_and_ca(
opts.authorized_keys.clone(),
opts.cert_authority.clone(),
)
.map_err(ServeError::KeyLoadFailed)?;
let auth_policy = AuthPolicy::from_server_auth_config(auth_config);
let dynamic_config = DynamicConfig::new(auth_policy);
let max_auth_attempts = opts.max_auth_attempts;
let max_connections_per_ip = opts.max_connections_per_ip;
let config = Arc::new(Config {
keys: vec![private_key],
max_auth_attempts: opts.max_auth_attempts,
max_auth_attempts,
methods: russh::MethodSet::PUBLICKEY,
preferred: russh::Preferred::DEFAULT,
..Default::default()
@@ -262,19 +271,21 @@ impl Server {
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
let connection_limiter = Arc::new(ConnectionRateLimiter::new(opts.max_connections_per_ip));
let connection_limiter = Arc::new(ConnectionRateLimiter::new(max_connections_per_ip));
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config)));
Ok(Self {
config,
auth_config,
dynamic,
connection_limiter,
outbound_proxy,
stealth: opts.stealth,
transport_mode: opts.transport_mode,
listen_addr: opts.listen_addr,
max_auth_attempts: opts.max_auth_attempts,
max_auth_attempts,
shutdown_tx,
shutdown_rx,
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
@@ -285,6 +296,12 @@ impl Server {
self.shutdown_tx.clone()
}
pub fn config_reload_handle(&self) -> ConfigReloadHandle {
ConfigReloadHandle {
dynamic: Arc::clone(&self.dynamic),
}
}
pub async fn shutdown(&self) -> Result<(), ServeError> {
info!("initiating graceful shutdown");
let _ = self.shutdown_tx.send(true);
@@ -292,11 +309,15 @@ impl Server {
{
let sessions = self.sessions.lock().await;
for session in sessions.iter() {
if let Err(e) = session.handle.disconnect(
russh::Disconnect::ByApplication,
"shutdown".to_string(),
String::new(),
).await {
if let Err(e) = session
.handle
.disconnect(
russh::Disconnect::ByApplication,
"shutdown".to_string(),
String::new(),
)
.await
{
warn!("failed to send SSH disconnect: {e}");
}
}
@@ -392,7 +413,7 @@ impl Server {
let handler_transport_kind = transport_kind;
let handler = ServerHandler::new(
Arc::clone(&server.auth_config),
Arc::clone(&server.dynamic),
server.outbound_proxy.clone(),
remote_addr,
handler_transport_kind,
@@ -410,15 +431,9 @@ impl Server {
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
tokio::spawn(async move {
let result = handle_connection(
stream,
config,
handler,
sessions,
stealth,
transport_is_tls,
)
.await;
let result =
handle_connection(stream, config, handler, sessions, stealth, transport_is_tls)
.await;
if let Err(e) = result {
warn!("connection error: {e}");
@@ -611,8 +626,7 @@ mod tests {
#[test]
fn serve_options_validate_tcp_with_acme_rejected() {
let opts =
ServeOptions::new(make_key_source()).acme_domain("example.com");
let opts = ServeOptions::new(make_key_source()).acme_domain("example.com");
assert!(opts.validate().is_err());
}
@@ -626,8 +640,8 @@ mod tests {
#[test]
fn server_new_creates_server() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let opts =
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
assert_eq!(server.max_auth_attempts, 10);
}
@@ -662,8 +676,8 @@ mod tests {
#[test]
fn serve_options_debug_redacts_keys() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let opts =
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
let debug_str = format!("{:?}", opts);
assert!(debug_str.contains("<KeySource>"));
assert!(!debug_str.contains("OPENSSH"));
@@ -715,8 +729,8 @@ mod tests {
#[test]
fn server_shutdown_sender_clones() {
let opts = ServeOptions::new(make_key_source())
.authorized_keys(make_authorized_keys_source());
let opts =
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
let server = Server::new(opts).unwrap();
let sender = server.shutdown_sender();
assert!(!server.is_shutdown());
@@ -726,8 +740,7 @@ mod tests {
#[test]
fn server_holds_listen_addr() {
let opts = ServeOptions::new(make_key_source())
.listen_addr("0.0.0.0:443");
let opts = ServeOptions::new(make_key_source()).listen_addr("0.0.0.0:443");
let server = Server::new(opts).unwrap();
assert_eq!(server.listen_addr, "0.0.0.0:443");
}
@@ -747,12 +760,10 @@ mod tests {
let server = Server::new(opts).unwrap();
let shutdown_tx = server.shutdown_sender();
let server_handle = tokio::spawn(async move {
server
.run(acceptor, None)
.await
.expect("server run failed")
});
let server_handle =
tokio::spawn(
async move { server.run(acceptor, None).await.expect("server run failed") },
);
tokio::time::sleep(Duration::from_millis(50)).await;
@@ -760,6 +771,9 @@ mod tests {
let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
assert!(result.is_ok(), "server should have shut down within timeout");
assert!(
result.is_ok(),
"server should have shut down within timeout"
);
}
}
}

View File

@@ -134,7 +134,10 @@ mod tests {
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data).await.unwrap();
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
assert!(
all_data.starts_with(banner),
"banner bytes must be preserved after detection"
);
}
#[tokio::test]
@@ -142,7 +145,10 @@ mod tests {
let (client, server) = duplex(1024);
let (mut client_read, mut client_write) = tokio::io::split(client);
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
client_write
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
.await
.unwrap();
drop(client_write);
let (detection, mut reader) = detect_protocol(server).await;
@@ -206,7 +212,10 @@ mod tests {
let (client, server) = duplex(1024);
let mut client = client;
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
client
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
.await
.unwrap();
let (detection, mut reader) = detect_protocol(server).await;
assert_eq!(detection, ProtocolDetection::Http);
@@ -223,4 +232,4 @@ mod tests {
let result = client.read(&mut extra).await;
assert!(result.is_err() || result.unwrap() == 0);
}
}
}

View File

@@ -52,9 +52,7 @@ impl<C: ChannelOpener> Socks5Server<C> {
}
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
let listen_addr: SocketAddr = addr
.parse()
.expect("invalid SOCKS5 listen address");
let listen_addr: SocketAddr = addr.parse().expect("invalid SOCKS5 listen address");
Self {
listen_addr,
channel_opener: Arc::new(channel_opener),
@@ -80,10 +78,7 @@ impl<C: ChannelOpener> Socks5Server<C> {
}
}
async fn handle_socks5_connection<S, C>(
mut socket: S,
opener: Arc<C>,
) -> Result<(), Socks5Error>
async fn handle_socks5_connection<S, C>(mut socket: S, opener: Arc<C>) -> Result<(), Socks5Error>
where
S: AsyncRead + AsyncWrite + Unpin,
C: ChannelOpener,
@@ -173,7 +168,11 @@ impl<H: russh::client::Handler> HandleChannelOpener<H> {
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
type Stream = russh::ChannelStream<russh::client::Msg>;
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
async fn open_channel(
&self,
host: String,
port: u16,
) -> Result<Self::Stream, ChannelOpenError> {
let handle = self.handle.lock().await;
if handle.is_closed() {
return Err(ChannelOpenError::SessionClosed);
@@ -241,7 +240,10 @@ mod tests {
}
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
client
.write_all(&build_socks5_greeting(&[0x00]))
.await
.unwrap();
client.flush().await.unwrap();
let mut resp = [0u8; 2];
client.read_exact(&mut resp).await.unwrap();
@@ -264,9 +266,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
let resp = do_handshake(&mut client).await;
assert_eq!(resp, [0x05, 0x00]);
@@ -284,9 +285,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
client
.write_all(&build_socks5_greeting(&[0x02]))
@@ -301,10 +301,7 @@ mod tests {
drop(client);
let result = server_handle.await.unwrap();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
Socks5Error::NoAcceptableAuth
));
assert!(matches!(result.unwrap_err(), Socks5Error::NoAcceptableAuth));
}
#[tokio::test]
@@ -312,9 +309,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
do_handshake(&mut client).await;
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
@@ -329,9 +325,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
do_handshake(&mut client).await;
@@ -354,9 +349,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
do_handshake(&mut client).await;
@@ -381,9 +375,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: true };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
do_handshake(&mut client).await;
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
@@ -399,9 +392,8 @@ mod tests {
let (mut client, server) = duplex(4096);
let opener = MockChannelOpener { fail: false };
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
do_handshake(&mut client).await;
@@ -450,9 +442,10 @@ mod tests {
stream: Arc::clone(&ssh_stream),
};
let server_handle = tokio::spawn(async move {
handle_socks5_connection(server_sock, Arc::new(opener)).await
});
let server_handle =
tokio::spawn(
async move { handle_socks5_connection(server_sock, Arc::new(opener)).await },
);
do_handshake(&mut client_sock).await;
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
@@ -494,4 +487,4 @@ mod tests {
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
}
}
}

View File

@@ -169,10 +169,7 @@ mod tests {
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
assert_eq!(req.version, 0x05);
assert_eq!(req.command, 0x01);
assert_eq!(
req.address,
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
);
assert_eq!(req.address, Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1)));
assert_eq!(req.port, 443);
}
@@ -201,7 +198,10 @@ mod tests {
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
assert_eq!(req.version, 0x05);
assert_eq!(req.command, 0x01);
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
assert_eq!(
req.address,
Socks5Address::Domain("example.com".to_string())
);
assert_eq!(req.port, 443);
}
@@ -301,4 +301,4 @@ mod tests {
let port = cursor.read_u16().await.unwrap();
assert_eq!(port, 8080);
}
}
}

View File

@@ -1,5 +1,5 @@
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
use anyhow::Result;
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
#[cfg(feature = "transport-traits")]
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
@@ -9,10 +9,10 @@ pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKin
#[cfg(not(feature = "transport-traits"))]
mod local_traits {
use std::net::SocketAddr;
use anyhow::Result;
use tokio::io::{AsyncRead, AsyncWrite};
use async_trait::async_trait;
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite};
#[async_trait]
pub trait Transport: Send + Sync + 'static {
@@ -138,4 +138,4 @@ impl TransportAcceptor for MockTransportAcceptor {
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
let (client, server) = tokio::io::duplex(buf_size);
(MockStream::new(client), MockStream::new(server))
}
}

View File

@@ -7,9 +7,9 @@ use rustls::crypto::aws_lc_rs::default_provider;
use rustls::ServerConfig;
use rustls_acme::caches::DirCache;
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
use tracing::{error, info};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
use tracing::{error, info};
use super::{TransportAcceptor, TransportInfo, TransportKind};
@@ -94,14 +94,10 @@ impl AcmeCertProvider {
.contact(self.contact.clone());
let state = match &self.cache_dir {
Some(cache_dir) => {
base_config.cache(DirCache::new(cache_dir.clone())).state()
}
None => {
base_config
.cache(rustls_acme::caches::NoCache::default())
.state()
}
Some(cache_dir) => base_config.cache(DirCache::new(cache_dir.clone())).state(),
None => base_config
.cache(rustls_acme::caches::NoCache::default())
.state(),
};
let resolver = state.resolver();
@@ -132,10 +128,7 @@ pub struct AcmeTlsAcceptor {
}
impl AcmeTlsAcceptor {
pub async fn bind_acme(
addr: SocketAddr,
provider: Arc<AcmeCertProvider>,
) -> Result<Self> {
pub async fn bind_acme(addr: SocketAddr, provider: Arc<AcmeCertProvider>) -> Result<Self> {
let (state, resolver) = provider.build_acme_state();
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
@@ -193,11 +186,7 @@ impl TransportAcceptor for AcmeTlsAcceptor {
let (tcp_stream, remote_addr) = self.listener.accept().await?;
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
let server_name = tls_stream
.get_ref()
.1
.server_name()
.map(|s| s.to_string());
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
let info = TransportInfo {
remote_addr: Some(remote_addr),
@@ -277,8 +266,7 @@ mod tests {
#[test]
fn acme_cert_provider_build_state_with_cache() {
let provider =
AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
let (_state, resolver) = provider.build_acme_state();
assert!(Arc::strong_count(&resolver) >= 2);
}
@@ -288,7 +276,9 @@ mod tests {
let _ = default_provider().install_default();
let provider = AcmeCertProvider::domain("example.com");
let (_, resolver) = provider.build_acme_state();
let config = provider.build_server_config_with_resolver(resolver).unwrap();
let config = provider
.build_server_config_with_resolver(resolver)
.unwrap();
assert!(!config.alpn_protocols.is_empty());
assert!(config
.alpn_protocols
@@ -359,4 +349,4 @@ mod tests {
let acceptor = result.unwrap();
assert_eq!(acceptor.listen_addr().port(), 443);
}
}
}

View File

@@ -1,9 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use iroh::{
endpoint::RecvStream,
node_info::NodeIdExt,
Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
endpoint::RecvStream, node_info::NodeIdExt, Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
};
use tokio::io;
@@ -39,7 +37,9 @@ impl IrohTransport {
proxy_url: Option<url::Url>,
) -> Result<Self> {
let relay_url = relay_url.unwrap_or_else(|| {
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
DEFAULT_RELAY_URL
.parse()
.expect("default relay URL is valid")
});
let relay_map = RelayMap::from_url(relay_url);
let mut builder = Endpoint::builder()
@@ -49,7 +49,11 @@ impl IrohTransport {
builder = builder.proxy_url(proxy.clone());
}
let endpoint = builder.bind().await?;
Ok(Self { node_id, endpoint, owned: true })
Ok(Self {
node_id,
endpoint,
owned: true,
})
}
/// Create an iroh transport using an existing shared endpoint.
@@ -60,7 +64,11 @@ impl IrohTransport {
/// other protocol handlers on the same QUIC endpoint — one connection
/// per peer, multiplexed by ALPN.
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
Self { node_id, endpoint, owned: false }
Self {
node_id,
endpoint,
owned: false,
}
}
pub fn endpoint_id(&self) -> String {
@@ -115,12 +123,11 @@ impl IrohAcceptor {
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
///
/// Use this when alknet is the only iroh service on this node.
pub async fn bind(
relay_url: Option<RelayUrl>,
proxy_url: Option<url::Url>,
) -> Result<Self> {
pub async fn bind(relay_url: Option<RelayUrl>, proxy_url: Option<url::Url>) -> Result<Self> {
let relay_url = relay_url.unwrap_or_else(|| {
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
DEFAULT_RELAY_URL
.parse()
.expect("default relay URL is valid")
});
let relay_map = RelayMap::from_url(relay_url);
let mut builder = Endpoint::builder()
@@ -130,7 +137,10 @@ impl IrohAcceptor {
builder = builder.proxy_url(proxy.clone());
}
let endpoint = builder.bind().await?;
Ok(Self { endpoint, owned: true })
Ok(Self {
endpoint,
owned: true,
})
}
/// Create an iroh acceptor using an existing shared endpoint.
@@ -146,7 +156,10 @@ impl IrohAcceptor {
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
/// internally.
pub fn from_endpoint(endpoint: Endpoint) -> Self {
Self { endpoint, owned: false }
Self {
endpoint,
owned: false,
}
}
pub fn endpoint_id(&self) -> String {
@@ -219,18 +232,14 @@ mod tests {
#[test]
fn iroh_transport_describe_format() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
let desc = format!("iroh://{}", node_id.to_z32());
assert!(desc.starts_with("iroh://"));
}
#[tokio::test]
async fn iroh_transport_connect_builds_endpoint() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
assert!(transport.describe().starts_with("iroh://"));
assert!(!transport.endpoint_id().is_empty());
@@ -239,9 +248,7 @@ mod tests {
#[tokio::test]
async fn iroh_transport_from_endpoint() {
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
.public()
.into();
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
let endpoint = acceptor.endpoint.clone();
let transport = IrohTransport::from_endpoint(node_id, endpoint);
@@ -318,4 +325,4 @@ mod tests {
transport.connect().await.unwrap();
let _server_stream = accept_handle.await.unwrap();
}
}
}

View File

@@ -13,13 +13,13 @@
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
mod tcp;
#[cfg(feature = "iroh")]
mod iroh_transport;
mod tcp;
pub use tcp::{TcpAcceptor, TcpTransport};
#[cfg(feature = "iroh")]
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
pub use tcp::{TcpAcceptor, TcpTransport};
#[cfg(feature = "tls")]
mod tls;
@@ -89,12 +89,8 @@ pub struct TransportInfo {
#[derive(Debug, Clone)]
pub enum TransportKind {
Tcp,
Tls {
server_name: Option<String>,
},
Iroh {
endpoint_id: String,
},
Tls { server_name: Option<String> },
Iroh { endpoint_id: String },
}
#[cfg(test)]
@@ -185,4 +181,4 @@ mod tests {
assert_eq!(endpoint_id, "abc123");
}
}
}
}

View File

@@ -159,4 +159,4 @@ mod tests {
.unwrap();
assert_ne!(acceptor.listen_addr().port(), 0);
}
}
}

View File

@@ -7,7 +7,9 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
use tokio_rustls::{
client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector,
};
#[cfg(feature = "acme")]
use rustls::crypto::aws_lc_rs::default_provider;
@@ -169,7 +171,9 @@ impl TlsAcceptor {
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
.with_no_client_auth()
.with_cert_resolver(acme_resolver);
server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
server_config
.alpn_protocols
.push(ACME_TLS_ALPN_NAME.to_vec());
let server_config = Arc::new(server_config);
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
@@ -195,11 +199,7 @@ impl TransportAcceptor for TlsAcceptor {
let (tcp_stream, remote_addr) = self.listener.accept().await?;
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
let server_name = tls_stream
.get_ref()
.1
.server_name()
.map(|s| s.to_string());
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
let info = TransportInfo {
remote_addr: Some(remote_addr),
@@ -324,10 +324,7 @@ mod tests {
let (mut server, info) = accept_handle.await.unwrap();
assert!(info.remote_addr.is_some());
assert!(matches!(
info.transport_kind,
TransportKind::Tls { .. }
));
assert!(matches!(info.transport_kind, TransportKind::Tls { .. }));
client.write_all(b"hello tls").await.unwrap();
let mut buf = [0u8; 9];
@@ -429,4 +426,4 @@ mod tests {
let verifier = NoVerifier;
assert!(verifier.supported_verify_schemes().len() > 0);
}
}
}

View File

@@ -1,2 +1,2 @@
#[tokio::test]
async fn auth_placeholder() {}
async fn auth_placeholder() {}

View File

@@ -1,2 +1,2 @@
#[tokio::test]
async fn client_placeholder() {}
async fn client_placeholder() {}

View File

@@ -1,2 +1,2 @@
#[tokio::test]
async fn server_placeholder() {}
async fn server_placeholder() {}

View File

@@ -1,4 +1,6 @@
use alknet_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
use alknet_core::testutil::{
mock_pair, MockTransport, MockTransportAcceptor, Transport, TransportAcceptor,
};
#[tokio::test]
async fn mock_transport_connect() {
@@ -23,4 +25,4 @@ async fn mock_pair_communicates() {
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
}

View File

@@ -328,7 +328,12 @@ impl russh::server::Handler for NapiServerHandler {
session: &mut russh::server::Session,
) -> std::result::Result<(), Self::Error> {
tracing::warn!(channel = %channel, "rejected x11 request");
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
let _ = (
single_connection,
x11_auth_protocol,
x11_auth_cookie,
x11_screen_number,
);
let _ = session.channel_failure(channel);
Ok(())
}
@@ -348,7 +353,11 @@ impl russh::server::Handler for NapiServerHandler {
port: &mut u32,
_session: &mut russh::server::Session,
) -> std::result::Result<bool, Self::Error> {
tracing::warn!(address = address, port = *port, "rejected tcpip-forward request");
tracing::warn!(
address = address,
port = *port,
"rejected tcpip-forward request"
);
Ok(false)
}
@@ -367,7 +376,10 @@ impl russh::server::Handler for NapiServerHandler {
socket_path: &str,
_session: &mut russh::server::Session,
) -> std::result::Result<bool, Self::Error> {
tracing::warn!(socket_path = socket_path, "rejected streamlocal-forward request");
tracing::warn!(
socket_path = socket_path,
"rejected streamlocal-forward request"
);
Ok(false)
}
@@ -542,8 +554,8 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
})?,
);
let private_key =
alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| {
let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone())
.map_err(|e| {
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
})?;
@@ -635,26 +647,28 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
)
})?;
let acceptor = TlsAcceptor::bind(addr, certs, key, None).await.map_err(|e| {
napi::Error::new(
napi::Status::GenericFailure,
format!("tls bind failed: {}", e),
)
})?;
let acceptor = TlsAcceptor::bind(addr, certs, key, None)
.await
.map_err(|e| {
napi::Error::new(
napi::Status::GenericFailure,
format!("tls bind failed: {}", e),
)
})?;
let actual_listen = acceptor.listen_addr().to_string();
let auth_config = Arc::new(
ServerAuthConfig::from_keys_and_ca(authorized_keys_source, cert_authority_source)
.map_err(|e| {
napi::Error::new(
napi::Status::InvalidArg,
format!("auth config error: {}", e),
)
})?,
napi::Error::new(
napi::Status::InvalidArg,
format!("auth config error: {}", e),
)
})?,
);
let private_key =
alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| {
let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone())
.map_err(|e| {
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
})?;
@@ -728,11 +742,11 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
let auth_config = Arc::new(
ServerAuthConfig::from_keys_and_ca(authorized_keys_source, cert_authority_source)
.map_err(|e| {
napi::Error::new(
napi::Status::InvalidArg,
format!("auth config error: {}", e),
)
})?,
napi::Error::new(
napi::Status::InvalidArg,
format!("auth config error: {}", e),
)
})?,
);
let private_key =

View File

@@ -10,8 +10,6 @@ use std::net::SocketAddr;
use std::process;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use clap::{Parser, Subcommand, ValueEnum};
use alknet_core::auth::keys::KeySource;
use alknet_core::client::{ConnectOptions, TransportMode};
use alknet_core::server::{ServeOptions, ServeTransportMode, Server};
@@ -21,6 +19,8 @@ use alknet_core::transport::TcpTransport;
#[cfg(feature = "tls")]
use alknet_core::transport::TlsTransport;
use alknet_core::transport::Transport;
use anyhow::{anyhow, Result};
use clap::{Parser, Subcommand, ValueEnum};
#[derive(Parser)]
#[command(name = "alknet", version, about = "Alknet SSH tunnel tool")]
@@ -76,7 +76,7 @@ enum Commands {
insecure: bool,
},
#[command( about = "Start the alknet server (accept SSH connections)")]
#[command(about = "Start the alknet server (accept SSH connections)")]
Serve {
#[arg(long, help = "SSH host key path (required)")]
key: String,