feat(core): implement StaticConfig/DynamicConfig split with ArcSwap hot-reload
Split alknet-core configuration into StaticConfig (immutable after startup) and DynamicConfig (hot-reloadable at runtime via ArcSwap). - Add StaticConfig struct in config/static_config.rs with all fields per ADR-030 - Add DynamicConfig struct with AuthPolicy, ForwardingPolicy, RateLimitConfig - Add ForwardingPolicy with allow_all()/deny_all() defaults (ADR-031) - Add ConfigReloadHandle with reload() method for runtime config updates - Replace Arc<ServerAuthConfig> with Arc<ArcSwap<DynamicConfig>> in ServerHandler - Add config_reload_handle() to Server for obtaining reload handles - Add AuthPolicy with authenticate_publickey/authenticate_certificate methods - All existing tests pass with the new config structure - Default DynamicConfig produces identical behavior to current code
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -75,6 +75,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"alknet-core",
|
"alknet-core",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"arc-swap",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"futures",
|
"futures",
|
||||||
"ipnetwork",
|
"ipnetwork",
|
||||||
@@ -185,6 +186,15 @@ version = "1.0.102"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "arc-swap"
|
||||||
|
version = "1.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207"
|
||||||
|
dependencies = [
|
||||||
|
"rustversion",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "asn1-rs"
|
name = "asn1-rs"
|
||||||
version = "0.6.2"
|
version = "0.6.2"
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ iroh = { version = "0.34", optional = true }
|
|||||||
url = { version = "2", optional = true }
|
url = { version = "2", optional = true }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
ipnetwork = "0.21.1"
|
ipnetwork = "0.21.1"
|
||||||
|
arc-swap = "1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] }
|
alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] }
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use russh::keys::{PrivateKey, PublicKey, decode_secret_key, parse_public_key_base64};
|
use russh::keys::{decode_secret_key, parse_public_key_base64, PrivateKey, PublicKey};
|
||||||
|
|
||||||
use crate::error::ConfigError;
|
use crate::error::ConfigError;
|
||||||
|
|
||||||
@@ -98,10 +98,7 @@ fn parse_authorized_keys_line(line: &str) -> Option<Result<(PublicKey, Vec<Strin
|
|||||||
|| parts[0].starts_with("principals=")
|
|| parts[0].starts_with("principals=")
|
||||||
{
|
{
|
||||||
let opts_str = parts[0];
|
let opts_str = parts[0];
|
||||||
options = opts_str
|
options = opts_str.split(',').map(|s| s.to_string()).collect();
|
||||||
.split(',')
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
.collect();
|
|
||||||
key_type_idx = 1;
|
key_type_idx = 1;
|
||||||
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
|
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
|
||||||
key_type_idx = 0;
|
key_type_idx = 0;
|
||||||
@@ -218,9 +215,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_authorized_keys_multiple_entries() {
|
fn parse_authorized_keys_multiple_entries() {
|
||||||
let content = format!(
|
let content = format!("{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n");
|
||||||
"{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n"
|
|
||||||
);
|
|
||||||
let f = make_authorized_keys(&content);
|
let f = make_authorized_keys(&content);
|
||||||
let source = KeySource::File(f.path().to_path_buf());
|
let source = KeySource::File(f.path().to_path_buf());
|
||||||
let keys = load_public_keys(source).unwrap();
|
let keys = load_public_keys(source).unwrap();
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ pub mod keys;
|
|||||||
pub mod server_auth;
|
pub mod server_auth;
|
||||||
|
|
||||||
pub use client_auth::{ClientAuthConfig, ClientHandler};
|
pub use client_auth::{ClientAuthConfig, ClientHandler};
|
||||||
pub use keys::{CertAuthorityEntry, KeySource, load_private_key, load_public_keys};
|
pub use keys::{load_private_key, load_public_keys, CertAuthorityEntry, KeySource};
|
||||||
pub use server_auth::ServerAuthConfig;
|
pub use server_auth::ServerAuthConfig;
|
||||||
@@ -13,7 +13,7 @@ use ipnetwork::IpNetwork;
|
|||||||
use russh::keys::helpers::EncodedExt;
|
use russh::keys::helpers::EncodedExt;
|
||||||
use russh::keys::{Certificate, PublicKey};
|
use russh::keys::{Certificate, PublicKey};
|
||||||
|
|
||||||
use super::keys::{CertAuthorityEntry, KeySource, load_cert_authority_entries, load_public_keys};
|
use super::keys::{load_cert_authority_entries, load_public_keys, CertAuthorityEntry, KeySource};
|
||||||
use crate::error::AuthError;
|
use crate::error::AuthError;
|
||||||
|
|
||||||
/// Server-side authentication configuration.
|
/// Server-side authentication configuration.
|
||||||
@@ -41,10 +41,7 @@ impl ServerAuthConfig {
|
|||||||
None => HashSet::new(),
|
None => HashSet::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoded_keys: HashSet<Vec<u8>> = authorized_keys
|
let encoded_keys: HashSet<Vec<u8>> = authorized_keys.iter().map(encode_key_data).collect();
|
||||||
.iter()
|
|
||||||
.map(encode_key_data)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let cert_authorities = match cert_authority_source {
|
let cert_authorities = match cert_authority_source {
|
||||||
Some(src) => load_cert_authority_entries(src)?,
|
Some(src) => load_cert_authority_entries(src)?,
|
||||||
@@ -135,10 +132,7 @@ fn check_critical_options(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_extensions(
|
fn check_extensions(cert: &Certificate, ca_entry: &CertAuthorityEntry) -> Result<(), AuthError> {
|
||||||
cert: &Certificate,
|
|
||||||
ca_entry: &CertAuthorityEntry,
|
|
||||||
) -> Result<(), AuthError> {
|
|
||||||
let ca_permit_port_forwarding = ca_entry
|
let ca_permit_port_forwarding = ca_entry
|
||||||
.options
|
.options
|
||||||
.iter()
|
.iter()
|
||||||
@@ -188,8 +182,8 @@ fn check_source_address(allowed: &str, client_ip: Option<IpAddr>) -> bool {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use rand_core::OsRng;
|
use rand_core::OsRng;
|
||||||
use russh::keys::{Certificate, PrivateKey, decode_secret_key};
|
|
||||||
use russh::keys::ssh_key::certificate::{Builder, CertType};
|
use russh::keys::ssh_key::certificate::{Builder, CertType};
|
||||||
|
use russh::keys::{decode_secret_key, Certificate, PrivateKey};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||||
@@ -218,12 +212,8 @@ mod tests {
|
|||||||
principals: Vec<&str>,
|
principals: Vec<&str>,
|
||||||
) -> Certificate {
|
) -> Certificate {
|
||||||
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
|
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
|
||||||
let mut builder = Builder::new_with_random_nonce(
|
let mut builder =
|
||||||
&mut OsRng,
|
Builder::new_with_random_nonce(&mut OsRng, key_data, valid_after, valid_before)
|
||||||
key_data,
|
|
||||||
valid_after,
|
|
||||||
valid_before,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
builder.cert_type(CertType::User).unwrap();
|
builder.cert_type(CertType::User).unwrap();
|
||||||
@@ -252,11 +242,7 @@ mod tests {
|
|||||||
} else {
|
} else {
|
||||||
format!("cert-authority,{}", options.join(","))
|
format!("cert-authority,{}", options.join(","))
|
||||||
};
|
};
|
||||||
let line = format!(
|
let line = format!("{} {} CA\n", opts, ca_pub.to_openssh().unwrap());
|
||||||
"{} {} CA\n",
|
|
||||||
opts,
|
|
||||||
ca_pub.to_openssh().unwrap()
|
|
||||||
);
|
|
||||||
f.write_all(line.as_bytes()).unwrap();
|
f.write_all(line.as_bytes()).unwrap();
|
||||||
f.flush().unwrap();
|
f.flush().unwrap();
|
||||||
f
|
f
|
||||||
@@ -357,13 +343,8 @@ mod tests {
|
|||||||
let user_pub = user_key.public_key().clone();
|
let user_pub = user_key.public_key().clone();
|
||||||
let now = now_secs();
|
let now = now_secs();
|
||||||
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
|
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
|
||||||
let mut builder = Builder::new_with_random_nonce(
|
let mut builder =
|
||||||
&mut OsRng,
|
Builder::new_with_random_nonce(&mut OsRng, key_data, now - 60, now + 3600).unwrap();
|
||||||
key_data,
|
|
||||||
now - 60,
|
|
||||||
now + 3600,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
builder.cert_type(CertType::User).unwrap();
|
builder.cert_type(CertType::User).unwrap();
|
||||||
builder.all_principals_valid().unwrap();
|
builder.all_principals_valid().unwrap();
|
||||||
let cert = builder.sign(&ca_key).unwrap();
|
let cert = builder.sign(&ca_key).unwrap();
|
||||||
@@ -383,7 +364,13 @@ mod tests {
|
|||||||
let other_ca_key = load_other_key();
|
let other_ca_key = load_other_key();
|
||||||
let user_pub = user_key.public_key().clone();
|
let user_pub = user_key.public_key().clone();
|
||||||
let now = now_secs();
|
let now = now_secs();
|
||||||
let cert = make_cert(&other_ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
|
let cert = make_cert(
|
||||||
|
&other_ca_key,
|
||||||
|
&user_pub,
|
||||||
|
now - 60,
|
||||||
|
now + 3600,
|
||||||
|
vec!["testuser"],
|
||||||
|
);
|
||||||
let ca_key = load_ca_key();
|
let ca_key = load_ca_key();
|
||||||
let ca_pub = ca_key.public_key().clone();
|
let ca_pub = ca_key.public_key().clone();
|
||||||
let f = make_ca_file(&ca_pub, &[]);
|
let f = make_ca_file(&ca_pub, &[]);
|
||||||
@@ -398,8 +385,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn no_config_accepts_nothing() {
|
fn no_config_accepts_nothing() {
|
||||||
let config =
|
let config = ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
|
||||||
ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
|
|
||||||
let other_pub = load_other_key().public_key().clone();
|
let other_pub = load_other_key().public_key().clone();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.authenticate_publickey(&other_pub),
|
config.authenticate_publickey(&other_pub),
|
||||||
|
|||||||
@@ -113,11 +113,7 @@ impl<T: Transport> ChannelManager<T> {
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
self.inner
|
self.inner.forwards.write().await.insert(ForwardRequest {
|
||||||
.forwards
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert(ForwardRequest {
|
|
||||||
addr: addr.to_string(),
|
addr: addr.to_string(),
|
||||||
port,
|
port,
|
||||||
});
|
});
|
||||||
@@ -132,11 +128,7 @@ impl<T: Transport> ChannelManager<T> {
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||||
|
|
||||||
self.inner
|
self.inner.forwards.write().await.remove(&ForwardRequest {
|
||||||
.forwards
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.remove(&ForwardRequest {
|
|
||||||
addr: addr.to_string(),
|
addr: addr.to_string(),
|
||||||
port,
|
port,
|
||||||
});
|
});
|
||||||
@@ -226,10 +218,7 @@ impl<T: Transport> ChannelManager<T> {
|
|||||||
for fwd in forwards.iter() {
|
for fwd in forwards.iter() {
|
||||||
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
debug!(
|
debug!("re-registered tcpip_forward: {}:{}", fwd.addr, fwd.port);
|
||||||
"re-registered tcpip_forward: {}:{}",
|
|
||||||
fwd.addr, fwd.port
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
|||||||
@@ -197,10 +197,7 @@ pub struct ClientSession<T: Transport> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Transport> ClientSession<T> {
|
impl<T: Transport> ClientSession<T> {
|
||||||
pub async fn new(
|
pub async fn new(opts: ConnectOptions, transport: Arc<T>) -> Result<Self, ConnectError> {
|
||||||
opts: ConnectOptions,
|
|
||||||
transport: Arc<T>,
|
|
||||||
) -> Result<Self, ConnectError> {
|
|
||||||
opts.validate().map_err(ConnectError::Config)?;
|
opts.validate().map_err(ConnectError::Config)?;
|
||||||
|
|
||||||
let auth_config = Arc::new(
|
let auth_config = Arc::new(
|
||||||
@@ -283,13 +280,10 @@ impl<T: Transport> ClientSession<T> {
|
|||||||
let remote_specs = build_remote_specs(&self.opts)?;
|
let remote_specs = build_remote_specs(&self.opts)?;
|
||||||
|
|
||||||
for spec in &remote_specs {
|
for spec in &remote_specs {
|
||||||
let remote_forwarder = RemoteForwarder::new(spec.clone())
|
let remote_forwarder =
|
||||||
.map_err(|_| ConnectError::ForwardFailed)?;
|
RemoteForwarder::new(spec.clone()).map_err(|_| ConnectError::ForwardFailed)?;
|
||||||
let mut h = self.handle.lock().await;
|
let mut h = self.handle.lock().await;
|
||||||
remote_forwarder
|
remote_forwarder.register(&mut h).await.map_err(|_| {
|
||||||
.register(&mut h)
|
|
||||||
.await
|
|
||||||
.map_err(|_| {
|
|
||||||
warn!("failed to register remote forward {}", spec);
|
warn!("failed to register remote forward {}", spec);
|
||||||
ConnectError::ForwardFailed
|
ConnectError::ForwardFailed
|
||||||
})?;
|
})?;
|
||||||
@@ -307,7 +301,9 @@ impl<T: Transport> ClientSession<T> {
|
|||||||
let fwd_shutdown = self.shutdown_rx.clone();
|
let fwd_shutdown = self.shutdown_rx.clone();
|
||||||
let forward_task = tokio::spawn(async move {
|
let forward_task = tokio::spawn(async move {
|
||||||
crate::client::forward::run_local_forwarders(
|
crate::client::forward::run_local_forwarders(
|
||||||
local_forwarders, fwd_handle, fwd_shutdown,
|
local_forwarders,
|
||||||
|
fwd_handle,
|
||||||
|
fwd_shutdown,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
@@ -358,7 +354,14 @@ impl<T: Transport> ClientSession<T> {
|
|||||||
|
|
||||||
let handler = ClientHandler::from_config(&reconnect_auth);
|
let handler = ClientHandler::from_config(&reconnect_auth);
|
||||||
let username = reconnect_username.clone();
|
let username = reconnect_username.clone();
|
||||||
match establish_session(&*reconnect_transport, handler, &reconnect_auth, &username).await {
|
match establish_session(
|
||||||
|
&*reconnect_transport,
|
||||||
|
handler,
|
||||||
|
&reconnect_auth,
|
||||||
|
&username,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(new_handle) => {
|
Ok(new_handle) => {
|
||||||
info!("reconnection successful");
|
info!("reconnection successful");
|
||||||
{
|
{
|
||||||
@@ -370,8 +373,13 @@ impl<T: Transport> ClientSession<T> {
|
|||||||
Ok(rf) => {
|
Ok(rf) => {
|
||||||
let mut h = reconnect_handle.lock().await;
|
let mut h = reconnect_handle.lock().await;
|
||||||
match rf.register(&mut h).await {
|
match rf.register(&mut h).await {
|
||||||
Ok(_) => debug!("re-registered remote forward: {}", spec),
|
Ok(_) => {
|
||||||
Err(e) => warn!("failed to re-register remote forward {}: {e}", spec),
|
debug!("re-registered remote forward: {}", spec)
|
||||||
|
}
|
||||||
|
Err(e) => warn!(
|
||||||
|
"failed to re-register remote forward {}: {e}",
|
||||||
|
spec
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => warn!("failed to create remote forwarder: {e}"),
|
Err(e) => warn!("failed to create remote forwarder: {e}"),
|
||||||
@@ -493,12 +501,10 @@ fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>,
|
|||||||
name: format!("invalid forward spec: {}", spec_str),
|
name: format!("invalid forward spec: {}", spec_str),
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
forwarders.push(
|
forwarders.push(LocalForwarder::new(spec).map_err(|e| {
|
||||||
LocalForwarder::new(spec).map_err(|e| {
|
|
||||||
warn!("failed to create local forwarder: {}", e);
|
warn!("failed to create local forwarder: {}", e);
|
||||||
ConnectError::ForwardFailed
|
ConnectError::ForwardFailed
|
||||||
})?,
|
})?);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(forwarders)
|
Ok(forwarders)
|
||||||
}
|
}
|
||||||
@@ -576,7 +582,10 @@ mod tests {
|
|||||||
assert_eq!(opts.forwards.len(), 1);
|
assert_eq!(opts.forwards.len(), 1);
|
||||||
assert_eq!(opts.remote_forwards.len(), 1);
|
assert_eq!(opts.remote_forwards.len(), 1);
|
||||||
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
||||||
assert_eq!(opts.iroh_relay.as_deref(), Some("https://relay.example.com"));
|
assert_eq!(
|
||||||
|
opts.iroh_relay.as_deref(),
|
||||||
|
Some("https://relay.example.com")
|
||||||
|
);
|
||||||
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
|
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
|
||||||
assert!(opts.insecure);
|
assert!(opts.insecure);
|
||||||
}
|
}
|
||||||
@@ -650,9 +659,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn connect_error_variants() {
|
fn connect_error_variants() {
|
||||||
assert_eq!(ConnectError::ConnectionFailed.to_string(), "connection failed");
|
assert_eq!(
|
||||||
assert_eq!(ConnectError::AuthFailed.to_string(), "authentication failed");
|
ConnectError::ConnectionFailed.to_string(),
|
||||||
assert_eq!(ConnectError::ForwardFailed.to_string(), "forward setup failed");
|
"connection failed"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ConnectError::AuthFailed.to_string(),
|
||||||
|
"authentication failed"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ConnectError::ForwardFailed.to_string(),
|
||||||
|
"forward setup failed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -703,7 +721,10 @@ mod tests {
|
|||||||
let transport = Arc::new(FailTransport);
|
let transport = Arc::new(FailTransport);
|
||||||
let result = ClientSession::new(opts, transport).await;
|
let result = ClientSession::new(opts, transport).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
assert!(matches!(
|
||||||
|
result.err().unwrap(),
|
||||||
|
ConnectError::ConnectionFailed
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -714,7 +735,10 @@ mod tests {
|
|||||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||||
let result = ClientSession::new(opts, transport).await;
|
let result = ClientSession::new(opts, transport).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(matches!(result.err().unwrap(), ConnectError::ConnectionFailed));
|
assert!(matches!(
|
||||||
|
result.err().unwrap(),
|
||||||
|
ConnectError::ConnectionFailed
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -750,7 +774,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_remote_specs_valid() {
|
fn build_remote_specs_valid() {
|
||||||
let opts = ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
let opts =
|
||||||
|
ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
||||||
let result = build_remote_specs(&opts);
|
let result = build_remote_specs(&opts);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
assert_eq!(result.unwrap().len(), 1);
|
assert_eq!(result.unwrap().len(), 1);
|
||||||
@@ -798,8 +823,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn integration_mock_transport_session() {
|
async fn integration_mock_transport_session() {
|
||||||
use crate::socks5::{ChannelOpener, ChannelOpenError};
|
use crate::socks5::{ChannelOpenError, ChannelOpener};
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
struct MockOpener;
|
struct MockOpener;
|
||||||
@@ -839,9 +864,7 @@ mod tests {
|
|||||||
conn.read_exact(&mut auth_resp).await.unwrap();
|
conn.read_exact(&mut auth_resp).await.unwrap();
|
||||||
assert_eq!(auth_resp, [0x05, 0x00]);
|
assert_eq!(auth_resp, [0x05, 0x00]);
|
||||||
|
|
||||||
let connect_req = [
|
let connect_req = [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80];
|
||||||
0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80,
|
|
||||||
];
|
|
||||||
conn.write_all(&connect_req).await.unwrap();
|
conn.write_all(&connect_req).await.unwrap();
|
||||||
|
|
||||||
let mut reply = [0u8; 10];
|
let mut reply = [0u8; 10];
|
||||||
|
|||||||
@@ -205,12 +205,7 @@ async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
|
|||||||
|
|
||||||
let handle_guard = handle.lock().await;
|
let handle_guard = handle.lock().await;
|
||||||
let channel = handle_guard
|
let channel = handle_guard
|
||||||
.channel_open_direct_tcpip(
|
.channel_open_direct_tcpip(remote_host, remote_port as u32, &local_addr, 0)
|
||||||
remote_host,
|
|
||||||
remote_port as u32,
|
|
||||||
&local_addr,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||||
source: Box::new(e) as _,
|
source: Box::new(e) as _,
|
||||||
@@ -470,10 +465,7 @@ mod tests {
|
|||||||
let bound_addr = listener.local_addr().unwrap();
|
let bound_addr = listener.local_addr().unwrap();
|
||||||
drop(listener);
|
drop(listener);
|
||||||
|
|
||||||
let spec = PortForwardSpec::local(&format!(
|
let spec = PortForwardSpec::local(&format!("127.0.0.1:{}:remote:5432", bound_addr.port()))
|
||||||
"127.0.0.1:{}:remote:5432",
|
|
||||||
bound_addr.port()
|
|
||||||
))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let forwarder = LocalForwarder::new(spec).unwrap();
|
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||||
assert_eq!(forwarder.local_port(), bound_addr.port());
|
assert_eq!(forwarder.local_port(), bound_addr.port());
|
||||||
|
|||||||
395
crates/alknet-core/src/config/dynamic_config.rs
Normal file
395
crates/alknet-core/src/config/dynamic_config.rs
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
|
|
||||||
|
use crate::auth::ServerAuthConfig;
|
||||||
|
|
||||||
|
pub struct AuthPolicy {
|
||||||
|
pub authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||||
|
pub cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||||
|
encoded_keys: std::collections::HashSet<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_key_data(key: &russh::keys::PublicKey) -> Vec<u8> {
|
||||||
|
use russh::keys::helpers::EncodedExt;
|
||||||
|
key.key_data().encoded().unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthPolicy {
|
||||||
|
pub fn new(
|
||||||
|
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||||
|
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||||
|
) -> Self {
|
||||||
|
let encoded_keys = authorized_keys.iter().map(encode_key_data).collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
authorized_keys,
|
||||||
|
cert_authorities,
|
||||||
|
encoded_keys,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_server_auth_config(config: ServerAuthConfig) -> Self {
|
||||||
|
Self::new(config.authorized_keys, config.cert_authorities)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self::new(std::collections::HashSet::new(), Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate_publickey(
|
||||||
|
&self,
|
||||||
|
key: &russh::keys::PublicKey,
|
||||||
|
) -> Result<(), crate::error::AuthError> {
|
||||||
|
let encoded = encode_key_data(key);
|
||||||
|
if self.encoded_keys.contains(&encoded) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Err(crate::error::AuthError::KeyRejected)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate_certificate(
|
||||||
|
&self,
|
||||||
|
cert: &russh::keys::Certificate,
|
||||||
|
user: &str,
|
||||||
|
client_ip: Option<std::net::IpAddr>,
|
||||||
|
) -> Result<(), crate::error::AuthError> {
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
let matching_ca = self
|
||||||
|
.cert_authorities
|
||||||
|
.iter()
|
||||||
|
.find(|ca| cert.signature_key() == ca.public_key.key_data());
|
||||||
|
|
||||||
|
let ca_entry = match matching_ca {
|
||||||
|
Some(entry) => entry,
|
||||||
|
None => return Err(crate::error::AuthError::CertInvalid),
|
||||||
|
};
|
||||||
|
|
||||||
|
if cert.verify_signature().is_err() {
|
||||||
|
return Err(crate::error::AuthError::CertInvalid);
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = SystemTime::now();
|
||||||
|
let now_secs = now
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
|
||||||
|
return Err(crate::error::AuthError::CertExpired);
|
||||||
|
}
|
||||||
|
|
||||||
|
let principals = cert.valid_principals();
|
||||||
|
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
|
||||||
|
return Err(crate::error::AuthError::CertPrincipalMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
check_critical_options(cert, ca_entry, client_ip)?;
|
||||||
|
check_extensions(cert, ca_entry)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_critical_options(
|
||||||
|
cert: &russh::keys::Certificate,
|
||||||
|
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||||
|
client_ip: Option<std::net::IpAddr>,
|
||||||
|
) -> Result<(), crate::error::AuthError> {
|
||||||
|
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
|
||||||
|
|
||||||
|
for (name, data) in cert.critical_options().iter() {
|
||||||
|
match name.as_str() {
|
||||||
|
"source-address" => {
|
||||||
|
if !check_source_address(data, client_ip) {
|
||||||
|
return Err(crate::error::AuthError::CertInvalid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"force-command" => {}
|
||||||
|
"no-pty" => {}
|
||||||
|
_ => {
|
||||||
|
let _ = ca_has_no_pty;
|
||||||
|
return Err(crate::error::AuthError::CertInvalid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_extensions(
|
||||||
|
cert: &russh::keys::Certificate,
|
||||||
|
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||||
|
) -> Result<(), crate::error::AuthError> {
|
||||||
|
let ca_permit_port_forwarding = ca_entry
|
||||||
|
.options
|
||||||
|
.iter()
|
||||||
|
.any(|o| o == "permit-port-forwarding");
|
||||||
|
|
||||||
|
if ca_permit_port_forwarding {
|
||||||
|
let cert_allows = cert
|
||||||
|
.extensions()
|
||||||
|
.iter()
|
||||||
|
.any(|(n, _)| n == "permit-port-forwarding");
|
||||||
|
if !cert_allows {
|
||||||
|
return Err(crate::error::AuthError::CertInvalid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_source_address(allowed: &str, client_ip: Option<std::net::IpAddr>) -> bool {
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
let Some(ip) = client_ip else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
for pattern in allowed.split(',') {
|
||||||
|
let pattern = pattern.trim();
|
||||||
|
if pattern.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(cidr) = IpNetwork::from_str(pattern) {
|
||||||
|
if cidr.contains(ip) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(net_ip) = IpAddr::from_str(pattern) {
|
||||||
|
if net_ip == ip {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for AuthPolicy {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("AuthPolicy")
|
||||||
|
.field("authorized_keys_count", &self.authorized_keys.len())
|
||||||
|
.field("cert_authorities_count", &self.cert_authorities.len())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for AuthPolicy {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
authorized_keys: self.authorized_keys.clone(),
|
||||||
|
cert_authorities: self.cert_authorities.clone(),
|
||||||
|
encoded_keys: self.encoded_keys.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum ForwardingAction {
|
||||||
|
Allow,
|
||||||
|
Deny,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ForwardingRule {
|
||||||
|
pub action: ForwardingAction,
|
||||||
|
pub principals: Vec<String>,
|
||||||
|
pub transports: Vec<crate::server::handler::TransportKind>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ForwardingPolicy {
|
||||||
|
pub default: ForwardingAction,
|
||||||
|
pub rules: Vec<ForwardingRule>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ForwardingPolicy {
|
||||||
|
pub fn allow_all() -> Self {
|
||||||
|
Self {
|
||||||
|
default: ForwardingAction::Allow,
|
||||||
|
rules: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deny_all() -> Self {
|
||||||
|
Self {
|
||||||
|
default: ForwardingAction::Deny,
|
||||||
|
rules: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RateLimitConfig {
|
||||||
|
pub max_connections_per_ip: usize,
|
||||||
|
pub max_auth_attempts: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RateLimitConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_connections_per_ip: 0,
|
||||||
|
max_auth_attempts: 10,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DynamicConfig {
|
||||||
|
pub auth: AuthPolicy,
|
||||||
|
pub forwarding: ForwardingPolicy,
|
||||||
|
pub rate_limits: RateLimitConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DynamicConfig {
|
||||||
|
pub fn new(auth: AuthPolicy) -> Self {
|
||||||
|
Self {
|
||||||
|
auth,
|
||||||
|
forwarding: ForwardingPolicy::allow_all(),
|
||||||
|
rate_limits: RateLimitConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_forwarding_policy(mut self, policy: ForwardingPolicy) -> Self {
|
||||||
|
self.forwarding = policy;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_rate_limits(mut self, limits: RateLimitConfig) -> Self {
|
||||||
|
self.rate_limits = limits;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DynamicConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
auth: AuthPolicy::empty(),
|
||||||
|
forwarding: ForwardingPolicy::allow_all(),
|
||||||
|
rate_limits: RateLimitConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConfigReloadHandle {
|
||||||
|
pub(crate) dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConfigReloadHandle {
|
||||||
|
pub fn reload(&self, new_config: DynamicConfig) {
|
||||||
|
self.dynamic.store(Arc::new(new_config));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dynamic(&self) -> Arc<DynamicConfig> {
|
||||||
|
self.dynamic.load_full()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ConfigReloadHandle {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ConfigReloadHandle").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_dynamic_config() -> (Arc<ArcSwap<DynamicConfig>>, ConfigReloadHandle) {
|
||||||
|
let inner = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||||
|
let handle = ConfigReloadHandle {
|
||||||
|
dynamic: Arc::clone(&inner),
|
||||||
|
};
|
||||||
|
(inner, handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forwarding_policy_allow_all_default() {
|
||||||
|
let policy = ForwardingPolicy::allow_all();
|
||||||
|
assert_eq!(policy.default, ForwardingAction::Allow);
|
||||||
|
assert!(policy.rules.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forwarding_policy_deny_all() {
|
||||||
|
let policy = ForwardingPolicy::deny_all();
|
||||||
|
assert_eq!(policy.default, ForwardingAction::Deny);
|
||||||
|
assert!(policy.rules.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dynamic_config_default() {
|
||||||
|
let config = DynamicConfig::default();
|
||||||
|
assert_eq!(config.forwarding.default, ForwardingAction::Allow);
|
||||||
|
assert_eq!(config.rate_limits.max_connections_per_ip, 0);
|
||||||
|
assert_eq!(config.rate_limits.max_auth_attempts, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_reload_handle_updates_dynamic() {
|
||||||
|
let (arc_swap, handle) = new_dynamic_config();
|
||||||
|
let initial = arc_swap.load();
|
||||||
|
assert_eq!(initial.forwarding.default, ForwardingAction::Allow);
|
||||||
|
|
||||||
|
let new_config = DynamicConfig {
|
||||||
|
auth: AuthPolicy::empty(),
|
||||||
|
forwarding: ForwardingPolicy::deny_all(),
|
||||||
|
rate_limits: RateLimitConfig::default(),
|
||||||
|
};
|
||||||
|
handle.reload(new_config);
|
||||||
|
|
||||||
|
let updated = arc_swap.load();
|
||||||
|
assert_eq!(updated.forwarding.default, ForwardingAction::Deny);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dynamic_config_with_forwarding_policy_builder() {
|
||||||
|
let config = DynamicConfig::new(AuthPolicy::empty())
|
||||||
|
.with_forwarding_policy(ForwardingPolicy::deny_all());
|
||||||
|
assert_eq!(config.forwarding.default, ForwardingAction::Deny);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rate_limit_config_custom() {
|
||||||
|
let limits = RateLimitConfig {
|
||||||
|
max_connections_per_ip: 5,
|
||||||
|
max_auth_attempts: 3,
|
||||||
|
};
|
||||||
|
assert_eq!(limits.max_connections_per_ip, 5);
|
||||||
|
assert_eq!(limits.max_auth_attempts, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forwarding_action_equality() {
|
||||||
|
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
|
||||||
|
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
|
||||||
|
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_policy_empty_rejects_all() {
|
||||||
|
let policy = AuthPolicy::empty();
|
||||||
|
let key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||||
|
let other_ssh_key =
|
||||||
|
russh::keys::parse_public_key_base64(key_text.split_whitespace().nth(1).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
policy.authenticate_publickey(&other_ssh_key),
|
||||||
|
Err(crate::error::AuthError::KeyRejected)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_policy_debug_redacts_keys() {
|
||||||
|
let policy = AuthPolicy::empty();
|
||||||
|
let debug_str = format!("{:?}", policy);
|
||||||
|
assert!(debug_str.contains("authorized_keys_count"));
|
||||||
|
assert!(debug_str.contains("cert_authorities_count"));
|
||||||
|
}
|
||||||
|
}
|
||||||
8
crates/alknet-core/src/config/mod.rs
Normal file
8
crates/alknet-core/src/config/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
pub mod dynamic_config;
|
||||||
|
pub mod static_config;
|
||||||
|
|
||||||
|
pub use dynamic_config::{
|
||||||
|
new_dynamic_config, AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction,
|
||||||
|
ForwardingPolicy, ForwardingRule, RateLimitConfig,
|
||||||
|
};
|
||||||
|
pub use static_config::StaticConfig;
|
||||||
101
crates/alknet-core/src/config/static_config.rs
Normal file
101
crates/alknet-core/src/config/static_config.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use crate::server::handler::{ProxyConfig, ProxyMode};
|
||||||
|
use crate::server::serve::ServeTransportMode;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
pub struct StaticConfig {
|
||||||
|
pub transport_mode: ServeTransportMode,
|
||||||
|
pub listen_addr: String,
|
||||||
|
pub tls_cert: Option<String>,
|
||||||
|
pub tls_key: Option<String>,
|
||||||
|
pub acme_domain: Option<String>,
|
||||||
|
pub stealth: bool,
|
||||||
|
pub host_key: russh::keys::PrivateKey,
|
||||||
|
pub host_key_algorithm: russh::keys::Algorithm,
|
||||||
|
pub max_auth_attempts: usize,
|
||||||
|
pub max_connections_per_ip: usize,
|
||||||
|
pub proxy_config: Option<ProxyConfig>,
|
||||||
|
pub iroh_relay: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for StaticConfig {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("StaticConfig")
|
||||||
|
.field("transport_mode", &self.transport_mode)
|
||||||
|
.field("listen_addr", &self.listen_addr)
|
||||||
|
.field("tls_cert", &self.tls_cert.as_ref().map(|_| "<redacted>"))
|
||||||
|
.field("tls_key", &self.tls_key.as_ref().map(|_| "<redacted>"))
|
||||||
|
.field("acme_domain", &self.acme_domain)
|
||||||
|
.field("stealth", &self.stealth)
|
||||||
|
.field("host_key_algorithm", &self.host_key_algorithm)
|
||||||
|
.field("max_auth_attempts", &self.max_auth_attempts)
|
||||||
|
.field("max_connections_per_ip", &self.max_connections_per_ip)
|
||||||
|
.field("proxy_config", &self.proxy_config)
|
||||||
|
.field("iroh_relay", &self.iroh_relay)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StaticConfig {
|
||||||
|
pub fn from_serve_options(
|
||||||
|
opts: crate::server::serve::ServeOptions,
|
||||||
|
) -> Result<(Self, crate::config::DynamicConfig), crate::error::ConfigError> {
|
||||||
|
opts.validate()?;
|
||||||
|
|
||||||
|
let host_key = crate::auth::keys::load_private_key(opts.key.clone())?;
|
||||||
|
let host_key_algorithm = host_key.algorithm();
|
||||||
|
|
||||||
|
let auth_config = crate::auth::ServerAuthConfig::from_keys_and_ca(
|
||||||
|
opts.authorized_keys.clone(),
|
||||||
|
opts.cert_authority.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let auth_policy = crate::config::AuthPolicy::from_server_auth_config(auth_config);
|
||||||
|
|
||||||
|
let dynamic = crate::config::DynamicConfig::new(auth_policy);
|
||||||
|
|
||||||
|
let proxy_config = parse_proxy_config(opts.proxy.as_deref());
|
||||||
|
|
||||||
|
let static_config = StaticConfig {
|
||||||
|
transport_mode: opts.transport_mode,
|
||||||
|
listen_addr: opts.listen_addr,
|
||||||
|
tls_cert: opts.tls_cert,
|
||||||
|
tls_key: opts.tls_key,
|
||||||
|
acme_domain: opts.acme_domain,
|
||||||
|
stealth: opts.stealth,
|
||||||
|
host_key,
|
||||||
|
host_key_algorithm,
|
||||||
|
max_auth_attempts: opts.max_auth_attempts,
|
||||||
|
max_connections_per_ip: opts.max_connections_per_ip,
|
||||||
|
proxy_config,
|
||||||
|
iroh_relay: opts.iroh_relay,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((static_config, dynamic))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
|
||||||
|
proxy.map(|url| {
|
||||||
|
if url.starts_with("socks5://") {
|
||||||
|
let addr: SocketAddr = url
|
||||||
|
.strip_prefix("socks5://")
|
||||||
|
.unwrap()
|
||||||
|
.parse()
|
||||||
|
.expect("invalid socks5 proxy address");
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::Socks5(addr),
|
||||||
|
}
|
||||||
|
} else if url.starts_with("http://") {
|
||||||
|
let addr: SocketAddr = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.unwrap()
|
||||||
|
.parse()
|
||||||
|
.expect("invalid http connect proxy address");
|
||||||
|
ProxyConfig {
|
||||||
|
mode: ProxyMode::HttpConnect(addr),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("unsupported proxy URL scheme: {url}");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -97,7 +97,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn transport_error_display() {
|
fn transport_error_display() {
|
||||||
assert_eq!(TransportError::ConnectionFailed.to_string(), "connection failed");
|
assert_eq!(
|
||||||
|
TransportError::ConnectionFailed.to_string(),
|
||||||
|
"connection failed"
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
TransportError::HandshakeFailed {
|
TransportError::HandshakeFailed {
|
||||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
|
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
|
||||||
@@ -120,13 +123,19 @@ mod tests {
|
|||||||
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
|
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
|
||||||
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
|
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
|
||||||
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
|
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
|
||||||
assert_eq!(AuthError::CertPrincipalMismatch.to_string(), "certificate principal mismatch");
|
assert_eq!(
|
||||||
|
AuthError::CertPrincipalMismatch.to_string(),
|
||||||
|
"certificate principal mismatch"
|
||||||
|
);
|
||||||
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
|
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn channel_error_display() {
|
fn channel_error_display() {
|
||||||
assert_eq!(ChannelError::TargetUnreachable.to_string(), "target unreachable");
|
assert_eq!(
|
||||||
|
ChannelError::TargetUnreachable.to_string(),
|
||||||
|
"target unreachable"
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ChannelError::ProxyConnectFailed {
|
ChannelError::ProxyConnectFailed {
|
||||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||||
@@ -160,7 +169,10 @@ mod tests {
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
"bind failed"
|
"bind failed"
|
||||||
);
|
);
|
||||||
assert_eq!(ConfigError::IncompatibleOptions.to_string(), "incompatible options");
|
assert_eq!(
|
||||||
|
ConfigError::IncompatibleOptions.to_string(),
|
||||||
|
"incompatible options"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -184,7 +196,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn forward_error_display() {
|
fn forward_error_display() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ForwardError::InvalidSpec { spec: "bad".to_string() }.to_string(),
|
ForwardError::InvalidSpec {
|
||||||
|
spec: "bad".to_string()
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
"invalid port forward spec: bad"
|
"invalid port forward spec: bad"
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -209,7 +224,9 @@ mod tests {
|
|||||||
let forward_err = ForwardError::BindFailed { source: io_err };
|
let forward_err = ForwardError::BindFailed { source: io_err };
|
||||||
assert!(forward_err.source().is_some());
|
assert!(forward_err.source().is_some());
|
||||||
|
|
||||||
let plain = ForwardError::InvalidSpec { spec: "bad".to_string() };
|
let plain = ForwardError::InvalidSpec {
|
||||||
|
spec: "bad".to_string(),
|
||||||
|
};
|
||||||
assert!(plain.source().is_none());
|
assert!(plain.source().is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -50,18 +50,23 @@
|
|||||||
//! }
|
//! }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
pub mod transport;
|
|
||||||
pub mod client;
|
|
||||||
pub mod server;
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod socks5;
|
pub mod client;
|
||||||
|
pub mod config;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod server;
|
||||||
|
pub mod socks5;
|
||||||
|
pub mod transport;
|
||||||
|
|
||||||
#[cfg(feature = "testutil")]
|
#[cfg(feature = "testutil")]
|
||||||
pub mod testutil;
|
pub mod testutil;
|
||||||
|
|
||||||
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
|
||||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
|
||||||
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||||
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||||
pub use server::serve::{Server, ServeError, ServeOptions, ServeTransportMode};
|
pub use config::{
|
||||||
|
AuthPolicy, ConfigReloadHandle, DynamicConfig, ForwardingAction, ForwardingPolicy,
|
||||||
|
ForwardingRule, RateLimitConfig, StaticConfig,
|
||||||
|
};
|
||||||
|
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||||
|
pub use server::serve::{ServeError, ServeOptions, ServeTransportMode, Server};
|
||||||
|
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|||||||
@@ -46,7 +46,10 @@ async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyErr
|
|||||||
.map_err(|e| map_connection_error(e, target))
|
.map_err(|e| map_connection_error(e, target))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn connect_socks5(target: SocketAddr, proxy_addr: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
async fn connect_socks5(
|
||||||
|
target: SocketAddr,
|
||||||
|
proxy_addr: SocketAddr,
|
||||||
|
) -> Result<TcpStream, ChannelProxyError> {
|
||||||
let mut stream = TcpStream::connect(proxy_addr)
|
let mut stream = TcpStream::connect(proxy_addr)
|
||||||
.await
|
.await
|
||||||
.map_err(ChannelProxyError::from)?;
|
.map_err(ChannelProxyError::from)?;
|
||||||
@@ -134,10 +137,7 @@ async fn connect_http_connect(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let response_str = String::from_utf8_lossy(&response);
|
let response_str = String::from_utf8_lossy(&response);
|
||||||
let status_line = response_str
|
let status_line = response_str.lines().next().unwrap_or("");
|
||||||
.lines()
|
|
||||||
.next()
|
|
||||||
.unwrap_or("");
|
|
||||||
|
|
||||||
if status_line.contains("200") {
|
if status_line.contains("200") {
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
@@ -279,11 +279,7 @@ mod tests {
|
|||||||
.parse()
|
.parse()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let reply = vec![
|
let reply = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||||
0x05, 0x00, 0x00, 0x01,
|
|
||||||
0, 0, 0, 0,
|
|
||||||
0, 0,
|
|
||||||
];
|
|
||||||
proxy_sock.write_all(&reply).await.unwrap();
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
|
|
||||||
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
||||||
@@ -323,11 +319,7 @@ mod tests {
|
|||||||
let mut port_bytes = [0u8; 2];
|
let mut port_bytes = [0u8; 2];
|
||||||
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||||
|
|
||||||
let reply = vec![
|
let reply = vec![0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||||
0x05, 0x05, 0x00, 0x01,
|
|
||||||
0, 0, 0, 0,
|
|
||||||
0, 0,
|
|
||||||
];
|
|
||||||
proxy_sock.write_all(&reply).await.unwrap();
|
proxy_sock.write_all(&reply).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,15 @@ use std::net::{IpAddr, SocketAddr};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use russh::keys::ssh_key::HashAlg;
|
use russh::keys::ssh_key::HashAlg;
|
||||||
use russh::server::{Auth, Handler, Msg, Session};
|
use russh::server::{Auth, Handler, Msg, Session};
|
||||||
use russh::Channel;
|
use russh::Channel;
|
||||||
use russh::ChannelId;
|
use russh::ChannelId;
|
||||||
|
|
||||||
use crate::auth::ServerAuthConfig;
|
use crate::config::DynamicConfig;
|
||||||
use crate::server::control_channel::{
|
use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX};
|
||||||
ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX,
|
|
||||||
};
|
|
||||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -44,7 +43,7 @@ impl std::fmt::Display for TransportKind {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct ServerHandler {
|
pub struct ServerHandler {
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
@@ -59,7 +58,7 @@ pub struct ServerHandler {
|
|||||||
|
|
||||||
impl ServerHandler {
|
impl ServerHandler {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
transport: TransportKind,
|
transport: TransportKind,
|
||||||
@@ -89,7 +88,7 @@ impl ServerHandler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
auth_config,
|
dynamic,
|
||||||
outbound_proxy,
|
outbound_proxy,
|
||||||
remote_addr,
|
remote_addr,
|
||||||
control_channel_router: ControlChannelRouter::without_handler(),
|
control_channel_router: ControlChannelRouter::without_handler(),
|
||||||
@@ -127,10 +126,7 @@ impl Drop for ServerHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ServerHandler {
|
impl ServerHandler {
|
||||||
pub fn with_control_channel_handler(
|
pub fn with_control_channel_handler(mut self, handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||||
mut self,
|
|
||||||
handler: Box<dyn ControlChannelHandler>,
|
|
||||||
) -> Self {
|
|
||||||
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@@ -172,7 +168,8 @@ impl Handler for ServerHandler {
|
|||||||
.map_or("unknown".to_string(), |a| a.to_string());
|
.map_or("unknown".to_string(), |a| a.to_string());
|
||||||
|
|
||||||
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
|
let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user);
|
||||||
let result = self.auth_config.authenticate_publickey(&russh_pub);
|
let auth_config = self.dynamic.load();
|
||||||
|
let result = auth_config.auth.authenticate_publickey(&russh_pub);
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
@@ -226,9 +223,12 @@ impl Handler for ServerHandler {
|
|||||||
});
|
});
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let target = match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
|
let target =
|
||||||
|
match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
|
||||||
Ok(addr) => addr,
|
Ok(addr) => addr,
|
||||||
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
|
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16))
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(mut addrs) => match addrs.next() {
|
Ok(mut addrs) => match addrs.next() {
|
||||||
Some(addr) => addr,
|
Some(addr) => addr,
|
||||||
None => return,
|
None => return,
|
||||||
@@ -236,7 +236,12 @@ impl Handler for ServerHandler {
|
|||||||
Err(_) => return,
|
Err(_) => return,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
crate::server::channel_proxy::proxy_channel(channel.into_stream(), target, &proxy_config).await;
|
crate::server::channel_proxy::proxy_channel(
|
||||||
|
channel.into_stream(),
|
||||||
|
target,
|
||||||
|
&proxy_config,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
});
|
});
|
||||||
|
|
||||||
let _ = (originator_address, originator_port);
|
let _ = (originator_address, originator_port);
|
||||||
@@ -389,7 +394,12 @@ impl Handler for ServerHandler {
|
|||||||
channel = %channel,
|
channel = %channel,
|
||||||
"rejected x11 request on channel"
|
"rejected x11 request on channel"
|
||||||
);
|
);
|
||||||
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
|
let _ = (
|
||||||
|
single_connection,
|
||||||
|
x11_auth_protocol,
|
||||||
|
x11_auth_cookie,
|
||||||
|
x11_screen_number,
|
||||||
|
);
|
||||||
let _ = session.channel_failure(channel);
|
let _ = session.channel_failure(channel);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -469,6 +479,8 @@ impl Handler for ServerHandler {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::auth::keys::KeySource;
|
use crate::auth::keys::KeySource;
|
||||||
|
use crate::auth::ServerAuthConfig;
|
||||||
|
use crate::config::AuthPolicy;
|
||||||
use russh::keys::{decode_secret_key, PrivateKey};
|
use russh::keys::{decode_secret_key, PrivateKey};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
@@ -487,19 +499,19 @@ mod tests {
|
|||||||
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_auth_config(keys_content: &str) -> Arc<ServerAuthConfig> {
|
fn make_auth_config(keys_content: &str) -> Arc<ArcSwap<DynamicConfig>> {
|
||||||
let f = make_authorized_keys_file(keys_content);
|
let f = make_authorized_keys_file(keys_content);
|
||||||
Arc::new(
|
let server_auth =
|
||||||
ServerAuthConfig::from_keys_and_ca(
|
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||||
Some(KeySource::File(f.path().to_path_buf())),
|
.unwrap();
|
||||||
None,
|
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||||
)
|
let dynamic = DynamicConfig::new(auth_policy);
|
||||||
.unwrap(),
|
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_empty_auth_config() -> Arc<ServerAuthConfig> {
|
fn make_empty_auth_config() -> Arc<ArcSwap<DynamicConfig>> {
|
||||||
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
let dynamic = DynamicConfig::default();
|
||||||
|
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||||
@@ -507,11 +519,18 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn make_handler(
|
fn make_handler(
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
) -> ServerHandler {
|
) -> ServerHandler {
|
||||||
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
|
ServerHandler::new(
|
||||||
|
dynamic,
|
||||||
|
outbound_proxy,
|
||||||
|
remote_addr,
|
||||||
|
TransportKind::Tcp,
|
||||||
|
default_limiter(),
|
||||||
|
10,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -530,9 +549,8 @@ mod tests {
|
|||||||
let mut handler = make_handler(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||||
let other_ssh_key = russh::keys::parse_public_key_base64(
|
let other_ssh_key =
|
||||||
other_key_text.split_whitespace().nth(1).unwrap(),
|
russh::keys::parse_public_key_base64(other_key_text.split_whitespace().nth(1).unwrap())
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let result = handler
|
let result = handler
|
||||||
@@ -553,10 +571,7 @@ mod tests {
|
|||||||
let mut handler = make_handler(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let result = handler
|
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||||
.auth_publickey("testuser", &ssh_key)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
Auth::Reject {
|
Auth::Reject {
|
||||||
@@ -629,8 +644,16 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn one_handler_per_connection() {
|
fn one_handler_per_connection() {
|
||||||
let auth_config = make_empty_auth_config();
|
let auth_config = make_empty_auth_config();
|
||||||
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
let handler1 = make_handler(
|
||||||
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some("10.0.0.1:22".parse().unwrap()),
|
||||||
|
);
|
||||||
|
let handler2 = make_handler(
|
||||||
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some("10.0.0.2:22".parse().unwrap()),
|
||||||
|
);
|
||||||
|
|
||||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||||
}
|
}
|
||||||
@@ -651,10 +674,20 @@ mod tests {
|
|||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
|
||||||
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
|
assert_eq!(
|
||||||
|
r1,
|
||||||
|
Auth::Reject {
|
||||||
|
proceed_with_methods: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
|
assert_eq!(
|
||||||
|
r2,
|
||||||
|
Auth::Reject {
|
||||||
|
proceed_with_methods: None
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
assert!(!handler.auth_limiter.check());
|
assert!(!handler.auth_limiter.check());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ pub mod stealth;
|
|||||||
|
|
||||||
pub use channel_proxy::{connect_outbound, proxy_channel};
|
pub use channel_proxy::{connect_outbound, proxy_channel};
|
||||||
pub use control_channel::{
|
pub use control_channel::{
|
||||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
|
is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream,
|
||||||
ALKNET_PREFIX, is_reserved_destination,
|
ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX,
|
||||||
};
|
};
|
||||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||||
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
pub use serve::{Server, ServeError, ServeOptions, ServeTransportMode};
|
pub use serve::{ServeError, ServeOptions, ServeTransportMode, Server};
|
||||||
pub use stealth::{ProtocolDetection, detect_protocol, send_fake_nginx_404, validate_stealth_config};
|
pub use stealth::{
|
||||||
|
detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection,
|
||||||
|
};
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ use std::net::SocketAddr;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
use russh::server::{self, Config};
|
use russh::server::{self, Config};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
use crate::auth::keys::KeySource;
|
use crate::auth::keys::KeySource;
|
||||||
use crate::auth::server_auth::ServerAuthConfig;
|
use crate::auth::server_auth::ServerAuthConfig;
|
||||||
|
use crate::config::{AuthPolicy, ConfigReloadHandle, DynamicConfig};
|
||||||
use crate::error::ConfigError;
|
use crate::error::ConfigError;
|
||||||
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
use crate::server::handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||||
use crate::server::rate_limit::ConnectionRateLimiter;
|
use crate::server::rate_limit::ConnectionRateLimiter;
|
||||||
@@ -228,7 +230,7 @@ struct ActiveSession {
|
|||||||
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
|
/// Supports stealth mode (TLS only), outbound proxy routing, and connection rate limiting.
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
config: Arc<server::Config>,
|
config: Arc<server::Config>,
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
stealth: bool,
|
stealth: bool,
|
||||||
@@ -244,17 +246,24 @@ impl Server {
|
|||||||
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
|
pub fn new(opts: ServeOptions) -> Result<Self, ServeError> {
|
||||||
opts.validate().map_err(ServeError::Config)?;
|
opts.validate().map_err(ServeError::Config)?;
|
||||||
|
|
||||||
let private_key =
|
let private_key = crate::auth::keys::load_private_key(opts.key.clone())
|
||||||
crate::auth::keys::load_private_key(opts.key.clone()).map_err(ServeError::KeyLoadFailed)?;
|
.map_err(ServeError::KeyLoadFailed)?;
|
||||||
|
|
||||||
let auth_config = Arc::new(
|
let auth_config = ServerAuthConfig::from_keys_and_ca(
|
||||||
ServerAuthConfig::from_keys_and_ca(opts.authorized_keys.clone(), opts.cert_authority.clone())
|
opts.authorized_keys.clone(),
|
||||||
.map_err(ServeError::KeyLoadFailed)?,
|
opts.cert_authority.clone(),
|
||||||
);
|
)
|
||||||
|
.map_err(ServeError::KeyLoadFailed)?;
|
||||||
|
|
||||||
|
let auth_policy = AuthPolicy::from_server_auth_config(auth_config);
|
||||||
|
let dynamic_config = DynamicConfig::new(auth_policy);
|
||||||
|
|
||||||
|
let max_auth_attempts = opts.max_auth_attempts;
|
||||||
|
let max_connections_per_ip = opts.max_connections_per_ip;
|
||||||
|
|
||||||
let config = Arc::new(Config {
|
let config = Arc::new(Config {
|
||||||
keys: vec![private_key],
|
keys: vec![private_key],
|
||||||
max_auth_attempts: opts.max_auth_attempts,
|
max_auth_attempts,
|
||||||
methods: russh::MethodSet::PUBLICKEY,
|
methods: russh::MethodSet::PUBLICKEY,
|
||||||
preferred: russh::Preferred::DEFAULT,
|
preferred: russh::Preferred::DEFAULT,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -262,19 +271,21 @@ impl Server {
|
|||||||
|
|
||||||
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
|
let outbound_proxy = parse_proxy_config(opts.proxy.as_deref());
|
||||||
|
|
||||||
let connection_limiter = Arc::new(ConnectionRateLimiter::new(opts.max_connections_per_ip));
|
let connection_limiter = Arc::new(ConnectionRateLimiter::new(max_connections_per_ip));
|
||||||
|
|
||||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||||
|
|
||||||
|
let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config)));
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
auth_config,
|
dynamic,
|
||||||
connection_limiter,
|
connection_limiter,
|
||||||
outbound_proxy,
|
outbound_proxy,
|
||||||
stealth: opts.stealth,
|
stealth: opts.stealth,
|
||||||
transport_mode: opts.transport_mode,
|
transport_mode: opts.transport_mode,
|
||||||
listen_addr: opts.listen_addr,
|
listen_addr: opts.listen_addr,
|
||||||
max_auth_attempts: opts.max_auth_attempts,
|
max_auth_attempts,
|
||||||
shutdown_tx,
|
shutdown_tx,
|
||||||
shutdown_rx,
|
shutdown_rx,
|
||||||
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
|
sessions: Arc::new(tokio::sync::Mutex::new(Vec::new())),
|
||||||
@@ -285,6 +296,12 @@ impl Server {
|
|||||||
self.shutdown_tx.clone()
|
self.shutdown_tx.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn config_reload_handle(&self) -> ConfigReloadHandle {
|
||||||
|
ConfigReloadHandle {
|
||||||
|
dynamic: Arc::clone(&self.dynamic),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn shutdown(&self) -> Result<(), ServeError> {
|
pub async fn shutdown(&self) -> Result<(), ServeError> {
|
||||||
info!("initiating graceful shutdown");
|
info!("initiating graceful shutdown");
|
||||||
let _ = self.shutdown_tx.send(true);
|
let _ = self.shutdown_tx.send(true);
|
||||||
@@ -292,11 +309,15 @@ impl Server {
|
|||||||
{
|
{
|
||||||
let sessions = self.sessions.lock().await;
|
let sessions = self.sessions.lock().await;
|
||||||
for session in sessions.iter() {
|
for session in sessions.iter() {
|
||||||
if let Err(e) = session.handle.disconnect(
|
if let Err(e) = session
|
||||||
|
.handle
|
||||||
|
.disconnect(
|
||||||
russh::Disconnect::ByApplication,
|
russh::Disconnect::ByApplication,
|
||||||
"shutdown".to_string(),
|
"shutdown".to_string(),
|
||||||
String::new(),
|
String::new(),
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
warn!("failed to send SSH disconnect: {e}");
|
warn!("failed to send SSH disconnect: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -392,7 +413,7 @@ impl Server {
|
|||||||
let handler_transport_kind = transport_kind;
|
let handler_transport_kind = transport_kind;
|
||||||
|
|
||||||
let handler = ServerHandler::new(
|
let handler = ServerHandler::new(
|
||||||
Arc::clone(&server.auth_config),
|
Arc::clone(&server.dynamic),
|
||||||
server.outbound_proxy.clone(),
|
server.outbound_proxy.clone(),
|
||||||
remote_addr,
|
remote_addr,
|
||||||
handler_transport_kind,
|
handler_transport_kind,
|
||||||
@@ -410,14 +431,8 @@ impl Server {
|
|||||||
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
|
let transport_is_tls = server.transport_mode == ServeTransportMode::Tls;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let result = handle_connection(
|
let result =
|
||||||
stream,
|
handle_connection(stream, config, handler, sessions, stealth, transport_is_tls)
|
||||||
config,
|
|
||||||
handler,
|
|
||||||
sessions,
|
|
||||||
stealth,
|
|
||||||
transport_is_tls,
|
|
||||||
)
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
@@ -611,8 +626,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serve_options_validate_tcp_with_acme_rejected() {
|
fn serve_options_validate_tcp_with_acme_rejected() {
|
||||||
let opts =
|
let opts = ServeOptions::new(make_key_source()).acme_domain("example.com");
|
||||||
ServeOptions::new(make_key_source()).acme_domain("example.com");
|
|
||||||
assert!(opts.validate().is_err());
|
assert!(opts.validate().is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -626,8 +640,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn server_new_creates_server() {
|
fn server_new_creates_server() {
|
||||||
let opts = ServeOptions::new(make_key_source())
|
let opts =
|
||||||
.authorized_keys(make_authorized_keys_source());
|
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
|
||||||
let server = Server::new(opts).unwrap();
|
let server = Server::new(opts).unwrap();
|
||||||
assert_eq!(server.max_auth_attempts, 10);
|
assert_eq!(server.max_auth_attempts, 10);
|
||||||
}
|
}
|
||||||
@@ -662,8 +676,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serve_options_debug_redacts_keys() {
|
fn serve_options_debug_redacts_keys() {
|
||||||
let opts = ServeOptions::new(make_key_source())
|
let opts =
|
||||||
.authorized_keys(make_authorized_keys_source());
|
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
|
||||||
let debug_str = format!("{:?}", opts);
|
let debug_str = format!("{:?}", opts);
|
||||||
assert!(debug_str.contains("<KeySource>"));
|
assert!(debug_str.contains("<KeySource>"));
|
||||||
assert!(!debug_str.contains("OPENSSH"));
|
assert!(!debug_str.contains("OPENSSH"));
|
||||||
@@ -715,8 +729,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn server_shutdown_sender_clones() {
|
fn server_shutdown_sender_clones() {
|
||||||
let opts = ServeOptions::new(make_key_source())
|
let opts =
|
||||||
.authorized_keys(make_authorized_keys_source());
|
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
|
||||||
let server = Server::new(opts).unwrap();
|
let server = Server::new(opts).unwrap();
|
||||||
let sender = server.shutdown_sender();
|
let sender = server.shutdown_sender();
|
||||||
assert!(!server.is_shutdown());
|
assert!(!server.is_shutdown());
|
||||||
@@ -726,8 +740,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn server_holds_listen_addr() {
|
fn server_holds_listen_addr() {
|
||||||
let opts = ServeOptions::new(make_key_source())
|
let opts = ServeOptions::new(make_key_source()).listen_addr("0.0.0.0:443");
|
||||||
.listen_addr("0.0.0.0:443");
|
|
||||||
let server = Server::new(opts).unwrap();
|
let server = Server::new(opts).unwrap();
|
||||||
assert_eq!(server.listen_addr, "0.0.0.0:443");
|
assert_eq!(server.listen_addr, "0.0.0.0:443");
|
||||||
}
|
}
|
||||||
@@ -747,12 +760,10 @@ mod tests {
|
|||||||
let server = Server::new(opts).unwrap();
|
let server = Server::new(opts).unwrap();
|
||||||
let shutdown_tx = server.shutdown_sender();
|
let shutdown_tx = server.shutdown_sender();
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
server
|
tokio::spawn(
|
||||||
.run(acceptor, None)
|
async move { server.run(acceptor, None).await.expect("server run failed") },
|
||||||
.await
|
);
|
||||||
.expect("server run failed")
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
@@ -760,6 +771,9 @@ mod tests {
|
|||||||
|
|
||||||
let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
|
let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
|
||||||
|
|
||||||
assert!(result.is_ok(), "server should have shut down within timeout");
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"server should have shut down within timeout"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,10 @@ mod tests {
|
|||||||
|
|
||||||
let mut all_data = Vec::new();
|
let mut all_data = Vec::new();
|
||||||
reader.read_to_end(&mut all_data).await.unwrap();
|
reader.read_to_end(&mut all_data).await.unwrap();
|
||||||
assert!(all_data.starts_with(banner), "banner bytes must be preserved after detection");
|
assert!(
|
||||||
|
all_data.starts_with(banner),
|
||||||
|
"banner bytes must be preserved after detection"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -142,7 +145,10 @@ mod tests {
|
|||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||||
|
|
||||||
client_write.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
client_write
|
||||||
|
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
drop(client_write);
|
drop(client_write);
|
||||||
|
|
||||||
let (detection, mut reader) = detect_protocol(server).await;
|
let (detection, mut reader) = detect_protocol(server).await;
|
||||||
@@ -206,7 +212,10 @@ mod tests {
|
|||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
let mut client = client;
|
let mut client = client;
|
||||||
|
|
||||||
client.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await.unwrap();
|
client
|
||||||
|
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let (detection, mut reader) = detect_protocol(server).await;
|
let (detection, mut reader) = detect_protocol(server).await;
|
||||||
assert_eq!(detection, ProtocolDetection::Http);
|
assert_eq!(detection, ProtocolDetection::Http);
|
||||||
|
|||||||
@@ -52,9 +52,7 @@ impl<C: ChannelOpener> Socks5Server<C> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
||||||
let listen_addr: SocketAddr = addr
|
let listen_addr: SocketAddr = addr.parse().expect("invalid SOCKS5 listen address");
|
||||||
.parse()
|
|
||||||
.expect("invalid SOCKS5 listen address");
|
|
||||||
Self {
|
Self {
|
||||||
listen_addr,
|
listen_addr,
|
||||||
channel_opener: Arc::new(channel_opener),
|
channel_opener: Arc::new(channel_opener),
|
||||||
@@ -80,10 +78,7 @@ impl<C: ChannelOpener> Socks5Server<C> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socks5_connection<S, C>(
|
async fn handle_socks5_connection<S, C>(mut socket: S, opener: Arc<C>) -> Result<(), Socks5Error>
|
||||||
mut socket: S,
|
|
||||||
opener: Arc<C>,
|
|
||||||
) -> Result<(), Socks5Error>
|
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
C: ChannelOpener,
|
C: ChannelOpener,
|
||||||
@@ -173,7 +168,11 @@ impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
|||||||
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
||||||
type Stream = russh::ChannelStream<russh::client::Msg>;
|
type Stream = russh::ChannelStream<russh::client::Msg>;
|
||||||
|
|
||||||
async fn open_channel(&self, host: String, port: u16) -> Result<Self::Stream, ChannelOpenError> {
|
async fn open_channel(
|
||||||
|
&self,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<Self::Stream, ChannelOpenError> {
|
||||||
let handle = self.handle.lock().await;
|
let handle = self.handle.lock().await;
|
||||||
if handle.is_closed() {
|
if handle.is_closed() {
|
||||||
return Err(ChannelOpenError::SessionClosed);
|
return Err(ChannelOpenError::SessionClosed);
|
||||||
@@ -241,7 +240,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
||||||
client.write_all(&build_socks5_greeting(&[0x00])).await.unwrap();
|
client
|
||||||
|
.write_all(&build_socks5_greeting(&[0x00]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
client.flush().await.unwrap();
|
client.flush().await.unwrap();
|
||||||
let mut resp = [0u8; 2];
|
let mut resp = [0u8; 2];
|
||||||
client.read_exact(&mut resp).await.unwrap();
|
client.read_exact(&mut resp).await.unwrap();
|
||||||
@@ -264,9 +266,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
let resp = do_handshake(&mut client).await;
|
let resp = do_handshake(&mut client).await;
|
||||||
assert_eq!(resp, [0x05, 0x00]);
|
assert_eq!(resp, [0x05, 0x00]);
|
||||||
@@ -284,9 +285,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
client
|
client
|
||||||
.write_all(&build_socks5_greeting(&[0x02]))
|
.write_all(&build_socks5_greeting(&[0x02]))
|
||||||
@@ -301,10 +301,7 @@ mod tests {
|
|||||||
drop(client);
|
drop(client);
|
||||||
let result = server_handle.await.unwrap();
|
let result = server_handle.await.unwrap();
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(matches!(
|
assert!(matches!(result.unwrap_err(), Socks5Error::NoAcceptableAuth));
|
||||||
result.unwrap_err(),
|
|
||||||
Socks5Error::NoAcceptableAuth
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -312,9 +309,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
do_handshake(&mut client).await;
|
||||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
||||||
@@ -329,9 +325,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
@@ -354,9 +349,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
@@ -381,9 +375,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: true };
|
let opener = MockChannelOpener { fail: true };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
do_handshake(&mut client).await;
|
||||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
||||||
@@ -399,9 +392,8 @@ mod tests {
|
|||||||
let (mut client, server) = duplex(4096);
|
let (mut client, server) = duplex(4096);
|
||||||
let opener = MockChannelOpener { fail: false };
|
let opener = MockChannelOpener { fail: false };
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server, Arc::new(opener)).await
|
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||||
});
|
|
||||||
|
|
||||||
do_handshake(&mut client).await;
|
do_handshake(&mut client).await;
|
||||||
|
|
||||||
@@ -450,9 +442,10 @@ mod tests {
|
|||||||
stream: Arc::clone(&ssh_stream),
|
stream: Arc::clone(&ssh_stream),
|
||||||
};
|
};
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle =
|
||||||
handle_socks5_connection(server_sock, Arc::new(opener)).await
|
tokio::spawn(
|
||||||
});
|
async move { handle_socks5_connection(server_sock, Arc::new(opener)).await },
|
||||||
|
);
|
||||||
|
|
||||||
do_handshake(&mut client_sock).await;
|
do_handshake(&mut client_sock).await;
|
||||||
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
||||||
|
|||||||
@@ -169,10 +169,7 @@ mod tests {
|
|||||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
assert_eq!(req.version, 0x05);
|
assert_eq!(req.version, 0x05);
|
||||||
assert_eq!(req.command, 0x01);
|
assert_eq!(req.command, 0x01);
|
||||||
assert_eq!(
|
assert_eq!(req.address, Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1)));
|
||||||
req.address,
|
|
||||||
Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
|
|
||||||
);
|
|
||||||
assert_eq!(req.port, 443);
|
assert_eq!(req.port, 443);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +198,10 @@ mod tests {
|
|||||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||||
assert_eq!(req.version, 0x05);
|
assert_eq!(req.version, 0x05);
|
||||||
assert_eq!(req.command, 0x01);
|
assert_eq!(req.command, 0x01);
|
||||||
assert_eq!(req.address, Socks5Address::Domain("example.com".to_string()));
|
assert_eq!(
|
||||||
|
req.address,
|
||||||
|
Socks5Address::Domain("example.com".to_string())
|
||||||
|
);
|
||||||
assert_eq!(req.port, 443);
|
assert_eq!(req.port, 443);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use tokio::io::{DuplexStream, AsyncRead, AsyncWrite};
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
|
||||||
|
|
||||||
#[cfg(feature = "transport-traits")]
|
#[cfg(feature = "transport-traits")]
|
||||||
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||||
@@ -9,10 +9,10 @@ pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKin
|
|||||||
|
|
||||||
#[cfg(not(feature = "transport-traits"))]
|
#[cfg(not(feature = "transport-traits"))]
|
||||||
mod local_traits {
|
mod local_traits {
|
||||||
use std::net::SocketAddr;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Transport: Send + Sync + 'static {
|
pub trait Transport: Send + Sync + 'static {
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ use rustls::crypto::aws_lc_rs::default_provider;
|
|||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use rustls_acme::caches::DirCache;
|
use rustls_acme::caches::DirCache;
|
||||||
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
||||||
use tracing::{error, info};
|
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
||||||
|
|
||||||
@@ -94,14 +94,10 @@ impl AcmeCertProvider {
|
|||||||
.contact(self.contact.clone());
|
.contact(self.contact.clone());
|
||||||
|
|
||||||
let state = match &self.cache_dir {
|
let state = match &self.cache_dir {
|
||||||
Some(cache_dir) => {
|
Some(cache_dir) => base_config.cache(DirCache::new(cache_dir.clone())).state(),
|
||||||
base_config.cache(DirCache::new(cache_dir.clone())).state()
|
None => base_config
|
||||||
}
|
|
||||||
None => {
|
|
||||||
base_config
|
|
||||||
.cache(rustls_acme::caches::NoCache::default())
|
.cache(rustls_acme::caches::NoCache::default())
|
||||||
.state()
|
.state(),
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let resolver = state.resolver();
|
let resolver = state.resolver();
|
||||||
@@ -132,10 +128,7 @@ pub struct AcmeTlsAcceptor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AcmeTlsAcceptor {
|
impl AcmeTlsAcceptor {
|
||||||
pub async fn bind_acme(
|
pub async fn bind_acme(addr: SocketAddr, provider: Arc<AcmeCertProvider>) -> Result<Self> {
|
||||||
addr: SocketAddr,
|
|
||||||
provider: Arc<AcmeCertProvider>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (state, resolver) = provider.build_acme_state();
|
let (state, resolver) = provider.build_acme_state();
|
||||||
|
|
||||||
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
||||||
@@ -193,11 +186,7 @@ impl TransportAcceptor for AcmeTlsAcceptor {
|
|||||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||||
|
|
||||||
let server_name = tls_stream
|
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||||
.get_ref()
|
|
||||||
.1
|
|
||||||
.server_name()
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
let info = TransportInfo {
|
let info = TransportInfo {
|
||||||
remote_addr: Some(remote_addr),
|
remote_addr: Some(remote_addr),
|
||||||
@@ -277,8 +266,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn acme_cert_provider_build_state_with_cache() {
|
fn acme_cert_provider_build_state_with_cache() {
|
||||||
let provider =
|
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
||||||
AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
|
||||||
let (_state, resolver) = provider.build_acme_state();
|
let (_state, resolver) = provider.build_acme_state();
|
||||||
assert!(Arc::strong_count(&resolver) >= 2);
|
assert!(Arc::strong_count(&resolver) >= 2);
|
||||||
}
|
}
|
||||||
@@ -288,7 +276,9 @@ mod tests {
|
|||||||
let _ = default_provider().install_default();
|
let _ = default_provider().install_default();
|
||||||
let provider = AcmeCertProvider::domain("example.com");
|
let provider = AcmeCertProvider::domain("example.com");
|
||||||
let (_, resolver) = provider.build_acme_state();
|
let (_, resolver) = provider.build_acme_state();
|
||||||
let config = provider.build_server_config_with_resolver(resolver).unwrap();
|
let config = provider
|
||||||
|
.build_server_config_with_resolver(resolver)
|
||||||
|
.unwrap();
|
||||||
assert!(!config.alpn_protocols.is_empty());
|
assert!(!config.alpn_protocols.is_empty());
|
||||||
assert!(config
|
assert!(config
|
||||||
.alpn_protocols
|
.alpn_protocols
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use iroh::{
|
use iroh::{
|
||||||
endpoint::RecvStream,
|
endpoint::RecvStream, node_info::NodeIdExt, Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
|
||||||
node_info::NodeIdExt,
|
|
||||||
Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
|
|
||||||
};
|
};
|
||||||
use tokio::io;
|
use tokio::io;
|
||||||
|
|
||||||
@@ -39,7 +37,9 @@ impl IrohTransport {
|
|||||||
proxy_url: Option<url::Url>,
|
proxy_url: Option<url::Url>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let relay_url = relay_url.unwrap_or_else(|| {
|
let relay_url = relay_url.unwrap_or_else(|| {
|
||||||
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
|
DEFAULT_RELAY_URL
|
||||||
|
.parse()
|
||||||
|
.expect("default relay URL is valid")
|
||||||
});
|
});
|
||||||
let relay_map = RelayMap::from_url(relay_url);
|
let relay_map = RelayMap::from_url(relay_url);
|
||||||
let mut builder = Endpoint::builder()
|
let mut builder = Endpoint::builder()
|
||||||
@@ -49,7 +49,11 @@ impl IrohTransport {
|
|||||||
builder = builder.proxy_url(proxy.clone());
|
builder = builder.proxy_url(proxy.clone());
|
||||||
}
|
}
|
||||||
let endpoint = builder.bind().await?;
|
let endpoint = builder.bind().await?;
|
||||||
Ok(Self { node_id, endpoint, owned: true })
|
Ok(Self {
|
||||||
|
node_id,
|
||||||
|
endpoint,
|
||||||
|
owned: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an iroh transport using an existing shared endpoint.
|
/// Create an iroh transport using an existing shared endpoint.
|
||||||
@@ -60,7 +64,11 @@ impl IrohTransport {
|
|||||||
/// other protocol handlers on the same QUIC endpoint — one connection
|
/// other protocol handlers on the same QUIC endpoint — one connection
|
||||||
/// per peer, multiplexed by ALPN.
|
/// per peer, multiplexed by ALPN.
|
||||||
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
|
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
|
||||||
Self { node_id, endpoint, owned: false }
|
Self {
|
||||||
|
node_id,
|
||||||
|
endpoint,
|
||||||
|
owned: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn endpoint_id(&self) -> String {
|
pub fn endpoint_id(&self) -> String {
|
||||||
@@ -115,12 +123,11 @@ impl IrohAcceptor {
|
|||||||
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
|
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
|
||||||
///
|
///
|
||||||
/// Use this when alknet is the only iroh service on this node.
|
/// Use this when alknet is the only iroh service on this node.
|
||||||
pub async fn bind(
|
pub async fn bind(relay_url: Option<RelayUrl>, proxy_url: Option<url::Url>) -> Result<Self> {
|
||||||
relay_url: Option<RelayUrl>,
|
|
||||||
proxy_url: Option<url::Url>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let relay_url = relay_url.unwrap_or_else(|| {
|
let relay_url = relay_url.unwrap_or_else(|| {
|
||||||
DEFAULT_RELAY_URL.parse().expect("default relay URL is valid")
|
DEFAULT_RELAY_URL
|
||||||
|
.parse()
|
||||||
|
.expect("default relay URL is valid")
|
||||||
});
|
});
|
||||||
let relay_map = RelayMap::from_url(relay_url);
|
let relay_map = RelayMap::from_url(relay_url);
|
||||||
let mut builder = Endpoint::builder()
|
let mut builder = Endpoint::builder()
|
||||||
@@ -130,7 +137,10 @@ impl IrohAcceptor {
|
|||||||
builder = builder.proxy_url(proxy.clone());
|
builder = builder.proxy_url(proxy.clone());
|
||||||
}
|
}
|
||||||
let endpoint = builder.bind().await?;
|
let endpoint = builder.bind().await?;
|
||||||
Ok(Self { endpoint, owned: true })
|
Ok(Self {
|
||||||
|
endpoint,
|
||||||
|
owned: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an iroh acceptor using an existing shared endpoint.
|
/// Create an iroh acceptor using an existing shared endpoint.
|
||||||
@@ -146,7 +156,10 @@ impl IrohAcceptor {
|
|||||||
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
|
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
|
||||||
/// internally.
|
/// internally.
|
||||||
pub fn from_endpoint(endpoint: Endpoint) -> Self {
|
pub fn from_endpoint(endpoint: Endpoint) -> Self {
|
||||||
Self { endpoint, owned: false }
|
Self {
|
||||||
|
endpoint,
|
||||||
|
owned: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn endpoint_id(&self) -> String {
|
pub fn endpoint_id(&self) -> String {
|
||||||
@@ -219,18 +232,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn iroh_transport_describe_format() {
|
fn iroh_transport_describe_format() {
|
||||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
|
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||||
.public()
|
|
||||||
.into();
|
|
||||||
let desc = format!("iroh://{}", node_id.to_z32());
|
let desc = format!("iroh://{}", node_id.to_z32());
|
||||||
assert!(desc.starts_with("iroh://"));
|
assert!(desc.starts_with("iroh://"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn iroh_transport_connect_builds_endpoint() {
|
async fn iroh_transport_connect_builds_endpoint() {
|
||||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
|
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||||
.public()
|
|
||||||
.into();
|
|
||||||
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
|
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
|
||||||
assert!(transport.describe().starts_with("iroh://"));
|
assert!(transport.describe().starts_with("iroh://"));
|
||||||
assert!(!transport.endpoint_id().is_empty());
|
assert!(!transport.endpoint_id().is_empty());
|
||||||
@@ -239,9 +248,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn iroh_transport_from_endpoint() {
|
async fn iroh_transport_from_endpoint() {
|
||||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng)
|
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||||
.public()
|
|
||||||
.into();
|
|
||||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||||
let endpoint = acceptor.endpoint.clone();
|
let endpoint = acceptor.endpoint.clone();
|
||||||
let transport = IrohTransport::from_endpoint(node_id, endpoint);
|
let transport = IrohTransport::from_endpoint(node_id, endpoint);
|
||||||
|
|||||||
@@ -13,13 +13,13 @@
|
|||||||
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
|
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
|
||||||
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
|
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
|
||||||
|
|
||||||
mod tcp;
|
|
||||||
#[cfg(feature = "iroh")]
|
#[cfg(feature = "iroh")]
|
||||||
mod iroh_transport;
|
mod iroh_transport;
|
||||||
|
mod tcp;
|
||||||
|
|
||||||
pub use tcp::{TcpAcceptor, TcpTransport};
|
|
||||||
#[cfg(feature = "iroh")]
|
#[cfg(feature = "iroh")]
|
||||||
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
|
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
|
||||||
|
pub use tcp::{TcpAcceptor, TcpTransport};
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
mod tls;
|
mod tls;
|
||||||
@@ -89,12 +89,8 @@ pub struct TransportInfo {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TransportKind {
|
pub enum TransportKind {
|
||||||
Tcp,
|
Tcp,
|
||||||
Tls {
|
Tls { server_name: Option<String> },
|
||||||
server_name: Option<String>,
|
Iroh { endpoint_id: String },
|
||||||
},
|
|
||||||
Iroh {
|
|
||||||
endpoint_id: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server
|
|||||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
||||||
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio_rustls::{client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector};
|
use tokio_rustls::{
|
||||||
|
client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector,
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(feature = "acme")]
|
#[cfg(feature = "acme")]
|
||||||
use rustls::crypto::aws_lc_rs::default_provider;
|
use rustls::crypto::aws_lc_rs::default_provider;
|
||||||
@@ -169,7 +171,9 @@ impl TlsAcceptor {
|
|||||||
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
.with_cert_resolver(acme_resolver);
|
.with_cert_resolver(acme_resolver);
|
||||||
server_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
server_config
|
||||||
|
.alpn_protocols
|
||||||
|
.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||||
|
|
||||||
let server_config = Arc::new(server_config);
|
let server_config = Arc::new(server_config);
|
||||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||||
@@ -195,11 +199,7 @@ impl TransportAcceptor for TlsAcceptor {
|
|||||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||||
|
|
||||||
let server_name = tls_stream
|
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||||
.get_ref()
|
|
||||||
.1
|
|
||||||
.server_name()
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
let info = TransportInfo {
|
let info = TransportInfo {
|
||||||
remote_addr: Some(remote_addr),
|
remote_addr: Some(remote_addr),
|
||||||
@@ -324,10 +324,7 @@ mod tests {
|
|||||||
|
|
||||||
let (mut server, info) = accept_handle.await.unwrap();
|
let (mut server, info) = accept_handle.await.unwrap();
|
||||||
assert!(info.remote_addr.is_some());
|
assert!(info.remote_addr.is_some());
|
||||||
assert!(matches!(
|
assert!(matches!(info.transport_kind, TransportKind::Tls { .. }));
|
||||||
info.transport_kind,
|
|
||||||
TransportKind::Tls { .. }
|
|
||||||
));
|
|
||||||
|
|
||||||
client.write_all(b"hello tls").await.unwrap();
|
client.write_all(b"hello tls").await.unwrap();
|
||||||
let mut buf = [0u8; 9];
|
let mut buf = [0u8; 9];
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use alknet_core::testutil::{MockTransport, MockTransportAcceptor, Transport, TransportAcceptor, mock_pair};
|
use alknet_core::testutil::{
|
||||||
|
mock_pair, MockTransport, MockTransportAcceptor, Transport, TransportAcceptor,
|
||||||
|
};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn mock_transport_connect() {
|
async fn mock_transport_connect() {
|
||||||
|
|||||||
@@ -328,7 +328,12 @@ impl russh::server::Handler for NapiServerHandler {
|
|||||||
session: &mut russh::server::Session,
|
session: &mut russh::server::Session,
|
||||||
) -> std::result::Result<(), Self::Error> {
|
) -> std::result::Result<(), Self::Error> {
|
||||||
tracing::warn!(channel = %channel, "rejected x11 request");
|
tracing::warn!(channel = %channel, "rejected x11 request");
|
||||||
let _ = (single_connection, x11_auth_protocol, x11_auth_cookie, x11_screen_number);
|
let _ = (
|
||||||
|
single_connection,
|
||||||
|
x11_auth_protocol,
|
||||||
|
x11_auth_cookie,
|
||||||
|
x11_screen_number,
|
||||||
|
);
|
||||||
let _ = session.channel_failure(channel);
|
let _ = session.channel_failure(channel);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -348,7 +353,11 @@ impl russh::server::Handler for NapiServerHandler {
|
|||||||
port: &mut u32,
|
port: &mut u32,
|
||||||
_session: &mut russh::server::Session,
|
_session: &mut russh::server::Session,
|
||||||
) -> std::result::Result<bool, Self::Error> {
|
) -> std::result::Result<bool, Self::Error> {
|
||||||
tracing::warn!(address = address, port = *port, "rejected tcpip-forward request");
|
tracing::warn!(
|
||||||
|
address = address,
|
||||||
|
port = *port,
|
||||||
|
"rejected tcpip-forward request"
|
||||||
|
);
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,7 +376,10 @@ impl russh::server::Handler for NapiServerHandler {
|
|||||||
socket_path: &str,
|
socket_path: &str,
|
||||||
_session: &mut russh::server::Session,
|
_session: &mut russh::server::Session,
|
||||||
) -> std::result::Result<bool, Self::Error> {
|
) -> std::result::Result<bool, Self::Error> {
|
||||||
tracing::warn!(socket_path = socket_path, "rejected streamlocal-forward request");
|
tracing::warn!(
|
||||||
|
socket_path = socket_path,
|
||||||
|
"rejected streamlocal-forward request"
|
||||||
|
);
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -542,8 +554,8 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
|
|||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
|
|
||||||
let private_key =
|
let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone())
|
||||||
alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| {
|
.map_err(|e| {
|
||||||
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
|
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -635,7 +647,9 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let acceptor = TlsAcceptor::bind(addr, certs, key, None).await.map_err(|e| {
|
let acceptor = TlsAcceptor::bind(addr, certs, key, None)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
napi::Error::new(
|
napi::Error::new(
|
||||||
napi::Status::GenericFailure,
|
napi::Status::GenericFailure,
|
||||||
format!("tls bind failed: {}", e),
|
format!("tls bind failed: {}", e),
|
||||||
@@ -653,8 +667,8 @@ pub async fn serve(options: AlknetServeOptions) -> napi::Result<AlknetServer> {
|
|||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
|
|
||||||
let private_key =
|
let private_key = alknet_core::auth::keys::load_private_key(host_key_source.clone())
|
||||||
alknet_core::auth::keys::load_private_key(host_key_source.clone()).map_err(|e| {
|
.map_err(|e| {
|
||||||
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
|
napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ use std::net::SocketAddr;
|
|||||||
use std::process;
|
use std::process;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
|
||||||
use alknet_core::auth::keys::KeySource;
|
use alknet_core::auth::keys::KeySource;
|
||||||
use alknet_core::client::{ConnectOptions, TransportMode};
|
use alknet_core::client::{ConnectOptions, TransportMode};
|
||||||
use alknet_core::server::{ServeOptions, ServeTransportMode, Server};
|
use alknet_core::server::{ServeOptions, ServeTransportMode, Server};
|
||||||
@@ -21,6 +19,8 @@ use alknet_core::transport::TcpTransport;
|
|||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
use alknet_core::transport::TlsTransport;
|
use alknet_core::transport::TlsTransport;
|
||||||
use alknet_core::transport::Transport;
|
use alknet_core::transport::Transport;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "alknet", version, about = "Alknet SSH tunnel tool")]
|
#[command(name = "alknet", version, about = "Alknet SSH tunnel tool")]
|
||||||
@@ -76,7 +76,7 @@ enum Commands {
|
|||||||
insecure: bool,
|
insecure: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[command( about = "Start the alknet server (accept SSH connections)")]
|
#[command(about = "Start the alknet server (accept SSH connections)")]
|
||||||
Serve {
|
Serve {
|
||||||
#[arg(long, help = "SSH host key path (required)")]
|
#[arg(long, help = "SSH host key path (required)")]
|
||||||
key: String,
|
key: String,
|
||||||
|
|||||||
Reference in New Issue
Block a user