greenfield: clean slate for ALPN-as-service pivot
Delete old source crates (alknet-core, alknet, alknet-napi), old architecture docs (ADRs, specs, open questions), old research docs (phase2, event-sourcing, feasibility, etc.), old tasks, and obsolete reference material (gitserver/MPL, honker, nats, rustfs, polyglot, keystone, distributed-identity). Keep: alknet-secret (standalone, compiles), pivot docs, iroh and ssh references, rudolfs reference (MIT/Apache, fork candidate), ops docs, sdd_process.md, and licenses. Previous implementation preserved at /workspace/@alkdev/alknet-main/ for reference during porting. Workspace compiles: cargo check + 14 tests pass for alknet-secret.
This commit is contained in:
@@ -1,56 +0,0 @@
|
||||
[package]
|
||||
name = "alknet-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Core library for Alknet: pluggable SSH tunnel transport, SOCKS5 proxy, port forwarding, and authentication"
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "alknet_core"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
||||
iroh = ["dep:iroh", "dep:url"]
|
||||
acme = ["dep:rustls-acme", "dep:futures", "tls"]
|
||||
http = ["dep:axum", "dep:hyper", "dep:hyper-util", "dep:tower", "dep:http-body-util"]
|
||||
irpc = []
|
||||
testutil = []
|
||||
transport-traits = []
|
||||
|
||||
[dependencies]
|
||||
russh = "0.49"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tokio-rustls = { version = "0.26", optional = true }
|
||||
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
||||
rustls-pki-types = { version = "1", optional = true }
|
||||
rustls-acme = { version = "0.12", optional = true }
|
||||
futures = { version = "0.3", optional = true }
|
||||
webpki-roots = { version = "0.26", optional = true }
|
||||
iroh = { version = "0.34", optional = true }
|
||||
url = { version = "2", optional = true }
|
||||
async-trait = "0.1"
|
||||
ipnetwork = "0.21.1"
|
||||
arc-swap = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
axum = { version = "0.8", optional = true }
|
||||
hyper = { version = "1", optional = true }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "server", "service"], optional = true }
|
||||
tower = { version = "0.5", optional = true }
|
||||
http-body-util = { version = "0.1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
alknet-core = { path = ".", features = ["testutil", "tls", "iroh", "http"] }
|
||||
tempfile = "3"
|
||||
rcgen = "0.14"
|
||||
rand_core = "0.6"
|
||||
ssh-key = { version = "0.6", features = ["ed25519", "alloc"] }
|
||||
rand = "0.10.1"
|
||||
@@ -1,262 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::auth::identity::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AuthProtocol {
|
||||
VerifyPubkey {
|
||||
fingerprint: String,
|
||||
key_data: Vec<u8>,
|
||||
},
|
||||
VerifyToken {
|
||||
token_bytes: Vec<u8>,
|
||||
timestamp: u64,
|
||||
},
|
||||
ReloadKeys,
|
||||
CheckAccess {
|
||||
identity: Identity,
|
||||
operation: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AuthResult {
|
||||
Ok(Identity),
|
||||
Denied(String),
|
||||
}
|
||||
|
||||
pub struct AuthServiceImpl {
|
||||
provider: ConfigIdentityProvider,
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl AuthServiceImpl {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&dynamic));
|
||||
Self { provider, dynamic }
|
||||
}
|
||||
|
||||
pub fn verify_pubkey(&self, fingerprint: &str) -> AuthResult {
|
||||
match self.provider.resolve_from_fingerprint(fingerprint) {
|
||||
Some(identity) => AuthResult::Ok(identity),
|
||||
None => AuthResult::Denied(format!("key not authorized: {}", fingerprint)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_token(&self, token: &AuthToken) -> AuthResult {
|
||||
match self.provider.resolve_from_token(token) {
|
||||
Some(identity) => AuthResult::Ok(identity),
|
||||
None => AuthResult::Denied("token verification failed".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_keys(&self) {
|
||||
self.dynamic.rcu(Arc::clone);
|
||||
}
|
||||
|
||||
pub fn check_access(&self, identity: &Identity, operation: &str) -> AuthResult {
|
||||
if identity.scopes.iter().any(|s| s == operation) {
|
||||
AuthResult::Ok(identity.clone())
|
||||
} else {
|
||||
AuthResult::Denied(format!(
|
||||
"identity {} lacks scope: {}",
|
||||
identity.id, operation
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AuthServiceImpl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AuthServiceImpl").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::keys::PrivateKey;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
russh::keys::decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_service(keys_content: &str) -> (AuthServiceImpl, Arc<ArcSwap<DynamicConfig>>) {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let service = AuthServiceImpl::new(Arc::clone(&arc_swap));
|
||||
(service, arc_swap)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_valid() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let result = service.verify_pubkey(&fingerprint);
|
||||
assert!(matches!(result, AuthResult::Ok(_)));
|
||||
if let AuthResult::Ok(identity) = result {
|
||||
assert_eq!(identity.id, fingerprint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_invalid() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let result = service.verify_pubkey("SHA256:invalid");
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_matches_identity_provider() {
|
||||
let (service, arc_swap) = make_service(ED25519_PUBLIC_KEY);
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
|
||||
let service_result = service.verify_pubkey(&fingerprint);
|
||||
let provider_result = provider.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match service_result {
|
||||
AuthResult::Ok(identity) => {
|
||||
assert_eq!(identity, provider_result.unwrap());
|
||||
}
|
||||
AuthResult::Denied(_) => {
|
||||
assert!(provider_result.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_token_returns_denied() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let token = AuthToken {
|
||||
raw: b"test-token".to_vec(),
|
||||
};
|
||||
let result = service.verify_token(&token);
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_check_access_granted() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let identity = Identity {
|
||||
id: fingerprint,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = service.check_access(&identity, "relay:connect");
|
||||
assert!(matches!(result, AuthResult::Ok(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_check_access_denied() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let identity = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = service.check_access(&identity, "admin:write");
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_protocol_variants() {
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let verify_pubkey = AuthProtocol::VerifyPubkey {
|
||||
fingerprint: "SHA256:abc".to_string(),
|
||||
key_data: vec![1, 2, 3],
|
||||
};
|
||||
match &verify_pubkey {
|
||||
AuthProtocol::VerifyPubkey {
|
||||
fingerprint,
|
||||
key_data,
|
||||
} => {
|
||||
assert_eq!(fingerprint, "SHA256:abc");
|
||||
assert_eq!(key_data, &vec![1, 2, 3]);
|
||||
}
|
||||
_ => panic!("expected VerifyPubkey variant"),
|
||||
}
|
||||
|
||||
let verify_token = AuthProtocol::VerifyToken {
|
||||
token_bytes: vec![4, 5, 6],
|
||||
timestamp: 12345,
|
||||
};
|
||||
match &verify_token {
|
||||
AuthProtocol::VerifyToken {
|
||||
token_bytes,
|
||||
timestamp,
|
||||
} => {
|
||||
assert_eq!(token_bytes, &vec![4, 5, 6]);
|
||||
assert_eq!(*timestamp, 12345);
|
||||
}
|
||||
_ => panic!("expected VerifyToken variant"),
|
||||
}
|
||||
|
||||
assert!(matches!(AuthProtocol::ReloadKeys, AuthProtocol::ReloadKeys));
|
||||
|
||||
let check = AuthProtocol::CheckAccess {
|
||||
identity: identity.clone(),
|
||||
operation: "relay:connect".to_string(),
|
||||
};
|
||||
match &check {
|
||||
AuthProtocol::CheckAccess {
|
||||
identity: id,
|
||||
operation,
|
||||
} => {
|
||||
assert_eq!(id.id, "SHA256:abc");
|
||||
assert_eq!(operation, "relay:connect");
|
||||
}
|
||||
_ => panic!("expected CheckAccess variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_result_ok_identity() {
|
||||
let identity = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = AuthResult::Ok(identity.clone());
|
||||
assert_eq!(result, AuthResult::Ok(identity));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_result_denied_message() {
|
||||
let result = AuthResult::Denied("access denied".to_string());
|
||||
assert_eq!(result, AuthResult::Denied("access denied".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use russh::client;
|
||||
use russh::keys::key::PrivateKeyWithHashAlg;
|
||||
use russh::keys::{PrivateKey, PublicKey};
|
||||
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// Client-side SSH authentication configuration.
|
||||
///
|
||||
/// Holds the private key used for SSH authentication and an optional
|
||||
/// public key override. When no public key is provided, it is derived
|
||||
/// from the private key.
|
||||
pub struct ClientAuthConfig {
|
||||
private_key: Arc<PrivateKey>,
|
||||
public_key: PublicKey,
|
||||
}
|
||||
|
||||
impl ClientAuthConfig {
|
||||
/// Load a `ClientAuthConfig` from a key source (file or in-memory).
|
||||
pub fn from_key_source(source: KeySource) -> Result<Self, ConfigError> {
|
||||
let private_key = crate::auth::keys::load_private_key(source)?;
|
||||
let public_key = private_key.public_key().clone();
|
||||
Ok(Self {
|
||||
private_key: Arc::new(private_key),
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the private key wrapped in `Arc` for use with russh authentication.
|
||||
pub fn private_key(&self) -> Arc<PrivateKey> {
|
||||
Arc::clone(&self.private_key)
|
||||
}
|
||||
|
||||
/// Returns the public key derived from (or overridden for) this config.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.public_key
|
||||
}
|
||||
|
||||
/// Authenticate with the given SSH session handle and username.
|
||||
pub async fn authenticate<H: client::Handler>(
|
||||
&self,
|
||||
handle: &mut client::Handle<H>,
|
||||
username: &str,
|
||||
) -> Result<bool, russh::Error> {
|
||||
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::clone(&self.private_key), None)?;
|
||||
handle.authenticate_publickey(username, key_with_alg).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Client handler implementing `russh::client::Handler`.
|
||||
///
|
||||
/// Provides the callbacks required by russh during the SSH handshake.
|
||||
/// Server key verification is delegated to a configurable callback;
|
||||
/// the default accepts all server keys (suitable for testing or when
|
||||
/// transport-layer verification — e.g. TLS — is already in place).
|
||||
pub struct ClientHandler {
|
||||
pub_key: PublicKey,
|
||||
check_server_key_fn: Box<dyn Fn(&PublicKey) -> bool + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create a new client handler from a `ClientAuthConfig`.
|
||||
pub fn from_config(config: &ClientAuthConfig) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(|_| true),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a client handler with a custom server key verification callback.
|
||||
pub fn with_server_key_check(
|
||||
config: &ClientAuthConfig,
|
||||
check_fn: impl Fn(&PublicKey) -> bool + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(check_fn),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the public key associated with this handler.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.pub_key
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl client::Handler for ClientHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn check_server_key(
|
||||
&mut self,
|
||||
server_public_key: &PublicKey,
|
||||
) -> Result<bool, Self::Error> {
|
||||
Ok((self.check_server_key_fn)(server_public_key))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use russh::client::Handler;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
#[test]
|
||||
fn from_key_source_memory() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
assert_eq!(
|
||||
config.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_from_config() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::from_config(&config);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_with_custom_server_key_check() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_key_source_invalid_key() {
|
||||
let source = KeySource::Memory(b"not a key".to_vec());
|
||||
let result = ClientAuthConfig::from_key_source(source);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_accepts_by_default() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::from_config(&config);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_rejects_with_custom_fn() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_key_arc_dedup() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let key1 = config.private_key();
|
||||
let key2 = config.private_key();
|
||||
assert!(Arc::ptr_eq(&key1, &key2));
|
||||
}
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
//! Identity resolution and the `IdentityProvider` trait.
|
||||
//!
|
||||
//! See [ADR-029](docs/architecture/decisions/029-identity-provider.md) and
|
||||
//! [ADR-028](docs/architecture/decisions/028-identity-model.md).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Identity {
|
||||
pub id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub resources: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthToken {
|
||||
pub raw: Vec<u8>,
|
||||
}
|
||||
|
||||
pub trait IdentityProvider: Send + Sync + 'static {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity>;
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity>;
|
||||
}
|
||||
|
||||
pub struct ConfigIdentityProvider {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigIdentityProvider {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
}
|
||||
|
||||
impl IdentityProvider for ConfigIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
let auth = &config.auth;
|
||||
auth.resolve_identity_from_fingerprint(fingerprint)
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
let auth = &config.auth;
|
||||
let token_str = String::from_utf8_lossy(&token.raw);
|
||||
if token_str.starts_with(crate::config::API_KEY_PREFIX) {
|
||||
return auth.resolve_api_key(&token_str);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::keys::PrivateKey;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
russh::keys::decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_provider(keys_content: &str) -> (ConfigIdentityProvider, Arc<ArcSwap<DynamicConfig>>) {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
(provider, arc_swap)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_fields() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert(
|
||||
"service".to_string(),
|
||||
vec!["gitea".to_string(), "registry".to_string()],
|
||||
);
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec![
|
||||
"relay:connect".to_string(),
|
||||
"service:gitea:read".to_string(),
|
||||
],
|
||||
resources,
|
||||
};
|
||||
assert_eq!(identity.id, "SHA256:abc123");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "service:gitea:read"]);
|
||||
assert_eq!(
|
||||
identity.resources.get("service").unwrap(),
|
||||
&vec!["gitea".to_string(), "registry".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_equality() {
|
||||
let id1 = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let id2 = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
assert_eq!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_inequality_different_id() {
|
||||
let id1 = Identity {
|
||||
id: "a".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let id2 = Identity {
|
||||
id: "b".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolves_valid_fingerprint() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, fingerprint);
|
||||
assert!(!identity.scopes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_invalid_fingerprint() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let identity = provider.resolve_from_fingerprint("SHA256:invalid");
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_empty_config_rejects_all() {
|
||||
let dynamic = DynamicConfig::default();
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
let identity = provider.resolve_from_fingerprint("SHA256:anything");
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolve_from_token_returns_none() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let token = AuthToken {
|
||||
raw: b"test-token".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
fn compute_api_key_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolves_valid_api_key() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider.resolve_from_token(&auth_token);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_expired_api_key() {
|
||||
let token = "alk_expiredkey1";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_expi".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "expired key".to_string(),
|
||||
expires_at: Some(1),
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_wrong_hash_api_key() {
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||
.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "bad hash".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: b"alk_testsecret123".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_api_key_unknown_prefix_falls_through() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_other".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "other key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_api_key_scopes_in_identity() {
|
||||
let token = "alk_scopedkey12";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_sco".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string(), "secrets:derive".to_string()],
|
||||
description: "scoped key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider.resolve_from_token(&auth_token).unwrap();
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "secrets:derive"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_token_holds_raw_bytes() {
|
||||
let token = AuthToken { raw: vec![1, 2, 3] };
|
||||
assert_eq!(token.raw, vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_reflects_config_reload() {
|
||||
let (provider, arc_swap) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_some());
|
||||
|
||||
let new_dynamic = DynamicConfig::default();
|
||||
arc_swap.store(Arc::new(new_dynamic));
|
||||
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
//! Key loading and parsing for SSH authentication.
|
||||
//!
|
||||
//! Supports `KeySource` (file path or in-memory) for private keys, public keys,
|
||||
//! and certificate authority entries. All keys must be in OpenSSH format.
|
||||
//! PEM-encoded keys (PKCS#1, PKCS#8) are rejected with a clear error message.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use russh::keys::{decode_secret_key, parse_public_key_base64, PrivateKey, PublicKey};
|
||||
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// Source for key material — either a filesystem path or in-memory bytes.
|
||||
///
|
||||
/// Used throughout the API to accept keys without committing to a specific
|
||||
/// loading mechanism. In-memory keys are primarily for the NAPI wrapper.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeySource {
|
||||
File(PathBuf),
|
||||
Memory(Vec<u8>),
|
||||
}
|
||||
|
||||
/// A certificate authority entry parsed from an `authorized_keys` file.
|
||||
///
|
||||
/// Contains the CA public key and its associated options (e.g., `cert-authority`,
|
||||
/// `permit-port-forwarding`). Used by `ServerAuthConfig` for certificate validation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertAuthorityEntry {
|
||||
pub public_key: PublicKey,
|
||||
pub options: Vec<String>,
|
||||
}
|
||||
|
||||
fn resolve_bytes(source: &KeySource) -> Result<Vec<u8>, ConfigError> {
|
||||
match source {
|
||||
KeySource::File(path) => {
|
||||
if !path.exists() {
|
||||
return Err(ConfigError::KeyFileNotFound {
|
||||
path: path.display().to_string(),
|
||||
});
|
||||
}
|
||||
std::fs::read(path).map_err(|_| ConfigError::KeyFileNotFound {
|
||||
path: path.display().to_string(),
|
||||
})
|
||||
}
|
||||
KeySource::Memory(data) => Ok(data.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_openssh_private_key(data: &[u8]) -> Result<(), ConfigError> {
|
||||
let s = String::from_utf8_lossy(data);
|
||||
if s.contains("-----BEGIN OPENSSH PRIVATE KEY-----") {
|
||||
return Ok(());
|
||||
}
|
||||
if s.contains("-----BEGIN RSA PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN ENCRYPTED PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN EC PRIVATE KEY-----")
|
||||
{
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "PEM-encoded key is not supported; use OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
|
||||
});
|
||||
}
|
||||
Err(ConfigError::InvalidFlag {
|
||||
name: "unrecognized private key format; expected OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_private_key(source: KeySource) -> Result<PrivateKey, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
check_openssh_private_key(&data)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
decode_secret_key(&s, None).map_err(|e| ConfigError::InvalidFlag {
|
||||
name: format!("failed to decode private key: {e}"),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_authorized_keys_line(line: &str) -> Option<Result<(PublicKey, Vec<String>), ConfigError>> {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.splitn(4, ' ').collect();
|
||||
if parts.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut options = Vec::new();
|
||||
let key_type_idx;
|
||||
|
||||
if parts[0].starts_with("cert-authority")
|
||||
|| parts[0].starts_with("no-")
|
||||
|| parts[0].starts_with("permit-")
|
||||
|| parts[0].starts_with("from=")
|
||||
|| parts[0].starts_with("command=")
|
||||
|| parts[0].starts_with("environment=")
|
||||
|| parts[0].starts_with("tunnel=")
|
||||
|| parts[0].starts_with("principals=")
|
||||
{
|
||||
let opts_str = parts[0];
|
||||
options = opts_str.split(',').map(|s| s.to_string()).collect();
|
||||
key_type_idx = 1;
|
||||
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
|
||||
key_type_idx = 0;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
|
||||
if parts.len() <= key_type_idx {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key_base64 = parts[key_type_idx + 1];
|
||||
match parse_public_key_base64(key_base64) {
|
||||
Ok(pk) => Some(Ok((pk, options))),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_public_keys(source: KeySource) -> Result<Vec<PublicKey>, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
let mut keys = Vec::new();
|
||||
for line in s.lines() {
|
||||
if let Some(Ok((pk, _))) = parse_authorized_keys_line(line) {
|
||||
keys.push(pk);
|
||||
}
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
pub fn load_cert_authority_entries(
|
||||
source: KeySource,
|
||||
) -> Result<Vec<CertAuthorityEntry>, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
let mut entries = Vec::new();
|
||||
for line in s.lines() {
|
||||
if let Some(result) = parse_authorized_keys_line(line) {
|
||||
match result {
|
||||
Ok((pk, options)) if !options.is_empty() => {
|
||||
entries.push(CertAuthorityEntry {
|
||||
public_key: pk,
|
||||
options,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
const PEM_PRIVATE_KEY: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC\n-----END PRIVATE KEY-----\n";
|
||||
|
||||
fn make_authorized_keys(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
write!(f, "{content}").unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_private_key_file(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ed25519_key_from_file() {
|
||||
let f = make_private_key_file(ED25519_PRIVATE_KEY);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let key = load_private_key(source).unwrap();
|
||||
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ed25519_key_from_memory() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let key = load_private_key(source).unwrap();
|
||||
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_key_file_not_found() {
|
||||
let source = KeySource::File(PathBuf::from("/nonexistent/key"));
|
||||
let result = load_private_key(source);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, ConfigError::KeyFileNotFound { .. }));
|
||||
assert!(err.to_string().contains("/nonexistent/key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_pem_format() {
|
||||
let source = KeySource::Memory(PEM_PRIVATE_KEY.to_vec());
|
||||
let result = load_private_key(source);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, ConfigError::InvalidFlag { .. }));
|
||||
assert!(err.to_string().contains("PEM"));
|
||||
}
|
||||
|
||||
const ED25519_PUBLIC_KEY_2: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
|
||||
#[test]
|
||||
fn parse_authorized_keys_multiple_entries() {
|
||||
let content = format!("{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n");
|
||||
let f = make_authorized_keys(&content);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let keys = load_public_keys(source).unwrap();
|
||||
assert_eq!(keys.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_authorized_keys_from_memory() {
|
||||
let content = format!("{ED25519_PUBLIC_KEY}\n");
|
||||
let source = KeySource::Memory(content.into_bytes());
|
||||
let keys = load_public_keys(source).unwrap();
|
||||
assert_eq!(keys.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_cert_authority_entry() {
|
||||
let content =
|
||||
"cert-authority,permit-port-forwarding ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV CA name\n";
|
||||
let f = make_authorized_keys(content);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let entries = load_cert_authority_entries(source).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].options.len(), 2);
|
||||
assert_eq!(entries[0].options[0], "cert-authority");
|
||||
assert_eq!(entries[0].options[1], "permit-port-forwarding");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mixed_authorized_keys() {
|
||||
let content = format!(
|
||||
"{ED25519_PUBLIC_KEY}\ncert-authority ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE CA name\n"
|
||||
);
|
||||
let source = KeySource::Memory(content.into_bytes());
|
||||
let keys = load_public_keys(source.clone()).unwrap();
|
||||
assert_eq!(keys.len(), 2);
|
||||
let entries = load_cert_authority_entries(source).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].options, vec!["cert-authority"]);
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//! SSH authentication (Ed25519 public key and OpenSSH certificate authority).
|
||||
//!
|
||||
//! Supports file-path and in-memory key sources. No password authentication.
|
||||
//! See ADR-012 for the design rationale.
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub mod auth_protocol;
|
||||
pub mod client_auth;
|
||||
pub mod identity;
|
||||
pub mod keys;
|
||||
pub mod server_auth;
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub use auth_protocol::{AuthProtocol, AuthResult, AuthServiceImpl};
|
||||
pub use client_auth::{ClientAuthConfig, ClientHandler};
|
||||
pub use identity::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
pub use keys::{load_private_key, load_public_keys, CertAuthorityEntry, KeySource};
|
||||
pub use server_auth::ServerAuthConfig;
|
||||
@@ -1,395 +0,0 @@
|
||||
//! Server-side authentication configuration and validation.
|
||||
//!
|
||||
//! `ServerAuthConfig` holds the set of authorized public keys and optional certificate
|
||||
//! authority entries. Authentication is key-based only (Ed25519 + optional OpenSSH CA).
|
||||
//! No password authentication. See ADR-012.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
use russh::keys::helpers::EncodedExt;
|
||||
use russh::keys::{Certificate, PublicKey};
|
||||
|
||||
use super::keys::{load_cert_authority_entries, load_public_keys, CertAuthorityEntry, KeySource};
|
||||
use crate::error::AuthError;
|
||||
|
||||
/// Server-side authentication configuration.
|
||||
///
|
||||
/// Holds authorized public keys (constant-time comparison) and optional certificate
|
||||
/// authority entries for validating OpenSSH certificates.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerAuthConfig {
|
||||
pub authorized_keys: HashSet<PublicKey>,
|
||||
pub cert_authorities: Vec<CertAuthorityEntry>,
|
||||
encoded_keys: HashSet<Vec<u8>>,
|
||||
}
|
||||
|
||||
fn encode_key_data(key: &PublicKey) -> Vec<u8> {
|
||||
key.key_data().encoded().unwrap_or_default()
|
||||
}
|
||||
|
||||
impl ServerAuthConfig {
|
||||
pub fn from_keys_and_ca(
|
||||
authorized_keys_source: Option<KeySource>,
|
||||
cert_authority_source: Option<KeySource>,
|
||||
) -> Result<Self, crate::error::ConfigError> {
|
||||
let authorized_keys: HashSet<PublicKey> = match authorized_keys_source {
|
||||
Some(src) => load_public_keys(src)?.into_iter().collect(),
|
||||
None => HashSet::new(),
|
||||
};
|
||||
|
||||
let encoded_keys: HashSet<Vec<u8>> = authorized_keys.iter().map(encode_key_data).collect();
|
||||
|
||||
let cert_authorities = match cert_authority_source {
|
||||
Some(src) => load_cert_authority_entries(src)?,
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
Ok(ServerAuthConfig {
|
||||
authorized_keys,
|
||||
cert_authorities,
|
||||
encoded_keys,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate_publickey(&self, key: &PublicKey) -> Result<(), AuthError> {
|
||||
let encoded = encode_key_data(key);
|
||||
if self.encoded_keys.contains(&encoded) {
|
||||
return Ok(());
|
||||
}
|
||||
Err(AuthError::KeyRejected)
|
||||
}
|
||||
|
||||
pub fn authenticate_certificate(
|
||||
&self,
|
||||
cert: &Certificate,
|
||||
user: &str,
|
||||
client_ip: Option<IpAddr>,
|
||||
) -> Result<(), AuthError> {
|
||||
let matching_ca = self
|
||||
.cert_authorities
|
||||
.iter()
|
||||
.find(|ca| cert.signature_key() == ca.public_key.key_data());
|
||||
|
||||
let ca_entry = match matching_ca {
|
||||
Some(entry) => entry,
|
||||
None => return Err(AuthError::CertInvalid),
|
||||
};
|
||||
|
||||
if cert.verify_signature().is_err() {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
let now_secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
|
||||
return Err(AuthError::CertExpired);
|
||||
}
|
||||
|
||||
let principals = cert.valid_principals();
|
||||
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
|
||||
return Err(AuthError::CertPrincipalMismatch);
|
||||
}
|
||||
|
||||
check_critical_options(cert, ca_entry, client_ip)?;
|
||||
|
||||
check_extensions(cert, ca_entry)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_critical_options(
|
||||
cert: &Certificate,
|
||||
ca_entry: &CertAuthorityEntry,
|
||||
client_ip: Option<IpAddr>,
|
||||
) -> Result<(), AuthError> {
|
||||
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
|
||||
|
||||
for (name, data) in cert.critical_options().iter() {
|
||||
match name.as_str() {
|
||||
"source-address" => {
|
||||
if !check_source_address(data, client_ip) {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
"force-command" => {}
|
||||
"no-pty" => {}
|
||||
_ => {
|
||||
let _ = ca_has_no_pty;
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_extensions(cert: &Certificate, ca_entry: &CertAuthorityEntry) -> Result<(), AuthError> {
|
||||
let ca_permit_port_forwarding = ca_entry
|
||||
.options
|
||||
.iter()
|
||||
.any(|o| o == "permit-port-forwarding");
|
||||
|
||||
if ca_permit_port_forwarding {
|
||||
let cert_allows = cert
|
||||
.extensions()
|
||||
.iter()
|
||||
.any(|(n, _)| n == "permit-port-forwarding");
|
||||
if !cert_allows {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_source_address(allowed: &str, client_ip: Option<IpAddr>) -> bool {
|
||||
let Some(ip) = client_ip else {
|
||||
return false;
|
||||
};
|
||||
|
||||
for pattern in allowed.split(',') {
|
||||
let pattern = pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(cidr) = IpNetwork::from_str(pattern) {
|
||||
if cidr.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(net_ip) = IpAddr::from_str(pattern) {
|
||||
if net_ip == ip {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand_core::OsRng;
|
||||
use russh::keys::ssh_key::certificate::{Builder, CertType};
|
||||
use russh::keys::{decode_secret_key, Certificate, PrivateKey};
|
||||
use std::io::Write;
|
||||
|
||||
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const USER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5NgAAAJgM/+f3DP/n\n9wAAAAtzc2gtZWQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5Ng\nAAAEADN/ZEFvX/mflX8aEGwS/tMzys564rYEaMzd4vmYKZkShOvxfseqW24oF0F0HZWNv4\nrtuLe9U9y5YBghve6vk2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const OTHER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdgAAAJgXj2UzF49l\nMwAAAAtzc2gtZWQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdg\nAAAEBVadyi5nAUfkjpp4zyQ08b8h1o4RTEgwtLejTjX5Tycb/tXYstPhZGbUx+N7x5I9aW\nE36Q1fPavIqioKRKsbN2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn load_ca_key() -> PrivateKey {
|
||||
decode_secret_key(CA_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn load_user_key() -> PrivateKey {
|
||||
decode_secret_key(USER_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn load_other_key() -> PrivateKey {
|
||||
decode_secret_key(OTHER_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_cert(
|
||||
ca_key: &PrivateKey,
|
||||
user_pub: &PublicKey,
|
||||
valid_after: u64,
|
||||
valid_before: u64,
|
||||
principals: Vec<&str>,
|
||||
) -> Certificate {
|
||||
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
|
||||
let mut builder =
|
||||
Builder::new_with_random_nonce(&mut OsRng, key_data, valid_after, valid_before)
|
||||
.unwrap();
|
||||
|
||||
builder.cert_type(CertType::User).unwrap();
|
||||
|
||||
for p in principals {
|
||||
builder.valid_principal(p).unwrap();
|
||||
}
|
||||
|
||||
builder.sign(ca_key).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys: &[&PublicKey]) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
for key in keys {
|
||||
let line = format!("{}\n", key.to_openssh().unwrap());
|
||||
f.write_all(line.as_bytes()).unwrap();
|
||||
}
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_ca_file(ca_pub: &PublicKey, options: &[&str]) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
let opts = if options.is_empty() {
|
||||
"cert-authority".to_string()
|
||||
} else {
|
||||
format!("cert-authority,{}", options.join(","))
|
||||
};
|
||||
let line = format!("{} {} CA\n", opts, ca_pub.to_openssh().unwrap());
|
||||
f.write_all(line.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_key_accepted() {
|
||||
let user_key = load_user_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let f = make_authorized_keys_file(&[&user_pub]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
assert!(config.authenticate_publickey(&user_pub).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_key_rejected() {
|
||||
let user_key = load_user_key();
|
||||
let other_key = load_other_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let other_pub = other_key.public_key().clone();
|
||||
let f = make_authorized_keys_file(&[&user_pub]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_publickey(&other_pub),
|
||||
Err(AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_authority_signed_cert_accepted() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert!(config
|
||||
.authenticate_certificate(&cert, "testuser", None)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_cert_rejected() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 7200, now - 3600, vec!["testuser"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "testuser", None),
|
||||
Err(AuthError::CertExpired)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_principal_rejected() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["alice"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "bob", None),
|
||||
Err(AuthError::CertPrincipalMismatch)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_wildcard_principals_accepts_any_user() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
|
||||
let mut builder =
|
||||
Builder::new_with_random_nonce(&mut OsRng, key_data, now - 60, now + 3600).unwrap();
|
||||
builder.cert_type(CertType::User).unwrap();
|
||||
builder.all_principals_valid().unwrap();
|
||||
let cert = builder.sign(&ca_key).unwrap();
|
||||
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert!(config
|
||||
.authenticate_certificate(&cert, "anyuser", None)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_wrong_ca_rejected() {
|
||||
let user_key = load_user_key();
|
||||
let other_ca_key = load_other_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(
|
||||
&other_ca_key,
|
||||
&user_pub,
|
||||
now - 60,
|
||||
now + 3600,
|
||||
vec!["testuser"],
|
||||
);
|
||||
let ca_key = load_ca_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "testuser", None),
|
||||
Err(AuthError::CertInvalid)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_config_accepts_nothing() {
|
||||
let config = ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
|
||||
let other_pub = load_other_key().public_key().clone();
|
||||
assert_eq!(
|
||||
config.authenticate_publickey(&other_pub),
|
||||
Err(AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OperationContext {
|
||||
pub request_id: String,
|
||||
pub parent_request_id: Option<String>,
|
||||
pub identity: Option<crate::auth::Identity>,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub env: OperationEnv,
|
||||
pub trusted: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::OperationRegistry;
|
||||
|
||||
fn make_context() -> OperationContext {
|
||||
let registry = OperationRegistry::new();
|
||||
OperationContext {
|
||||
request_id: "req-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_context_fields() {
|
||||
let ctx = make_context();
|
||||
assert_eq!(ctx.request_id, "req-1");
|
||||
assert!(ctx.parent_request_id.is_none());
|
||||
assert!(ctx.identity.is_none());
|
||||
assert!(ctx.metadata.is_empty());
|
||||
assert!(!ctx.trusted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_context_with_parent() {
|
||||
let registry = OperationRegistry::new();
|
||||
let ctx = OperationContext {
|
||||
request_id: "req-2".to_string(),
|
||||
parent_request_id: Some("req-1".to_string()),
|
||||
identity: None,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: true,
|
||||
};
|
||||
assert_eq!(ctx.parent_request_id, Some("req-1".to_string()));
|
||||
assert!(ctx.trusted);
|
||||
}
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::registry::OperationRegistry;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::credentials::{CredentialProvider, CredentialSet, SecretStoreCredentialProvider};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OperationEnv {
|
||||
registry: Arc<OperationRegistry>,
|
||||
credential_provider: Arc<dyn CredentialProvider>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OperationEnv {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OperationEnv")
|
||||
.field("registry", &self.registry)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationEnv {
|
||||
pub fn local(registry: OperationRegistry) -> Self {
|
||||
Self {
|
||||
registry: Arc::new(registry),
|
||||
credential_provider: Arc::new(SecretStoreCredentialProvider::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_credential_provider(
|
||||
registry: OperationRegistry,
|
||||
credential_provider: Arc<dyn CredentialProvider>,
|
||||
) -> Self {
|
||||
Self {
|
||||
registry: Arc::new(registry),
|
||||
credential_provider,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
self.credential_provider.get_credentials(service)
|
||||
}
|
||||
|
||||
pub fn invoke(&self, namespace: &str, operation: &str, input: Value) -> ResponseEnvelope {
|
||||
let name = format!("/{namespace}/{operation}");
|
||||
let request_id = format!("env{name}");
|
||||
let context = OperationContext {
|
||||
request_id: request_id.clone(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
metadata: std::collections::HashMap::new(),
|
||||
env: self.clone(),
|
||||
trusted: true,
|
||||
};
|
||||
self.registry.invoke(&name, input, context)
|
||||
}
|
||||
|
||||
pub fn registry_ref(&self) -> &OperationRegistry {
|
||||
&self.registry
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::registry::OperationRegistryBuilder;
|
||||
use crate::call::spec::{AccessControl, OperationSpec, OperationType};
|
||||
use crate::config::{AuthPolicy, DynamicConfig};
|
||||
use crate::credentials::ConfigCredentialProvider;
|
||||
use arc_swap::ArcSwap;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_spec(name: &str, namespace: &str) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_local_invoke() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/auth/verify", "auth"),
|
||||
Arc::new(|_input, _ctx| {
|
||||
ResponseEnvelope::ok("env-/auth/verify", serde_json::json!({"verified": true}))
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!({"token": "abc"}));
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_invoke_missing() {
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!(null));
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_invoke_trusted() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/auth/verify", "auth"),
|
||||
Arc::new(|_input, ctx| {
|
||||
assert!(ctx.trusted);
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!({"ok": true}))
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!(null));
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_provides_credentials_from_handler_context() {
|
||||
let mut credentials = HashMap::new();
|
||||
credentials.insert(
|
||||
"vast-ai".to_string(),
|
||||
CredentialSet::Bearer {
|
||||
token: "test-token".to_string(),
|
||||
},
|
||||
);
|
||||
let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials);
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(config)));
|
||||
let provider = Arc::new(ConfigCredentialProvider::new(dynamic));
|
||||
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/test/creds", "test"),
|
||||
Arc::new(|_input, ctx| {
|
||||
let creds = ctx.env.credentials("vast-ai");
|
||||
match creds {
|
||||
Some(CredentialSet::Bearer { token }) => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({"token": token}),
|
||||
),
|
||||
_ => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({"found": false}),
|
||||
),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::with_credential_provider(registry, provider);
|
||||
let result = env.invoke("test", "creds", serde_json::json!(null));
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
assert_eq!(value["token"], "test-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_credentials_returns_none_for_missing_service() {
|
||||
let config = DynamicConfig::default();
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(config)));
|
||||
let provider = Arc::new(ConfigCredentialProvider::new(dynamic));
|
||||
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::with_credential_provider(registry, provider);
|
||||
assert!(env.credentials("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_default_credentials_returns_none() {
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::local(registry);
|
||||
assert!(env.credentials("vast-ai").is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EventEnvelope {
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
pub id: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
impl EventEnvelope {
|
||||
pub fn new(event_type: impl Into<String>, id: impl Into<String>, payload: Value) -> Self {
|
||||
Self {
|
||||
r#type: event_type.into(),
|
||||
id: id.into(),
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_requested(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_REQUESTED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_responded(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_RESPONDED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_completed(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_COMPLETED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_aborted(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_ABORTED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_error(
|
||||
id: impl Into<String>,
|
||||
code: impl Into<String>,
|
||||
message: impl Into<String>,
|
||||
retryable: bool,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
super::events::CALL_ERROR,
|
||||
id,
|
||||
serde_json::json!({
|
||||
"code": code.into(),
|
||||
"message": message.into(),
|
||||
"retryable": retryable,
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn event_envelope_new() {
|
||||
let env = EventEnvelope::new(
|
||||
"call.requested",
|
||||
"req-1",
|
||||
serde_json::json!({"key": "value"}),
|
||||
);
|
||||
assert_eq!(env.r#type, "call.requested");
|
||||
assert_eq!(env.id, "req-1");
|
||||
assert_eq!(env.payload, serde_json::json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_serialization() {
|
||||
let env = EventEnvelope::new(
|
||||
"call.requested",
|
||||
"req-1",
|
||||
serde_json::json!({"key": "value"}),
|
||||
);
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: EventEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.r#type, "call.requested");
|
||||
assert_eq!(deserialized.id, "req-1");
|
||||
assert_eq!(deserialized.payload, serde_json::json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_serialization_type_field() {
|
||||
let env = EventEnvelope::new("call.requested", "req-1", serde_json::json!(null));
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
assert!(serialized.contains("\"type\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_deserialization() {
|
||||
let json = r#"{"type":"call.responded","id":"req-42","payload":{"result":"ok"}}"#;
|
||||
let env: EventEnvelope = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(env.r#type, "call.responded");
|
||||
assert_eq!(env.id, "req-42");
|
||||
assert_eq!(env.payload["result"], "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_requested() {
|
||||
let env = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
assert_eq!(env.r#type, "call.requested");
|
||||
assert_eq!(env.id, "req-1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_responded() {
|
||||
let env = EventEnvelope::call_responded("req-1", serde_json::json!({"data": 42}));
|
||||
assert_eq!(env.r#type, "call.responded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_completed() {
|
||||
let env = EventEnvelope::call_completed("req-1", serde_json::json!(null));
|
||||
assert_eq!(env.r#type, "call.completed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_aborted() {
|
||||
let env = EventEnvelope::call_aborted("req-1", serde_json::json!({"reason": "cancelled"}));
|
||||
assert_eq!(env.r#type, "call.aborted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_error() {
|
||||
let env = EventEnvelope::call_error("req-1", "TIMEOUT", "timed out", true);
|
||||
assert_eq!(env.r#type, "call.error");
|
||||
assert_eq!(env.id, "req-1");
|
||||
assert_eq!(env.payload["code"], "TIMEOUT");
|
||||
assert_eq!(env.payload["message"], "timed out");
|
||||
assert_eq!(env.payload["retryable"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_empty_id() {
|
||||
let env = EventEnvelope::new("event.broadcast", "", serde_json::json!({"msg": "hello"}));
|
||||
assert_eq!(env.id, "");
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
pub const CALL_REQUESTED: &str = "call.requested";
|
||||
pub const CALL_RESPONDED: &str = "call.responded";
|
||||
pub const CALL_COMPLETED: &str = "call.completed";
|
||||
pub const CALL_ABORTED: &str = "call.aborted";
|
||||
pub const CALL_ERROR: &str = "call.error";
|
||||
|
||||
pub const SERVICE_LIST: &str = "/services/list";
|
||||
pub const SERVICE_SCHEMA: &str = "/services/schema";
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn event_type_constants() {
|
||||
assert_eq!(CALL_REQUESTED, "call.requested");
|
||||
assert_eq!(CALL_RESPONDED, "call.responded");
|
||||
assert_eq!(CALL_COMPLETED, "call.completed");
|
||||
assert_eq!(CALL_ABORTED, "call.aborted");
|
||||
assert_eq!(CALL_ERROR, "call.error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_operation_constants() {
|
||||
assert_eq!(SERVICE_LIST, "/services/list");
|
||||
assert_eq!(SERVICE_SCHEMA, "/services/schema");
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::call::envelope::EventEnvelope;
|
||||
|
||||
pub fn encode(envelope: &EventEnvelope) -> Vec<u8> {
|
||||
let json = serde_json::to_vec(envelope).expect("EventEnvelope serialization must not fail");
|
||||
let len = json.len() as u32;
|
||||
let mut frame = Vec::with_capacity(4 + json.len());
|
||||
frame.extend_from_slice(&len.to_be_bytes());
|
||||
frame.extend_from_slice(&json);
|
||||
frame
|
||||
}
|
||||
|
||||
pub fn decode(data: &[u8]) -> Result<EventEnvelope, FrameDecodeError> {
|
||||
if data.len() < 4 {
|
||||
return Err(FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
|
||||
if data.len() < 4 + len {
|
||||
return Err(FrameDecodeError::Incomplete {
|
||||
expected: 4 + len,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let body = &data[4..4 + len];
|
||||
let envelope: EventEnvelope = serde_json::from_slice(body).map_err(FrameDecodeError::Json)?;
|
||||
Ok(envelope)
|
||||
}
|
||||
|
||||
pub fn decode_with_remainder(data: &[u8]) -> Result<(EventEnvelope, usize), FrameDecodeError> {
|
||||
if data.len() < 4 {
|
||||
return Err(FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
|
||||
let total = 4 + len;
|
||||
if data.len() < total {
|
||||
return Err(FrameDecodeError::Incomplete {
|
||||
expected: total,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let body = &data[4..total];
|
||||
let envelope: EventEnvelope = serde_json::from_slice(body).map_err(FrameDecodeError::Json)?;
|
||||
Ok((envelope, total))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FrameDecodeError {
|
||||
#[error("frame too short: expected at least {expected} bytes, got {actual}")]
|
||||
TooShort { expected: usize, actual: usize },
|
||||
#[error("incomplete frame: expected {expected} bytes, got {actual}")]
|
||||
Incomplete { expected: usize, actual: usize },
|
||||
#[error("JSON deserialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub struct FrameFramedReader<S> {
|
||||
stream: S,
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<S> FrameFramedReader<S>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
buf: Vec::with_capacity(4096),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> io::Result<Option<EventEnvelope>> {
|
||||
loop {
|
||||
if self.buf.len() >= 4 {
|
||||
let len = u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]])
|
||||
as usize;
|
||||
let total = 4 + len;
|
||||
if self.buf.len() >= total {
|
||||
let body = &self.buf[4..total];
|
||||
match serde_json::from_slice(body) {
|
||||
Ok(envelope) => {
|
||||
self.buf.drain(..total);
|
||||
return Ok(Some(envelope));
|
||||
}
|
||||
Err(e) => {
|
||||
self.buf.drain(..total);
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut tmp = [0u8; 4096];
|
||||
match self.stream.read(&mut tmp).await {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => self.buf.extend_from_slice(&tmp[..n]),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FrameFramedWriter<S> {
|
||||
stream: S,
|
||||
}
|
||||
|
||||
impl<S> FrameFramedWriter<S>
|
||||
where
|
||||
S: AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self { stream }
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, envelope: &EventEnvelope) -> io::Result<()> {
|
||||
let frame = encode(envelope);
|
||||
self.stream.write_all(&frame).await?;
|
||||
self.stream.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::events;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_round_trip() {
|
||||
let envelope = EventEnvelope::new(
|
||||
events::CALL_REQUESTED,
|
||||
"req-1",
|
||||
json!({"namespace": "auth", "operation": "verify"}),
|
||||
);
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_starts_with_length_prefix() {
|
||||
let envelope = EventEnvelope::new(events::CALL_REQUESTED, "req-1", json!({}));
|
||||
let frame = encode(&envelope);
|
||||
let json = serde_json::to_vec(&envelope).unwrap();
|
||||
let expected_len = json.len() as u32;
|
||||
let stored_len = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
|
||||
assert_eq!(stored_len, expected_len);
|
||||
assert_eq!(frame.len(), 4 + json.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_too_short() {
|
||||
let data = [0u8; 2];
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: 2
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_incomplete() {
|
||||
let len = 100u32;
|
||||
let mut data = Vec::new();
|
||||
data.extend_from_slice(&len.to_be_bytes());
|
||||
data.extend_from_slice(&[0u8; 10]);
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
FrameDecodeError::Incomplete {
|
||||
expected: 104,
|
||||
actual: 14
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_invalid_json() {
|
||||
let json = b"not valid json";
|
||||
let mut data = Vec::new();
|
||||
data.extend_from_slice(&(json.len() as u32).to_be_bytes());
|
||||
data.extend_from_slice(json);
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), FrameDecodeError::Json(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_with_remainder() {
|
||||
let envelope = EventEnvelope::new(events::CALL_RESPONDED, "req-1", json!({"result": 42}));
|
||||
let frame = encode(&envelope);
|
||||
let mut extended = frame.clone();
|
||||
extended.extend_from_slice(&[0u8; 50]);
|
||||
let (decoded, consumed) = decode_with_remainder(&extended).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
assert_eq!(consumed, frame.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_empty_payload() {
|
||||
let envelope = EventEnvelope::new(events::CALL_COMPLETED, "req-1", json!(null));
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_large_payload() {
|
||||
let large_data: Vec<i32> = (0..1000).collect();
|
||||
let envelope = EventEnvelope::new(events::CALL_RESPONDED, "req-big", json!(large_data));
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_with_remainder_too_short() {
|
||||
let data = [0u8; 1];
|
||||
let result = decode_with_remainder(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//! Call protocol layer (Layer 3) of the three-layer model.
|
||||
//!
|
||||
//! See [ADR-024](docs/architecture/decisions/024-call-protocol.md) and
|
||||
//! [ADR-033](docs/architecture/decisions/033-call-protocol-extensions.md).
|
||||
|
||||
pub mod context;
|
||||
pub mod env;
|
||||
pub mod envelope;
|
||||
pub mod events;
|
||||
pub mod frame;
|
||||
pub mod pending;
|
||||
pub mod registry;
|
||||
pub mod response;
|
||||
pub mod services;
|
||||
pub mod spec;
|
||||
|
||||
pub use context::OperationContext;
|
||||
pub use env::OperationEnv;
|
||||
pub use envelope::EventEnvelope;
|
||||
pub use events::{CALL_ABORTED, CALL_COMPLETED, CALL_ERROR, CALL_REQUESTED, CALL_RESPONDED};
|
||||
pub use frame::{
|
||||
decode, decode_with_remainder, encode, FrameDecodeError, FrameFramedReader, FrameFramedWriter,
|
||||
};
|
||||
pub use pending::PendingRequestMap;
|
||||
pub use registry::{Handler, OperationRegistry, OperationRegistryBuilder};
|
||||
pub use response::{CallError, ResponseEnvelope};
|
||||
pub use services::{register_default_operations, services_list_spec, services_schema_spec};
|
||||
pub use spec::{AccessControl, OperationSpec, OperationType};
|
||||
@@ -1,265 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::call::response::CallError;
|
||||
|
||||
enum PendingEntry {
|
||||
Call {
|
||||
tx: oneshot::Sender<Result<Value, CallError>>,
|
||||
timeout: Instant,
|
||||
},
|
||||
Subscribe {
|
||||
tx: mpsc::Sender<Result<Value, CallError>>,
|
||||
timeout: Option<Instant>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct PendingRequestMap {
|
||||
pending: HashMap<String, PendingEntry>,
|
||||
}
|
||||
|
||||
impl PendingRequestMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_call(
|
||||
&mut self,
|
||||
request_id: impl Into<String>,
|
||||
tx: oneshot::Sender<Result<Value, CallError>>,
|
||||
timeout: Instant,
|
||||
) {
|
||||
self.pending
|
||||
.insert(request_id.into(), PendingEntry::Call { tx, timeout });
|
||||
}
|
||||
|
||||
pub fn insert_subscribe(
|
||||
&mut self,
|
||||
request_id: impl Into<String>,
|
||||
tx: mpsc::Sender<Result<Value, CallError>>,
|
||||
timeout: Option<Instant>,
|
||||
) {
|
||||
self.pending
|
||||
.insert(request_id.into(), PendingEntry::Subscribe { tx, timeout });
|
||||
}
|
||||
|
||||
pub fn resolve_call(&mut self, request_id: &str, value: Result<Value, CallError>) -> bool {
|
||||
if let Some(PendingEntry::Call { tx, .. }) = self.pending.remove(request_id) {
|
||||
let _ = tx.send(value);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_subscribe(&mut self, request_id: &str, value: Result<Value, CallError>) -> bool {
|
||||
match self.pending.get_mut(request_id) {
|
||||
Some(PendingEntry::Subscribe { tx, .. }) => tx.try_send(value).is_ok(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete_subscribe(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn abort(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn contains(&self, request_id: &str) -> bool {
|
||||
self.pending.contains_key(request_id)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.pending.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.pending.is_empty()
|
||||
}
|
||||
|
||||
pub fn sweep_expired(&mut self, now: Instant) -> usize {
|
||||
let expired: Vec<String> = self
|
||||
.pending
|
||||
.iter()
|
||||
.filter(|(_, entry)| match entry {
|
||||
PendingEntry::Call { timeout, .. } => *timeout <= now,
|
||||
PendingEntry::Subscribe { timeout, .. } => timeout.is_some_and(|t| t <= now),
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
let count = expired.len();
|
||||
for id in &expired {
|
||||
self.pending.remove(id);
|
||||
}
|
||||
count
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PendingRequestMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_insert_and_resolve_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-1", tx, timeout);
|
||||
assert!(map.contains("req-1"));
|
||||
assert_eq!(map.len(), 1);
|
||||
|
||||
let result = map.resolve_call("req-1", Ok(serde_json::json!({"status": "ok"})));
|
||||
assert!(result);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let response = rx.await.unwrap();
|
||||
assert!(response.is_ok());
|
||||
assert_eq!(response.unwrap(), serde_json::json!({"status": "ok"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_resolve_unknown_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let result = map.resolve_call("unknown", Ok(serde_json::json!(null)));
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_insert_and_push_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, mut rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-1", tx, None);
|
||||
assert!(map.contains("sub-1"));
|
||||
|
||||
let pushed = map.push_subscribe("sub-1", Ok(serde_json::json!({"item": 1})));
|
||||
assert!(pushed);
|
||||
|
||||
let response = rx.recv().await.unwrap();
|
||||
assert!(response.is_ok());
|
||||
assert_eq!(response.unwrap(), serde_json::json!({"item": 1}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_complete_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, mut rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-1", tx, None);
|
||||
|
||||
map.push_subscribe("sub-1", Ok(serde_json::json!({"item": 1})));
|
||||
let completed = map.complete_subscribe("sub-1");
|
||||
assert!(completed);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let _ = rx.recv().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_abort_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, _rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-1", tx, timeout);
|
||||
|
||||
let aborted = map.abort("req-1");
|
||||
assert!(aborted);
|
||||
assert!(map.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_abort_unknown() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let aborted = map.abort("unknown");
|
||||
assert!(!aborted);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_sweep_expired() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx1, _rx1) = oneshot::channel();
|
||||
let (tx2, _rx2) = oneshot::channel();
|
||||
let past = Instant::now() - Duration::from_secs(1);
|
||||
let future = Instant::now() + Duration::from_secs(30);
|
||||
|
||||
map.insert_call("expired-1", tx1, past);
|
||||
map.insert_call("active-1", tx2, future);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 1);
|
||||
assert!(!map.contains("expired-1"));
|
||||
assert!(map.contains("active-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_sweep_subscribe_with_timeout() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx1, _rx1) = mpsc::channel(16);
|
||||
let (tx2, _rx2) = mpsc::channel(16);
|
||||
let past = Some(Instant::now() - Duration::from_secs(1));
|
||||
let future = Some(Instant::now() + Duration::from_secs(30));
|
||||
|
||||
map.insert_subscribe("expired-sub", tx1, past);
|
||||
map.insert_subscribe("active-sub", tx2, future);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 1);
|
||||
assert!(!map.contains("expired-sub"));
|
||||
assert!(map.contains("active-sub"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_subscribe_no_timeout_not_swept() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, _rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-no-timeout", tx, None);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 0);
|
||||
assert!(map.contains("sub-no-timeout"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_push_unknown_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let pushed = map.push_subscribe("unknown", Ok(serde_json::json!(null)));
|
||||
assert!(!pushed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_call_error_response() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-err", tx, timeout);
|
||||
|
||||
let result = map.resolve_call(
|
||||
"req-err",
|
||||
Err(CallError {
|
||||
code: "TIMEOUT".to_string(),
|
||||
message: "request timed out".to_string(),
|
||||
retryable: true,
|
||||
}),
|
||||
);
|
||||
assert!(result);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let response = rx.await.unwrap();
|
||||
assert!(response.is_err());
|
||||
let err = response.unwrap_err();
|
||||
assert_eq!(err.code, "TIMEOUT");
|
||||
assert!(err.retryable);
|
||||
}
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::call::spec::OperationSpec;
|
||||
|
||||
pub type Handler = Arc<dyn Fn(Value, OperationContext) -> ResponseEnvelope + Send + Sync>;
|
||||
|
||||
pub struct OperationRegistry {
|
||||
operations: HashMap<String, (OperationSpec, Handler)>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OperationRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OperationRegistry")
|
||||
.field("operation_count", &self.operations.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, spec: OperationSpec, handler: Handler) {
|
||||
self.operations.insert(spec.name.clone(), (spec, handler));
|
||||
}
|
||||
|
||||
pub fn lookup(&self, name: &str) -> Option<(&OperationSpec, &Handler)> {
|
||||
self.operations
|
||||
.get(name)
|
||||
.map(|(spec, handler)| (spec, handler))
|
||||
}
|
||||
|
||||
pub fn invoke(&self, name: &str, input: Value, context: OperationContext) -> ResponseEnvelope {
|
||||
match self.lookup(name) {
|
||||
Some((spec, handler)) => {
|
||||
if !context.trusted {
|
||||
if let Some(ref identity) = context.identity {
|
||||
if !spec.access_control.check(identity) {
|
||||
return ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"FORBIDDEN",
|
||||
"access denied",
|
||||
false,
|
||||
);
|
||||
}
|
||||
} else if spec.access_control.has_restrictions() {
|
||||
return ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"FORBIDDEN",
|
||||
"authentication required",
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
handler(input, context)
|
||||
}
|
||||
None => ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"NOT_FOUND",
|
||||
format!("operation not found: {name}"),
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_operations(&self) -> Vec<&OperationSpec> {
|
||||
self.operations.values().map(|(spec, _)| spec).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OperationRegistryBuilder {
|
||||
registry: OperationRegistry,
|
||||
}
|
||||
|
||||
impl OperationRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registry: OperationRegistry::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with(mut self, spec: OperationSpec, handler: Handler) -> Self {
|
||||
self.registry.register(spec, handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> OperationRegistry {
|
||||
self.registry
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistryBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::Identity;
|
||||
use crate::call::env::OperationEnv;
|
||||
use crate::call::spec::{AccessControl, OperationType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_spec(name: &str, namespace: &str) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn make_spec_with_acl(name: &str, namespace: &str, acl: AccessControl) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Mutation,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: acl,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_context(request_id: &str, identity: Option<Identity>) -> OperationContext {
|
||||
let registry = OperationRegistry::new();
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_and_lookup() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let spec = make_spec("fs/readFile", "fs");
|
||||
let handler: Handler = Arc::new(|input, _ctx| ResponseEnvelope::ok("req-1", input));
|
||||
registry.register(spec, handler);
|
||||
let found = registry.lookup("fs/readFile");
|
||||
assert!(found.is_some());
|
||||
let (spec, _) = found.unwrap();
|
||||
assert_eq!(spec.name, "fs/readFile");
|
||||
assert_eq!(spec.namespace, "fs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lookup_missing_returns_none() {
|
||||
let registry = OperationRegistry::new();
|
||||
assert!(registry.lookup("missing").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_operation() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let spec = make_spec("fs/readFile", "fs");
|
||||
let handler: Handler = Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input));
|
||||
registry.register(spec, handler);
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("fs/readFile", serde_json::json!({"path": "/tmp"}), context);
|
||||
assert!(result.result.is_ok());
|
||||
assert_eq!(result.result.unwrap(), serde_json::json!({"path": "/tmp"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_missing_operation() {
|
||||
let registry = OperationRegistry::new();
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("missing", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_with_acl_check_allowed() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let context = make_context("req-1", Some(identity));
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_with_acl_check_denied() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let context = make_context("req-1", Some(identity));
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "FORBIDDEN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_trusted_skips_acl() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let mut registry2 = OperationRegistry::new();
|
||||
let context = OperationContext {
|
||||
request_id: "req-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: Some(identity),
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry2),
|
||||
trusted: true,
|
||||
};
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_no_identity_with_acl_denied() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "FORBIDDEN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_operations() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(
|
||||
make_spec("fs/readFile", "fs"),
|
||||
Arc::new(|_, ctx| ResponseEnvelope::ok(&ctx.request_id, serde_json::json!(null))),
|
||||
);
|
||||
registry.register(
|
||||
make_spec("bash/exec", "bash"),
|
||||
Arc::new(|_, ctx| ResponseEnvelope::ok(&ctx.request_id, serde_json::json!(null))),
|
||||
);
|
||||
let ops = registry.list_operations();
|
||||
assert_eq!(ops.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_builder() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("fs/readFile", "fs"),
|
||||
Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input)),
|
||||
)
|
||||
.with(
|
||||
make_spec("bash/exec", "bash"),
|
||||
Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input)),
|
||||
)
|
||||
.build();
|
||||
assert!(registry.lookup("fs/readFile").is_some());
|
||||
assert!(registry.lookup("bash/exec").is_some());
|
||||
}
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub struct CallError {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
pub retryable: bool,
|
||||
}
|
||||
|
||||
impl CallError {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>, retryable: bool) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
retryable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponseEnvelope {
|
||||
pub request_id: String,
|
||||
pub result: Result<Value, CallError>,
|
||||
}
|
||||
|
||||
impl ResponseEnvelope {
|
||||
pub fn ok(request_id: impl Into<String>, value: Value) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Ok(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(
|
||||
request_id: impl Into<String>,
|
||||
code: impl Into<String>,
|
||||
message: impl Into<String>,
|
||||
retryable: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Err(CallError {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
retryable,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn call_error_fields() {
|
||||
let err = CallError {
|
||||
code: "NOT_FOUND".to_string(),
|
||||
message: "operation not found".to_string(),
|
||||
retryable: false,
|
||||
};
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
assert_eq!(err.message, "operation not found");
|
||||
assert!(!err.retryable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_ok() {
|
||||
let env = ResponseEnvelope::ok("req-1", json!({"status": "ok"}));
|
||||
assert_eq!(env.request_id, "req-1");
|
||||
assert!(env.result.is_ok());
|
||||
assert_eq!(env.result.unwrap(), json!({"status": "ok"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_err() {
|
||||
let env = ResponseEnvelope::err("req-1", "NOT_FOUND", "operation not found", false);
|
||||
assert_eq!(env.request_id, "req-1");
|
||||
assert!(env.result.is_err());
|
||||
let err = env.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
assert_eq!(err.message, "operation not found");
|
||||
assert!(!err.retryable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_serialization() {
|
||||
let env = ResponseEnvelope::ok("req-1", json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: ResponseEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.request_id, "req-1");
|
||||
assert!(deserialized.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_err_serialization() {
|
||||
let env = ResponseEnvelope::err("req-2", "TIMEOUT", "timed out", true);
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: ResponseEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.request_id, "req-2");
|
||||
let err = deserialized.result.unwrap_err();
|
||||
assert_eq!(err.code, "TIMEOUT");
|
||||
assert!(err.retryable);
|
||||
}
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::call::spec::{AccessControl, OperationSpec, OperationType};
|
||||
|
||||
pub fn services_list_spec() -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: super::events::SERVICE_LIST.to_string(),
|
||||
namespace: "services".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}),
|
||||
output_schema: serde_json::json!({
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": { "type": "string" },
|
||||
},
|
||||
},
|
||||
}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn services_schema_spec() -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: super::events::SERVICE_SCHEMA.to_string(),
|
||||
namespace: "services".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
},
|
||||
"required": ["name"],
|
||||
}),
|
||||
output_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": { "type": "string" },
|
||||
"input_schema": { "type": "object" },
|
||||
"output_schema": { "type": "object" },
|
||||
},
|
||||
}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_default_operations(registry: &mut crate::call::OperationRegistry) {
|
||||
registry.register(services_list_spec(), Arc::new(services_list_handler));
|
||||
registry.register(services_schema_spec(), Arc::new(services_schema_handler));
|
||||
}
|
||||
|
||||
fn services_list_handler(_input: Value, ctx: OperationContext) -> ResponseEnvelope {
|
||||
let registry = &ctx.env.registry_ref();
|
||||
let specs = registry.list_operations();
|
||||
let ops: Vec<Value> = specs
|
||||
.iter()
|
||||
.map(|spec| {
|
||||
serde_json::json!({
|
||||
"name": spec.name,
|
||||
"namespace": spec.namespace,
|
||||
"op_type": format!("{:?}", spec.op_type).to_lowercase(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!({ "operations": ops }))
|
||||
}
|
||||
|
||||
fn services_schema_handler(input: Value, ctx: OperationContext) -> ResponseEnvelope {
|
||||
let name = match input.get("name").and_then(|v| v.as_str()) {
|
||||
Some(n) => n.to_string(),
|
||||
None => {
|
||||
return ResponseEnvelope::err(
|
||||
&ctx.request_id,
|
||||
"INVALID_INPUT",
|
||||
"missing required field: name",
|
||||
false,
|
||||
);
|
||||
}
|
||||
};
|
||||
let registry = &ctx.env.registry_ref();
|
||||
match registry.lookup(&name) {
|
||||
Some((spec, _)) => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({
|
||||
"name": spec.name,
|
||||
"namespace": spec.namespace,
|
||||
"op_type": format!("{:?}", spec.op_type).to_lowercase(),
|
||||
"input_schema": spec.input_schema,
|
||||
"output_schema": spec.output_schema,
|
||||
}),
|
||||
),
|
||||
None => ResponseEnvelope::err(
|
||||
&ctx.request_id,
|
||||
"NOT_FOUND",
|
||||
format!("operation not found: {name}"),
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::env::OperationEnv;
|
||||
|
||||
fn make_env() -> OperationEnv {
|
||||
let mut registry = crate::call::OperationRegistry::new();
|
||||
registry.register(services_list_spec(), Arc::new(services_list_handler));
|
||||
registry.register(services_schema_spec(), Arc::new(services_schema_handler));
|
||||
OperationEnv::local(registry)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_returns_operations() {
|
||||
let env = make_env();
|
||||
let result = env.invoke("services", "list", serde_json::json!({}));
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
let ops = value.get("operations").unwrap().as_array().unwrap();
|
||||
assert_eq!(ops.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_returns_spec() {
|
||||
let env = make_env();
|
||||
let result = env.invoke(
|
||||
"services",
|
||||
"schema",
|
||||
serde_json::json!({"name": "/services/list"}),
|
||||
);
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
assert_eq!(value["name"], "/services/list");
|
||||
assert_eq!(value["namespace"], "services");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_missing_name() {
|
||||
let env = make_env();
|
||||
let result = env.invoke("services", "schema", serde_json::json!({}));
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "INVALID_INPUT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_not_found() {
|
||||
let env = make_env();
|
||||
let result = env.invoke(
|
||||
"services",
|
||||
"schema",
|
||||
serde_json::json!({"name": "/nonexistent/op"}),
|
||||
);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_spec_fields() {
|
||||
let spec = services_list_spec();
|
||||
assert_eq!(spec.name, "/services/list");
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_spec_fields() {
|
||||
let spec = services_schema_spec();
|
||||
assert_eq!(spec.name, "/services/schema");
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_default_operations_adds_both() {
|
||||
let mut registry = crate::call::OperationRegistry::new();
|
||||
register_default_operations(&mut registry);
|
||||
assert!(registry.lookup("/services/list").is_some());
|
||||
assert!(registry.lookup("/services/schema").is_some());
|
||||
assert_eq!(registry.list_operations().len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
//! Operation specifications (type, access control) for the call protocol.
|
||||
//!
|
||||
//! See [ADR-025](docs/architecture/decisions/025-operation-spec.md) and
|
||||
//! [ADR-033](docs/architecture/decisions/033-call-protocol-extensions.md).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum OperationType {
|
||||
Query,
|
||||
Mutation,
|
||||
Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessControl {
|
||||
pub required_scopes: Vec<String>,
|
||||
pub required_scopes_any: Option<Vec<String>>,
|
||||
pub resource_type: Option<String>,
|
||||
pub resource_action: Option<String>,
|
||||
}
|
||||
|
||||
impl AccessControl {
|
||||
pub fn check(&self, identity: &crate::auth::Identity) -> bool {
|
||||
for scope in &self.required_scopes {
|
||||
if !identity.scopes.contains(scope) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(any) = &self.required_scopes_any {
|
||||
if !any.iter().any(|s| identity.scopes.contains(s)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(res_type) = &self.resource_type {
|
||||
if let Some(actions) = identity.resources.get(res_type) {
|
||||
if let Some(action) = &self.resource_action {
|
||||
if !actions.contains(action) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn has_restrictions(&self) -> bool {
|
||||
!self.required_scopes.is_empty()
|
||||
|| self.required_scopes_any.is_some()
|
||||
|| self.resource_type.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OperationSpec {
|
||||
pub name: String,
|
||||
pub namespace: String,
|
||||
pub op_type: OperationType,
|
||||
pub input_schema: Value,
|
||||
pub output_schema: Value,
|
||||
pub access_control: AccessControl,
|
||||
}
|
||||
|
||||
impl OperationSpec {
|
||||
pub fn path(&self) -> String {
|
||||
format!("/{}", self.name)
|
||||
}
|
||||
|
||||
pub fn namespace_from_name(name: &str) -> String {
|
||||
let trimmed = name.trim_start_matches('/');
|
||||
let parts: Vec<&str> = trimmed.split('/').collect();
|
||||
match parts.len() {
|
||||
n if n >= 3 => parts[1].to_string(),
|
||||
n if n >= 2 => parts[0].to_string(),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_identity(
|
||||
scopes: Vec<String>,
|
||||
resources: HashMap<String, Vec<String>>,
|
||||
) -> crate::auth::Identity {
|
||||
crate::auth::Identity {
|
||||
id: "test".to_string(),
|
||||
scopes,
|
||||
resources,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_allows_matching_scopes() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_rejects_missing_scopes() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_required_scopes_any_matches() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: Some(vec!["admin".to_string(), "read".to_string()]),
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_required_scopes_any_rejects() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: Some(vec!["admin".to_string()]),
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_matches() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["read".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], resources);
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_missing_resource_type() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_missing_action() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["write".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], resources);
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_combined_scopes_and_resources() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["read".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["relay:connect".to_string()],
|
||||
required_scopes_any: Some(vec!["admin".to_string()]),
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(
|
||||
vec!["relay:connect".to_string(), "admin".to_string()],
|
||||
resources,
|
||||
);
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_type_variants() {
|
||||
assert_eq!(OperationType::Query, OperationType::Query);
|
||||
assert_ne!(OperationType::Query, OperationType::Mutation);
|
||||
assert_ne!(OperationType::Mutation, OperationType::Subscription);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_spec_namespace_from_name() {
|
||||
assert_eq!(OperationSpec::namespace_from_name("/auth/verify"), "auth");
|
||||
assert_eq!(OperationSpec::namespace_from_name("/fs/readFile"), "fs");
|
||||
assert_eq!(
|
||||
OperationSpec::namespace_from_name("/head/agent/chat"),
|
||||
"agent"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_spec_path() {
|
||||
let spec = OperationSpec {
|
||||
name: "auth/verify".to_string(),
|
||||
namespace: "auth".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
};
|
||||
assert_eq!(spec.path(), "/auth/verify");
|
||||
}
|
||||
}
|
||||
@@ -1,468 +0,0 @@
|
||||
//! Channel manager with automatic reconnection.
|
||||
//!
|
||||
//! Owns the SSH session handle and provides `open_direct_tcpip()`,
|
||||
//! `request_tcpip_forward()`, and `cancel_tcpip_forward()`. Monitors
|
||||
//! the session for disconnect and attempts reconnection with exponential
|
||||
//! backoff (1s, 2s, 4s, ..., 30s cap). Re-registers remote forwards
|
||||
//! after successful reconnection.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use russh::client;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use crate::error::ChannelError;
|
||||
use crate::transport::Transport;
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct ForwardRequest {
|
||||
pub addr: String,
|
||||
pub port: u32,
|
||||
}
|
||||
|
||||
struct ChannelManagerInner<T: Transport> {
|
||||
transport: Arc<T>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
handle: Arc<RwLock<client::Handle<ClientHandler>>>,
|
||||
username: String,
|
||||
forwards: RwLock<HashSet<ForwardRequest>>,
|
||||
reconnect_attempts: RwLock<u32>,
|
||||
}
|
||||
|
||||
pub struct ChannelManager<T: Transport> {
|
||||
inner: Arc<ChannelManagerInner<T>>,
|
||||
reconnect_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
impl<T: Transport> ChannelManager<T> {
|
||||
pub async fn new(
|
||||
transport: Arc<T>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
username: String,
|
||||
) -> Result<Self, ChannelError> {
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
let handle = Self::establish_session(&*transport, handler, &auth_config, &username)
|
||||
.await
|
||||
.map_err(|_| ChannelError::TargetUnreachable)?;
|
||||
|
||||
let inner = Arc::new(ChannelManagerInner {
|
||||
transport,
|
||||
auth_config,
|
||||
handle: Arc::new(RwLock::new(handle)),
|
||||
username,
|
||||
forwards: RwLock::new(HashSet::new()),
|
||||
reconnect_attempts: RwLock::new(0),
|
||||
});
|
||||
|
||||
let reconnect_handle = Arc::new(RwLock::new(None));
|
||||
let manager = Self {
|
||||
inner,
|
||||
reconnect_handle,
|
||||
};
|
||||
|
||||
manager.start_reconnect_monitor();
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
async fn establish_session(
|
||||
transport: &T,
|
||||
handler: ClientHandler,
|
||||
auth_config: &ClientAuthConfig,
|
||||
username: &str,
|
||||
) -> Result<client::Handle<ClientHandler>, russh::Error> {
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
russh::Error::SendError
|
||||
})?;
|
||||
|
||||
let config = Arc::new(russh::client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler).await?;
|
||||
|
||||
let auth_ok = auth_config.authenticate(&mut handle, username).await?;
|
||||
if !auth_ok {
|
||||
return Err(russh::Error::SendError);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
pub async fn open_direct_tcpip(
|
||||
&self,
|
||||
host: &str,
|
||||
port: u32,
|
||||
) -> Result<russh::Channel<russh::client::Msg>, ChannelError> {
|
||||
let handle = self.inner.handle.read().await;
|
||||
handle
|
||||
.channel_open_direct_tcpip(host, port, "127.0.0.1", 0)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
debug!("channel open failed: {e}");
|
||||
ChannelError::ChannelClosed
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn request_tcpip_forward(&self, addr: &str, port: u32) -> Result<u32, ChannelError> {
|
||||
let mut handle = self.inner.handle.write().await;
|
||||
let result = handle
|
||||
.tcpip_forward(addr, port)
|
||||
.await
|
||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||
|
||||
self.inner.forwards.write().await.insert(ForwardRequest {
|
||||
addr: addr.to_string(),
|
||||
port,
|
||||
});
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn cancel_tcpip_forward(&self, addr: &str, port: u32) -> Result<(), ChannelError> {
|
||||
let handle = self.inner.handle.read().await;
|
||||
handle
|
||||
.cancel_tcpip_forward(addr, port)
|
||||
.await
|
||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||
|
||||
self.inner.forwards.write().await.remove(&ForwardRequest {
|
||||
addr: addr.to_string(),
|
||||
port,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
let handle = self.inner.handle.read().await;
|
||||
!handle.is_closed()
|
||||
}
|
||||
|
||||
fn start_reconnect_monitor(&self) {
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let handle_arc = Arc::clone(&self.inner.handle);
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
time::sleep(Duration::from_secs(1)).await;
|
||||
let handle = handle_arc.read().await;
|
||||
if handle.is_closed() {
|
||||
drop(handle);
|
||||
info!("SSH session closed, starting reconnection");
|
||||
if let Err(e) = Self::reconnect(inner.clone()).await {
|
||||
error!("reconnection failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let reconnect_handle = Arc::clone(&self.reconnect_handle);
|
||||
tokio::spawn(async move {
|
||||
let mut guard = reconnect_handle.write().await;
|
||||
*guard = Some(join_handle);
|
||||
});
|
||||
}
|
||||
|
||||
async fn reconnect(inner: Arc<ChannelManagerInner<T>>) -> Result<(), ChannelError> {
|
||||
let mut attempts = inner.reconnect_attempts.write().await;
|
||||
let attempt_num = *attempts;
|
||||
let backoff = backoff_duration(attempt_num);
|
||||
*attempts += 1;
|
||||
drop(attempts);
|
||||
|
||||
warn!(
|
||||
"reconnect attempt #{}, waiting {:?}",
|
||||
attempt_num + 1,
|
||||
backoff
|
||||
);
|
||||
time::sleep(backoff).await;
|
||||
|
||||
let handler = ClientHandler::from_config(&inner.auth_config);
|
||||
match Self::establish_session(
|
||||
&*inner.transport,
|
||||
handler,
|
||||
&inner.auth_config,
|
||||
&inner.username,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(new_handle) => {
|
||||
info!("reconnection successful");
|
||||
{
|
||||
let mut handle_guard = inner.handle.write().await;
|
||||
*handle_guard = new_handle;
|
||||
}
|
||||
{
|
||||
let mut attempts = inner.reconnect_attempts.write().await;
|
||||
*attempts = 0;
|
||||
}
|
||||
Self::re_register_forwards(&inner).await;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("reconnection attempt failed: {e}");
|
||||
Err(ChannelError::ChannelClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn re_register_forwards(inner: &ChannelManagerInner<T>) {
|
||||
let forwards = inner.forwards.read().await;
|
||||
if forwards.is_empty() {
|
||||
return;
|
||||
}
|
||||
let mut handle = inner.handle.write().await;
|
||||
for fwd in forwards.iter() {
|
||||
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
||||
Ok(_) => {
|
||||
debug!("re-registered tcpip_forward: {}:{}", fwd.addr, fwd.port);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"failed to re-register tcpip_forward {}:{}: {e}",
|
||||
fwd.addr, fwd.port
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely.
|
||||
fn backoff_duration(attempt: u32) -> Duration {
|
||||
let secs: u64 = match attempt {
|
||||
0 => 1,
|
||||
1 => 2,
|
||||
2 => 4,
|
||||
3 => 8,
|
||||
4 => 16,
|
||||
_ => 30,
|
||||
};
|
||||
Duration::from_secs(secs)
|
||||
}
|
||||
|
||||
impl<T: Transport> Drop for ChannelManager<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut guard) = self.reconnect_handle.try_write() {
|
||||
if let Some(handle) = guard.take() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::io::duplex;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn make_auth_config() -> Arc<ClientAuthConfig> {
|
||||
let source = crate::auth::keys::KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
Arc::new(ClientAuthConfig::from_key_source(source).unwrap())
|
||||
}
|
||||
|
||||
struct AlwaysFailTransport;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for AlwaysFailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
Err(anyhow::anyhow!("always fails"))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"always-fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct TrackConnectTransport {
|
||||
connect_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl TrackConnectTransport {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for TrackConnectTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"track-connect".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct CountingFailTransport {
|
||||
fail_count: Arc<AtomicUsize>,
|
||||
succeed_after: usize,
|
||||
}
|
||||
|
||||
impl CountingFailTransport {
|
||||
fn new(succeed_after: usize) -> Self {
|
||||
Self {
|
||||
fail_count: Arc::new(AtomicUsize::new(0)),
|
||||
succeed_after,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for CountingFailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
|
||||
if count < self.succeed_after {
|
||||
return Err(anyhow::anyhow!("connection failed (attempt {})", count));
|
||||
}
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"counting-fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_durations() {
|
||||
assert_eq!(backoff_duration(0), Duration::from_secs(1));
|
||||
assert_eq!(backoff_duration(1), Duration::from_secs(2));
|
||||
assert_eq!(backoff_duration(2), Duration::from_secs(4));
|
||||
assert_eq!(backoff_duration(3), Duration::from_secs(8));
|
||||
assert_eq!(backoff_duration(4), Duration::from_secs(16));
|
||||
assert_eq!(backoff_duration(5), Duration::from_secs(30));
|
||||
assert_eq!(backoff_duration(6), Duration::from_secs(30));
|
||||
assert_eq!(backoff_duration(100), Duration::from_secs(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_sequence_matches_spec() {
|
||||
let sequence: Vec<Duration> = (0..6).map(backoff_duration).collect();
|
||||
assert_eq!(
|
||||
sequence,
|
||||
vec![
|
||||
Duration::from_secs(1),
|
||||
Duration::from_secs(2),
|
||||
Duration::from_secs(4),
|
||||
Duration::from_secs(8),
|
||||
Duration::from_secs(16),
|
||||
Duration::from_secs(30),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_request_hash_eq() {
|
||||
let fwd1 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
};
|
||||
let fwd2 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
};
|
||||
let fwd3 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
};
|
||||
assert_eq!(fwd1, fwd2);
|
||||
assert_ne!(fwd1, fwd3);
|
||||
let mut set = HashSet::new();
|
||||
set.insert(fwd1.clone());
|
||||
assert!(set.contains(&fwd2));
|
||||
assert!(!set.contains(&fwd3));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_manager_new_transport_fails() {
|
||||
let auth = make_auth_config();
|
||||
let transport = Arc::new(AlwaysFailTransport);
|
||||
let result = ChannelManager::new(transport, auth, "testuser".to_string()).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(ChannelError::TargetUnreachable) => {}
|
||||
other => panic!("expected TargetUnreachable, got {:?}", other.as_ref().err()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_connect_called_on_new() {
|
||||
let transport = Arc::new(TrackConnectTransport::new());
|
||||
let connect_before = transport.connect_count.load(Ordering::SeqCst);
|
||||
assert_eq!(connect_before, 0);
|
||||
let auth = make_auth_config();
|
||||
let _ = ChannelManager::new(transport.clone(), auth, "testuser".to_string()).await;
|
||||
let connect_after = transport.connect_count.load(Ordering::SeqCst);
|
||||
assert!(connect_after > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnect_monitor_detects_closed_handle() {
|
||||
let auth = make_auth_config();
|
||||
let transport = Arc::new(TrackConnectTransport::new());
|
||||
let handler = ClientHandler::from_config(&auth);
|
||||
let config = Arc::new(russh::client::Config::default());
|
||||
let stream = transport.connect().await.unwrap();
|
||||
let handle = client::connect_stream(config, stream, handler).await;
|
||||
match handle {
|
||||
Ok(h) => {
|
||||
assert!(!h.is_closed());
|
||||
drop(h);
|
||||
}
|
||||
Err(_) => {
|
||||
// connect_stream fails without a real SSH server,
|
||||
// but the concept is verified: dropped handle => is_closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_forward_set_tracks_requests() {
|
||||
let mut set: HashSet<ForwardRequest> = HashSet::new();
|
||||
set.insert(ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
});
|
||||
set.insert(ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
});
|
||||
assert_eq!(set.len(), 2);
|
||||
set.remove(&ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
});
|
||||
assert_eq!(set.len(), 1);
|
||||
assert!(set.contains(&ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_indefinitely_beyond_cap() {
|
||||
for attempt in 0..50 {
|
||||
let duration = backoff_duration(attempt);
|
||||
assert!(duration <= Duration::from_secs(30));
|
||||
assert!(duration >= Duration::from_secs(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,877 +0,0 @@
|
||||
//! Client session management and connection logic.
|
||||
//!
|
||||
//! `ClientSession` establishes an SSH connection over a transport, authenticates,
|
||||
//! starts a SOCKS5 proxy, sets up port forwards, and monitors for reconnection.
|
||||
//! `ConnectOptions` provides a builder-pattern API for programmatic configuration.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use russh::client;
|
||||
use russh::keys::PrivateKey;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
|
||||
use crate::error::ConfigError;
|
||||
use crate::socks5::{HandleChannelOpener, Socks5Server};
|
||||
use crate::transport::Transport;
|
||||
|
||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
/// Transport mode for the client connection.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TransportMode {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportMode::Tcp => write!(f, "tcp"),
|
||||
TransportMode::Tls => write!(f, "tls"),
|
||||
TransportMode::Iroh => write!(f, "iroh"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Programmatic configuration for an alknet client session.
|
||||
///
|
||||
/// Construct with `ConnectOptions::new(key_source)` and chain builder methods.
|
||||
/// Call `validate()` before passing to `ClientSession::new()`.
|
||||
///
|
||||
/// ```
|
||||
/// use alknet_core::client::{ConnectOptions, TransportMode};
|
||||
/// use alknet_core::auth::keys::KeySource;
|
||||
///
|
||||
/// let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
|
||||
/// .server("example.com:22")
|
||||
/// .transport_mode(TransportMode::Tcp)
|
||||
/// .socks5_addr("127.0.0.1:1080")
|
||||
/// .forward("5432:db.internal:5432");
|
||||
/// opts.validate().unwrap();
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectOptions {
|
||||
pub server: Option<String>,
|
||||
pub peer: Option<String>,
|
||||
pub transport_mode: TransportMode,
|
||||
pub identity: KeySource,
|
||||
pub socks5_addr: String,
|
||||
pub forwards: Vec<String>,
|
||||
pub remote_forwards: Vec<String>,
|
||||
pub proxy: Option<String>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub tls_server_name: Option<String>,
|
||||
pub insecure: bool,
|
||||
}
|
||||
|
||||
impl ConnectOptions {
|
||||
pub fn new(identity: KeySource) -> Self {
|
||||
Self {
|
||||
server: None,
|
||||
peer: None,
|
||||
transport_mode: TransportMode::Tcp,
|
||||
identity,
|
||||
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
|
||||
forwards: Vec::new(),
|
||||
remote_forwards: Vec::new(),
|
||||
proxy: None,
|
||||
iroh_relay: None,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server(mut self, addr: impl Into<String>) -> Self {
|
||||
self.server = Some(addr.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
|
||||
self.peer = Some(endpoint_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
|
||||
self.transport_mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
|
||||
self.socks5_addr = addr.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.remote_forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn proxy(mut self, url: impl Into<String>) -> Self {
|
||||
self.proxy = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
|
||||
self.iroh_relay = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
match self.transport_mode {
|
||||
TransportMode::Tcp | TransportMode::Tls => {
|
||||
if self.server.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--server is required for tcp/tls transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
TransportMode::Iroh => {
|
||||
if self.peer.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--peer is required for iroh transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConnectOptions {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConnectOptions")
|
||||
.field("server", &self.server)
|
||||
.field("peer", &self.peer)
|
||||
.field("transport_mode", &self.transport_mode)
|
||||
.field("identity", &"<KeySource>")
|
||||
.field("socks5_addr", &self.socks5_addr)
|
||||
.field("forwards", &self.forwards)
|
||||
.field("remote_forwards", &self.remote_forwards)
|
||||
.field("proxy", &self.proxy)
|
||||
.field("iroh_relay", &self.iroh_relay)
|
||||
.field("tls_server_name", &self.tls_server_name)
|
||||
.field("insecure", &self.insecure)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// An active SSH client session over a transport.
|
||||
///
|
||||
/// Establishes the connection, authenticates, and runs a SOCKS5 proxy plus
|
||||
/// port forwards until shutdown or transport failure. On transport failure,
|
||||
/// attempts reconnection with exponential backoff (1s, 2s, 4s, ..., 30s cap).
|
||||
pub struct ClientSession<T: Transport> {
|
||||
opts: ConnectOptions,
|
||||
transport: Arc<T>,
|
||||
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
#[allow(dead_code)]
|
||||
private_key: Arc<PrivateKey>,
|
||||
#[allow(dead_code)]
|
||||
username: String,
|
||||
shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl<T: Transport> ClientSession<T> {
|
||||
pub async fn new(opts: ConnectOptions, transport: Arc<T>) -> Result<Self, ConnectError> {
|
||||
opts.validate().map_err(ConnectError::Config)?;
|
||||
|
||||
let auth_config = Arc::new(
|
||||
ClientAuthConfig::from_key_source(opts.identity.clone())
|
||||
.map_err(ConnectError::Config)?,
|
||||
);
|
||||
let private_key = auth_config.private_key();
|
||||
|
||||
let username = derive_username();
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SSH connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, &username)
|
||||
.await
|
||||
.map_err(|_| ConnectError::AuthFailed)?;
|
||||
if !auth_ok {
|
||||
return Err(ConnectError::AuthFailed);
|
||||
}
|
||||
|
||||
let handle = Arc::new(Mutex::new(handle));
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
Ok(Self {
|
||||
opts,
|
||||
transport,
|
||||
handle,
|
||||
auth_config,
|
||||
private_key,
|
||||
username,
|
||||
shutdown_tx,
|
||||
shutdown_rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
|
||||
Arc::clone(&self.handle)
|
||||
}
|
||||
|
||||
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
|
||||
&self.auth_config
|
||||
}
|
||||
|
||||
pub fn transport(&self) -> &Arc<T> {
|
||||
&self.transport
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.opts
|
||||
}
|
||||
|
||||
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
|
||||
self.shutdown_tx.clone()
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), ConnectError> {
|
||||
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
|
||||
})
|
||||
})?;
|
||||
|
||||
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
|
||||
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
|
||||
let socks5_listen = socks5_server.listen_addr();
|
||||
|
||||
let local_forwarders = build_local_forwarders(&self.opts)?;
|
||||
let remote_specs = build_remote_specs(&self.opts)?;
|
||||
|
||||
for spec in &remote_specs {
|
||||
let remote_forwarder =
|
||||
RemoteForwarder::new(spec.clone()).map_err(|_| ConnectError::ForwardFailed)?;
|
||||
let mut h = self.handle.lock().await;
|
||||
remote_forwarder.register(&mut h).await.map_err(|_| {
|
||||
warn!("failed to register remote forward {}", spec);
|
||||
ConnectError::ForwardFailed
|
||||
})?;
|
||||
info!("registered remote forward: {}", spec);
|
||||
}
|
||||
|
||||
let socks5_task = tokio::spawn(async move {
|
||||
debug!("SOCKS5 server starting on {}", socks5_listen);
|
||||
if let Err(e) = socks5_server.run().await {
|
||||
error!("SOCKS5 server error: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
let fwd_handle = Arc::clone(&self.handle);
|
||||
let fwd_shutdown = self.shutdown_rx.clone();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
crate::client::forward::run_local_forwarders(
|
||||
local_forwarders,
|
||||
fwd_handle,
|
||||
fwd_shutdown,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
info!("alknet client running: SOCKS5 on {}", socks5_listen);
|
||||
|
||||
#[cfg(unix)]
|
||||
let signal_done = {
|
||||
let sig_tx = self.shutdown_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut sigterm_stream =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install SIGTERM handler");
|
||||
tokio::select! {
|
||||
_ = sigterm_stream.recv() => {
|
||||
info!("received SIGTERM");
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("received SIGINT (Ctrl+C)");
|
||||
}
|
||||
}
|
||||
let _ = sig_tx.send(true);
|
||||
})
|
||||
};
|
||||
|
||||
let mut wait_shutdown = self.shutdown_rx.clone();
|
||||
let reconnect_handle = Arc::clone(&self.handle);
|
||||
let reconnect_transport = Arc::clone(&self.transport);
|
||||
let reconnect_auth = Arc::clone(&self.auth_config);
|
||||
let reconnect_username = self.username.clone();
|
||||
let reconnect_shutdown = self.shutdown_rx.clone();
|
||||
let reconnect_remote_specs = remote_specs.clone();
|
||||
|
||||
let reconnect_monitor = tokio::spawn(async move {
|
||||
let mut attempts: u32 = 0;
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
if *reconnect_shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
let h = reconnect_handle.lock().await;
|
||||
if h.is_closed() {
|
||||
drop(h);
|
||||
info!("SSH session closed, starting reconnection");
|
||||
let backoff = backoff_duration(attempts);
|
||||
warn!("reconnect attempt #{}, waiting {:?}", attempts + 1, backoff);
|
||||
tokio::time::sleep(backoff).await;
|
||||
|
||||
let handler = ClientHandler::from_config(&reconnect_auth);
|
||||
let username = reconnect_username.clone();
|
||||
match establish_session(
|
||||
&*reconnect_transport,
|
||||
handler,
|
||||
&reconnect_auth,
|
||||
&username,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(new_handle) => {
|
||||
info!("reconnection successful");
|
||||
{
|
||||
let mut guard = reconnect_handle.lock().await;
|
||||
*guard = new_handle;
|
||||
}
|
||||
for spec in &reconnect_remote_specs {
|
||||
match RemoteForwarder::new(spec.clone()) {
|
||||
Ok(rf) => {
|
||||
let mut h = reconnect_handle.lock().await;
|
||||
match rf.register(&mut h).await {
|
||||
Ok(_) => {
|
||||
debug!("re-registered remote forward: {}", spec)
|
||||
}
|
||||
Err(e) => warn!(
|
||||
"failed to re-register remote forward {}: {e}",
|
||||
spec
|
||||
),
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("failed to create remote forwarder: {e}"),
|
||||
}
|
||||
}
|
||||
attempts = 0;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("reconnection attempt failed: {e}");
|
||||
attempts += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = wait_shutdown.changed() => {
|
||||
if *wait_shutdown.borrow() {
|
||||
info!("shutdown signal received");
|
||||
}
|
||||
}
|
||||
_ = socks5_task => {
|
||||
warn!("SOCKS5 server exited unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
reconnect_monitor.abort();
|
||||
|
||||
#[cfg(unix)]
|
||||
signal_done.abort();
|
||||
|
||||
self.shutdown().await?;
|
||||
|
||||
forward_task.abort();
|
||||
let _ = forward_task.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> Result<(), ConnectError> {
|
||||
info!("initiating graceful shutdown");
|
||||
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
|
||||
{
|
||||
let handle = self.handle.lock().await;
|
||||
if !handle.is_closed() {
|
||||
if let Err(e) = handle
|
||||
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
|
||||
.await
|
||||
{
|
||||
warn!("failed to send SSH disconnect: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(DRAIN_TIMEOUT).await;
|
||||
|
||||
info!("graceful shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_username() -> String {
|
||||
std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "alknet".to_string())
|
||||
}
|
||||
|
||||
async fn establish_session<T: Transport>(
|
||||
transport: &T,
|
||||
handler: ClientHandler,
|
||||
auth_config: &ClientAuthConfig,
|
||||
username: &str,
|
||||
) -> Result<client::Handle<ClientHandler>, ConnectError> {
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SSH connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, username)
|
||||
.await
|
||||
.map_err(|_| ConnectError::AuthFailed)?;
|
||||
if !auth_ok {
|
||||
return Err(ConnectError::AuthFailed);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
fn backoff_duration(attempt: u32) -> Duration {
|
||||
let secs: u64 = match attempt {
|
||||
0 => 1,
|
||||
1 => 2,
|
||||
2 => 4,
|
||||
3 => 8,
|
||||
4 => 16,
|
||||
_ => 30,
|
||||
};
|
||||
Duration::from_secs(secs)
|
||||
}
|
||||
|
||||
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
|
||||
let mut forwarders = Vec::new();
|
||||
for spec_str in &opts.forwards {
|
||||
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
|
||||
warn!("invalid local forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
forwarders.push(LocalForwarder::new(spec).map_err(|e| {
|
||||
warn!("failed to create local forwarder: {}", e);
|
||||
ConnectError::ForwardFailed
|
||||
})?);
|
||||
}
|
||||
Ok(forwarders)
|
||||
}
|
||||
|
||||
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
|
||||
let mut specs = Vec::new();
|
||||
for spec_str in &opts.remote_forwards {
|
||||
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
|
||||
warn!("invalid remote forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid remote forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
specs.push(spec);
|
||||
}
|
||||
Ok(specs)
|
||||
}
|
||||
|
||||
/// Errors that can occur during client connection setup and operation.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectError {
|
||||
#[error("connection failed")]
|
||||
ConnectionFailed,
|
||||
#[error("authentication failed")]
|
||||
AuthFailed,
|
||||
#[error("forward setup failed")]
|
||||
ForwardFailed,
|
||||
#[error("config error: {0}")]
|
||||
Config(#[from] ConfigError),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::io::duplex;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn make_identity() -> KeySource {
|
||||
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_default_fields() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(opts.server.is_none());
|
||||
assert!(opts.peer.is_none());
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tcp);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
|
||||
assert!(opts.forwards.is_empty());
|
||||
assert!(opts.remote_forwards.is_empty());
|
||||
assert!(opts.proxy.is_none());
|
||||
assert!(opts.iroh_relay.is_none());
|
||||
assert!(opts.tls_server_name.is_none());
|
||||
assert!(!opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_builder_pattern() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.server("example.com:22")
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.socks5_addr("127.0.0.1:9050")
|
||||
.forward("127.0.0.1:5432:db:5432")
|
||||
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
|
||||
.proxy("socks5://127.0.0.1:1080")
|
||||
.iroh_relay("https://relay.example.com")
|
||||
.tls_server_name("alknet.test")
|
||||
.insecure(true);
|
||||
|
||||
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tls);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
|
||||
assert_eq!(opts.forwards.len(), 1);
|
||||
assert_eq!(opts.remote_forwards.len(), 1);
|
||||
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
||||
assert_eq!(
|
||||
opts.iroh_relay.as_deref(),
|
||||
Some("https://relay.example.com")
|
||||
);
|
||||
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
|
||||
assert!(opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.server("example.com:443");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_requires_peer() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_with_peer_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Iroh)
|
||||
.peer("some-endpoint-id");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_file() {
|
||||
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
|
||||
let opts = ConnectOptions::new(file_source);
|
||||
match &opts.identity {
|
||||
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
|
||||
_ => panic!("expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_memory() {
|
||||
let mem_source = KeySource::Memory(b"key-data".to_vec());
|
||||
let opts = ConnectOptions::new(mem_source);
|
||||
match &opts.identity {
|
||||
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
|
||||
_ => panic!("expected Memory variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_display() {
|
||||
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportMode::Tls.to_string(), "tls");
|
||||
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_error_variants() {
|
||||
assert_eq!(
|
||||
ConnectError::ConnectionFailed.to_string(),
|
||||
"connection failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConnectError::AuthFailed.to_string(),
|
||||
"authentication failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConnectError::ForwardFailed.to_string(),
|
||||
"forward setup failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_debug_redacts_identity() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let debug_str = format!("{:?}", opts);
|
||||
assert!(debug_str.contains("<KeySource>"));
|
||||
assert!(!debug_str.contains("OPENSSH"));
|
||||
}
|
||||
|
||||
struct FailTransport;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
Err(anyhow::anyhow!("always fails"))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct DuplexTransport {
|
||||
connect_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for DuplexTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"duplex".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_transport_fails() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let transport = Arc::new(FailTransport);
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
ConnectError::ConnectionFailed
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_ssh_handshake_fails() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
ConnectError::ConnectionFailed
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_valid() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_invalid_spec() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_valid() {
|
||||
let opts =
|
||||
ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_invalid() {
|
||||
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_socks5_addr() {
|
||||
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drain_timeout_is_two_seconds() {
|
||||
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_equality() {
|
||||
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
|
||||
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
|
||||
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_sends_disconnect_and_drains() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn socks5_is_always_enabled_by_default() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(!opts.socks5_addr.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_mock_transport_session() {
|
||||
use crate::socks5::{ChannelOpenError, ChannelOpener};
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
struct MockOpener;
|
||||
|
||||
impl ChannelOpener for MockOpener {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
let (client, _server) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let bound_addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
|
||||
let opener = MockOpener;
|
||||
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
|
||||
|
||||
let _server_task = tokio::spawn(async move {
|
||||
let _ = server.run().await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
|
||||
|
||||
let greeting = [0x05, 0x01, 0x00];
|
||||
conn.write_all(&greeting).await.unwrap();
|
||||
|
||||
let mut auth_resp = [0u8; 2];
|
||||
conn.read_exact(&mut auth_resp).await.unwrap();
|
||||
assert_eq!(auth_resp, [0x05, 0x00]);
|
||||
|
||||
let connect_req = [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80];
|
||||
conn.write_all(&connect_req).await.unwrap();
|
||||
|
||||
let mut reply = [0u8; 10];
|
||||
conn.read_exact(&mut reply).await.unwrap();
|
||||
assert_eq!(reply[1], 0x00);
|
||||
|
||||
conn.write_all(b"test data").await.unwrap();
|
||||
conn.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
//! Local and remote port forwarding.
|
||||
//!
|
||||
//! `LocalForwarder` binds a local TCP listener and forwards each connection through
|
||||
//! an SSH `direct-tcpip` channel. `RemoteForwarder` requests `tcpip-forward` from
|
||||
//! the server and handles `forwarded-tcpip` channels. Specs follow the
|
||||
//! `bind_addr:bind_port:target_host:target_port` format.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use russh::client;
|
||||
use tokio::io;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::error::ForwardError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PortForwardSpecKind {
|
||||
Local,
|
||||
Remote,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PortForwardSpec {
|
||||
pub kind: PortForwardSpecKind,
|
||||
pub bind_addr: String,
|
||||
pub bind_port: u16,
|
||||
pub target_host: String,
|
||||
pub target_port: u16,
|
||||
}
|
||||
|
||||
impl PortForwardSpec {
|
||||
pub fn local(spec: &str) -> Result<Self, ForwardError> {
|
||||
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||
Ok(Self {
|
||||
kind: PortForwardSpecKind::Local,
|
||||
bind_addr,
|
||||
bind_port,
|
||||
target_host,
|
||||
target_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remote(spec: &str) -> Result<Self, ForwardError> {
|
||||
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||
Ok(Self {
|
||||
kind: PortForwardSpecKind::Remote,
|
||||
bind_addr,
|
||||
bind_port,
|
||||
target_host,
|
||||
target_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||
format!("{}:{}", self.bind_addr, self.bind_port)
|
||||
.parse()
|
||||
.map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: format!("{}:{}", self.bind_addr, self.bind_port),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn target_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||
format!("{}:{}", self.target_host, self.target_port)
|
||||
.parse()
|
||||
.map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: format!("{}:{}", self.target_host, self.target_port),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PortForwardSpec {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let prefix = match self.kind {
|
||||
PortForwardSpecKind::Local => "-L",
|
||||
PortForwardSpecKind::Remote => "-R",
|
||||
};
|
||||
write!(
|
||||
f,
|
||||
"{} {}:{}:{}:{}",
|
||||
prefix, self.bind_addr, self.bind_port, self.target_host, self.target_port
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_spec(spec: &str) -> Result<(String, u16, String, u16), ForwardError> {
|
||||
let parts: Vec<&str> = spec.split(':').collect();
|
||||
if parts.len() != 4 {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let bind_addr = parts[0].to_string();
|
||||
let bind_port: u16 = parts[1].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
})?;
|
||||
let target_host = parts[2].to_string();
|
||||
let target_port: u16 = parts[3].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
})?;
|
||||
|
||||
Ok((bind_addr, bind_port, target_host, target_port))
|
||||
}
|
||||
|
||||
pub struct LocalForwarder {
|
||||
spec: PortForwardSpec,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl LocalForwarder {
|
||||
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||
if spec.kind != PortForwardSpecKind::Local {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: format!("expected local spec, got {:?}", spec.kind),
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
spec,
|
||||
listener: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn spec(&self) -> &PortForwardSpec {
|
||||
&self.spec
|
||||
}
|
||||
|
||||
pub async fn run<H: client::Handler + Send + 'static>(
|
||||
&mut self,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
) -> Result<(), ForwardError> {
|
||||
let listen_addr = self.spec.listen_addr()?;
|
||||
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||
.await
|
||||
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||
self.listener = Some(listener);
|
||||
let remote_host = self.spec.target_host.clone();
|
||||
let remote_port = self.spec.target_port;
|
||||
|
||||
info!(
|
||||
"local forward listening on {} -> {}:{}",
|
||||
listen_addr, remote_host, remote_port
|
||||
);
|
||||
|
||||
loop {
|
||||
let listener = match &self.listener {
|
||||
Some(l) => l,
|
||||
None => return Ok(()),
|
||||
};
|
||||
let accept_result = listener.accept().await;
|
||||
let (local_stream, local_addr) = match accept_result {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
let handle = handle.lock().await;
|
||||
if handle.is_closed() {
|
||||
debug!("local forward accept loop ending: ssh session closed");
|
||||
return Ok(());
|
||||
}
|
||||
drop(handle);
|
||||
error!("local forward accept error: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"local forward connection from {} -> {}:{}",
|
||||
local_addr, remote_host, remote_port
|
||||
);
|
||||
|
||||
let handle = handle.clone();
|
||||
let remote_host = remote_host.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
proxy_local_to_remote(local_stream, handle, &remote_host, remote_port).await
|
||||
{
|
||||
debug!("local forward proxy error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(listener) = self.listener.take() {
|
||||
drop(listener);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_port(&self) -> u16 {
|
||||
self.spec.bind_port
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
|
||||
local_stream: TcpStream,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
remote_host: &str,
|
||||
remote_port: u16,
|
||||
) -> Result<(), ForwardError> {
|
||||
let local_addr = local_stream
|
||||
.peer_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let handle_guard = handle.lock().await;
|
||||
let channel = handle_guard
|
||||
.channel_open_direct_tcpip(remote_host, remote_port as u32, &local_addr, 0)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
drop(handle_guard);
|
||||
|
||||
let ssh_stream = channel.into_stream();
|
||||
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||
|
||||
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||
|
||||
match tokio::join!(client_to_server, server_to_client) {
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
debug!("local forward bidirectional copy error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct RemoteForwarder {
|
||||
spec: PortForwardSpec,
|
||||
cancel: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl RemoteForwarder {
|
||||
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||
if spec.kind != PortForwardSpecKind::Remote {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: format!("expected remote spec, got {:?}", spec.kind),
|
||||
});
|
||||
}
|
||||
Ok(Self { spec, cancel: None })
|
||||
}
|
||||
|
||||
pub fn spec(&self) -> &PortForwardSpec {
|
||||
&self.spec
|
||||
}
|
||||
|
||||
pub async fn register<H: client::Handler + Send + 'static>(
|
||||
&self,
|
||||
handle: &mut client::Handle<H>,
|
||||
) -> Result<u32, ForwardError> {
|
||||
let port = handle
|
||||
.tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
pub async fn handle_forwarded_channel(
|
||||
channel: russh::Channel<russh::client::Msg>,
|
||||
connected_address: &str,
|
||||
connected_port: u32,
|
||||
local_host: &str,
|
||||
local_port: u16,
|
||||
) {
|
||||
debug!(
|
||||
"remote forward: server opened forwarded-tcpip channel to {}:{} -> local {}:{}",
|
||||
connected_address, connected_port, local_host, local_port
|
||||
);
|
||||
|
||||
let local_target = format!("{}:{}", local_host, local_port);
|
||||
let local_stream = match TcpStream::connect(&local_target).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"remote forward: failed to connect to local target {}: {}",
|
||||
local_target, e
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let ssh_stream = channel.into_stream();
|
||||
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||
|
||||
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||
|
||||
match tokio::join!(client_to_server, server_to_client) {
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
debug!("remote forward bidirectional copy error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unregister<H: client::Handler + Send + 'static>(
|
||||
&self,
|
||||
handle: &client::Handle<H>,
|
||||
) -> Result<(), ForwardError> {
|
||||
handle
|
||||
.cancel_tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(cancel) = self.cancel.take() {
|
||||
let _ = cancel.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_local_forwarders<H: client::Handler + Send + 'static>(
|
||||
forwarders: Vec<LocalForwarder>,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
mut shutdown: tokio::sync::watch::Receiver<bool>,
|
||||
) -> Vec<LocalForwarder> {
|
||||
let mut forwarders = forwarders;
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for forwarder in forwarders.drain(..) {
|
||||
let handle = handle.clone();
|
||||
let spec = forwarder.spec().clone();
|
||||
let (_cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut fwd = forwarder;
|
||||
tokio::select! {
|
||||
result = fwd.run(handle) => {
|
||||
if let Err(e) = result {
|
||||
error!("local forward {} failed: {}", spec, e);
|
||||
}
|
||||
}
|
||||
_ = cancel_rx => {
|
||||
fwd.stop().await;
|
||||
}
|
||||
}
|
||||
fwd
|
||||
}));
|
||||
}
|
||||
|
||||
let _ = shutdown.changed().await;
|
||||
|
||||
for task in &tasks {
|
||||
task.abort();
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for task in tasks {
|
||||
match task.await {
|
||||
Ok(fwd) => results.push(fwd),
|
||||
Err(e) => {
|
||||
if !e.is_cancelled() {
|
||||
error!("local forwarder task panicked: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_local_spec() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert_eq!(spec.kind, PortForwardSpecKind::Local);
|
||||
assert_eq!(spec.bind_addr, "127.0.0.1");
|
||||
assert_eq!(spec.bind_port, 5432);
|
||||
assert_eq!(spec.target_host, "db.internal");
|
||||
assert_eq!(spec.target_port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_remote_spec() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert_eq!(spec.kind, PortForwardSpecKind::Remote);
|
||||
assert_eq!(spec.bind_addr, "0.0.0.0");
|
||||
assert_eq!(spec.bind_port, 8080);
|
||||
assert_eq!(spec.target_host, "127.0.0.1");
|
||||
assert_eq!(spec.target_port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_few_parts() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:5432:db").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_many_parts() {
|
||||
assert!(PortForwardSpec::local("a:b:c:d:e").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_port() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:abc:db:5432").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_target_port() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:5432:db:abc").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_display() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert_eq!(spec.to_string(), "-L 127.0.0.1:5432:db.internal:5432");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_display_remote() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert_eq!(spec.to_string(), "-R 0.0.0.0:8080:127.0.0.1:3000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_forwarder_rejects_remote_spec() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert!(LocalForwarder::new(spec).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_forwarder_rejects_local_spec() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert!(RemoteForwarder::new(spec).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listen_addr_valid() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
let addr = spec.listen_addr().unwrap();
|
||||
assert_eq!(addr.port(), 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listen_addr_invalid_host() {
|
||||
let spec = PortForwardSpec {
|
||||
kind: PortForwardSpecKind::Local,
|
||||
bind_addr: "!!!invalid".to_string(),
|
||||
bind_port: 5432,
|
||||
target_host: "db".to_string(),
|
||||
target_port: 5432,
|
||||
};
|
||||
assert!(spec.listen_addr().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_forward_bind_and_accept() {
|
||||
let spec = PortForwardSpec::local(&format!("127.0.0.1:0:remote:5432")).unwrap();
|
||||
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||
|
||||
let listen_addr = forwarder.spec.listen_addr().unwrap();
|
||||
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
||||
let bound_addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
|
||||
let spec = PortForwardSpec::local(&format!("127.0.0.1:{}:remote:5432", bound_addr.port()))
|
||||
.unwrap();
|
||||
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||
assert_eq!(forwarder.local_port(), bound_addr.port());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_forward_proxy_bidirectional() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let _echo_addr = echo_listener.local_addr().unwrap();
|
||||
|
||||
let echo_server = tokio::spawn(async move {
|
||||
let (mut stream, _) = echo_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(_) => break,
|
||||
};
|
||||
if stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let local_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = local_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_task = tokio::spawn(async move {
|
||||
let (stream, _) = local_listener.accept().await.unwrap();
|
||||
let (mut read, mut write) = tokio::io::split(stream);
|
||||
let _ = io::copy(&mut read, &mut write).await;
|
||||
});
|
||||
|
||||
let mut local_conn = TcpStream::connect(local_addr).await.unwrap();
|
||||
local_conn.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = local_conn.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..n], b"hello");
|
||||
|
||||
echo_server.abort();
|
||||
proxy_task.abort();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarder_spec_access() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
let forwarder = LocalForwarder::new(spec.clone()).unwrap();
|
||||
assert_eq!(forwarder.spec(), &spec);
|
||||
assert_eq!(forwarder.local_port(), 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_forwarder_spec_access() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
|
||||
assert_eq!(forwarder.spec(), &spec);
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//! Client-side SSH session management.
|
||||
//!
|
||||
//! Provides `ClientSession` for establishing an SSH connection over any transport,
|
||||
//! running a local SOCKS5 proxy, and managing port forwards. Also provides
|
||||
//! `ChannelManager` for programmatic channel management with automatic reconnection.
|
||||
//!
|
||||
//! The client always starts a SOCKS5 proxy (default `127.0.0.1:1080`) when running
|
||||
//! via `ClientSession::run()`. For VPN-like "route all traffic" behavior, use
|
||||
//! [tun2proxy](https://github.com/tun2proxy/tun2proxy) alongside the SOCKS5 proxy.
|
||||
|
||||
pub mod channel_manager;
|
||||
pub mod connect;
|
||||
pub mod forward;
|
||||
|
||||
pub use channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
|
||||
@@ -1,99 +0,0 @@
|
||||
//! Configuration service for runtime config reload.
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use super::{DynamicConfig, ForwardingPolicy, RateLimitConfig};
|
||||
|
||||
pub struct ConfigServiceImpl {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigServiceImpl {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
|
||||
pub fn forwarding_policy(&self) -> Arc<ForwardingPolicy> {
|
||||
Arc::new(self.dynamic.load().forwarding.clone())
|
||||
}
|
||||
|
||||
pub fn rate_limits(&self) -> Arc<RateLimitConfig> {
|
||||
Arc::new(self.dynamic.load().rate_limits.clone())
|
||||
}
|
||||
|
||||
pub fn reload(&self, new_config: DynamicConfig) {
|
||||
self.dynamic.store(Arc::new(new_config));
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConfigServiceImpl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConfigServiceImpl").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
#[allow(dead_code)]
|
||||
pub enum ConfigProtocol {
|
||||
GetForwardingPolicy,
|
||||
GetRateLimits,
|
||||
ReloadForwarding { policy: ForwardingPolicy },
|
||||
ReloadRateLimits { limits: RateLimitConfig },
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::AuthPolicy;
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_forwarding_policy() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let policy = service.forwarding_policy();
|
||||
assert_eq!(policy.default, ForwardingPolicy::allow_all().default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_rate_limits() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let limits = service.rate_limits();
|
||||
assert_eq!(limits.max_auth_attempts, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_reload() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
assert_eq!(
|
||||
service.forwarding_policy().default,
|
||||
ForwardingPolicy::allow_all().default
|
||||
);
|
||||
|
||||
let new_config = DynamicConfig {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::deny_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: std::collections::HashMap::new(),
|
||||
};
|
||||
service.reload(new_config);
|
||||
|
||||
assert_eq!(
|
||||
service.forwarding_policy().default,
|
||||
ForwardingPolicy::deny_all().default
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_debug() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let debug_str = format!("{:?}", service);
|
||||
assert!(debug_str.contains("ConfigServiceImpl"));
|
||||
}
|
||||
}
|
||||
@@ -1,603 +0,0 @@
|
||||
//! Runtime-reloadable dynamic configuration (auth policy, forwarding policy, rate limits).
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
|
||||
use crate::auth::identity::Identity;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::forwarding::ForwardingPolicy;
|
||||
use crate::credentials::CredentialSet;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub struct ApiKeyEntry {
|
||||
pub prefix: String,
|
||||
pub hash: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub description: String,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
pub const API_KEY_PREFIX: &str = "alk_";
|
||||
|
||||
pub struct AuthPolicy {
|
||||
pub authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
pub cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
pub api_keys: Vec<ApiKeyEntry>,
|
||||
encoded_keys: std::collections::HashSet<Vec<u8>>,
|
||||
fingerprint_to_key: HashMap<String, russh::keys::PublicKey>,
|
||||
}
|
||||
|
||||
fn encode_key_data(key: &russh::keys::PublicKey) -> Vec<u8> {
|
||||
use russh::keys::helpers::EncodedExt;
|
||||
key.key_data().encoded().unwrap_or_default()
|
||||
}
|
||||
|
||||
impl AuthPolicy {
|
||||
pub fn new(
|
||||
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
) -> Self {
|
||||
Self::with_api_keys(authorized_keys, cert_authorities, Vec::new())
|
||||
}
|
||||
|
||||
pub fn with_api_keys(
|
||||
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
api_keys: Vec<ApiKeyEntry>,
|
||||
) -> Self {
|
||||
let encoded_keys = authorized_keys.iter().map(encode_key_data).collect();
|
||||
let fingerprint_to_key = authorized_keys
|
||||
.iter()
|
||||
.map(|k| (format!("{}", k.fingerprint(HashAlg::Sha256)), k.clone()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
authorized_keys,
|
||||
cert_authorities,
|
||||
api_keys,
|
||||
encoded_keys,
|
||||
fingerprint_to_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_server_auth_config(config: ServerAuthConfig) -> Self {
|
||||
Self::new(config.authorized_keys, config.cert_authorities)
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self::new(std::collections::HashSet::new(), Vec::new())
|
||||
}
|
||||
|
||||
pub fn resolve_identity_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
if self.fingerprint_to_key.contains_key(fingerprint) {
|
||||
Some(Identity {
|
||||
id: fingerprint.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_api_key(&self, token: &str) -> Option<Identity> {
|
||||
if !token.starts_with(API_KEY_PREFIX) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix_part = &token[..token.len().min(8)];
|
||||
|
||||
let entry = self
|
||||
.api_keys
|
||||
.iter()
|
||||
.find(|e| prefix_part.starts_with(&e.prefix))?;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
let expected_hash = format!("sha256:{}", hex::encode(result));
|
||||
|
||||
if entry.hash != expected_hash {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at {
|
||||
let now_secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
if now_secs >= expires_at {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(Identity {
|
||||
id: entry.prefix.clone(),
|
||||
scopes: entry.scopes.clone(),
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate_publickey(
|
||||
&self,
|
||||
key: &russh::keys::PublicKey,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let encoded = encode_key_data(key);
|
||||
if self.encoded_keys.contains(&encoded) {
|
||||
return Ok(());
|
||||
}
|
||||
Err(crate::error::AuthError::KeyRejected)
|
||||
}
|
||||
|
||||
pub fn authenticate_certificate(
|
||||
&self,
|
||||
cert: &russh::keys::Certificate,
|
||||
user: &str,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
use std::time::SystemTime;
|
||||
|
||||
let matching_ca = self
|
||||
.cert_authorities
|
||||
.iter()
|
||||
.find(|ca| cert.signature_key() == ca.public_key.key_data());
|
||||
|
||||
let ca_entry = match matching_ca {
|
||||
Some(entry) => entry,
|
||||
None => return Err(crate::error::AuthError::CertInvalid),
|
||||
};
|
||||
|
||||
if cert.verify_signature().is_err() {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
let now_secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
|
||||
return Err(crate::error::AuthError::CertExpired);
|
||||
}
|
||||
|
||||
let principals = cert.valid_principals();
|
||||
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
|
||||
return Err(crate::error::AuthError::CertPrincipalMismatch);
|
||||
}
|
||||
|
||||
check_critical_options(cert, ca_entry, client_ip)?;
|
||||
check_extensions(cert, ca_entry)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_critical_options(
|
||||
cert: &russh::keys::Certificate,
|
||||
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
|
||||
|
||||
for (name, data) in cert.critical_options().iter() {
|
||||
match name.as_str() {
|
||||
"source-address" => {
|
||||
if !check_source_address(data, client_ip) {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
"force-command" => {}
|
||||
"no-pty" => {}
|
||||
_ => {
|
||||
let _ = ca_has_no_pty;
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_extensions(
|
||||
cert: &russh::keys::Certificate,
|
||||
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let ca_permit_port_forwarding = ca_entry
|
||||
.options
|
||||
.iter()
|
||||
.any(|o| o == "permit-port-forwarding");
|
||||
|
||||
if ca_permit_port_forwarding {
|
||||
let cert_allows = cert
|
||||
.extensions()
|
||||
.iter()
|
||||
.any(|(n, _)| n == "permit-port-forwarding");
|
||||
if !cert_allows {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_source_address(allowed: &str, client_ip: Option<std::net::IpAddr>) -> bool {
|
||||
use ipnetwork::IpNetwork;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
let Some(ip) = client_ip else {
|
||||
return false;
|
||||
};
|
||||
|
||||
for pattern in allowed.split(',') {
|
||||
let pattern = pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(cidr) = IpNetwork::from_str(pattern) {
|
||||
if cidr.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(net_ip) = IpAddr::from_str(pattern) {
|
||||
if net_ip == ip {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AuthPolicy {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AuthPolicy")
|
||||
.field("authorized_keys_count", &self.authorized_keys.len())
|
||||
.field("cert_authorities_count", &self.cert_authorities.len())
|
||||
.field("api_keys_count", &self.api_keys.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AuthPolicy {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
authorized_keys: self.authorized_keys.clone(),
|
||||
cert_authorities: self.cert_authorities.clone(),
|
||||
api_keys: self.api_keys.clone(),
|
||||
encoded_keys: self.encoded_keys.clone(),
|
||||
fingerprint_to_key: self.fingerprint_to_key.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
pub max_connections_per_ip: usize,
|
||||
pub max_auth_attempts: usize,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections_per_ip: 0,
|
||||
max_auth_attempts: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub struct DynamicConfig {
|
||||
pub auth: AuthPolicy,
|
||||
pub forwarding: ForwardingPolicy,
|
||||
pub rate_limits: RateLimitConfig,
|
||||
pub credentials: HashMap<String, CredentialSet>,
|
||||
}
|
||||
|
||||
impl DynamicConfig {
|
||||
pub fn new(auth: AuthPolicy) -> Self {
|
||||
Self {
|
||||
auth,
|
||||
forwarding: ForwardingPolicy::allow_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_parts(
|
||||
auth: AuthPolicy,
|
||||
forwarding: ForwardingPolicy,
|
||||
rate_limits: RateLimitConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
auth,
|
||||
forwarding,
|
||||
rate_limits,
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_forwarding_policy(mut self, policy: ForwardingPolicy) -> Self {
|
||||
self.forwarding = policy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_rate_limits(mut self, limits: RateLimitConfig) -> Self {
|
||||
self.rate_limits = limits;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_credentials(mut self, credentials: HashMap<String, CredentialSet>) -> Self {
|
||||
self.credentials = credentials;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DynamicConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::allow_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConfigReloadHandle {
|
||||
pub(crate) dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigReloadHandle {
|
||||
pub fn reload(&self, new_config: DynamicConfig) {
|
||||
self.dynamic.store(Arc::new(new_config));
|
||||
}
|
||||
|
||||
pub fn dynamic(&self) -> Arc<DynamicConfig> {
|
||||
self.dynamic.load_full()
|
||||
}
|
||||
|
||||
pub fn dynamic_arc(&self) -> Arc<ArcSwap<DynamicConfig>> {
|
||||
Arc::clone(&self.dynamic)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConfigReloadHandle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConfigReloadHandle").finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_dynamic_config() -> (Arc<ArcSwap<DynamicConfig>>, ConfigReloadHandle) {
|
||||
let inner = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let handle = ConfigReloadHandle {
|
||||
dynamic: Arc::clone(&inner),
|
||||
};
|
||||
(inner, handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::forwarding::ForwardingAction;
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_allow_all_default() {
|
||||
let policy = ForwardingPolicy::allow_all();
|
||||
assert_eq!(policy.default, ForwardingAction::Allow);
|
||||
assert!(policy.rules.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_deny_all() {
|
||||
let policy = ForwardingPolicy::deny_all();
|
||||
assert_eq!(policy.default, ForwardingAction::Deny);
|
||||
assert!(policy.rules.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_config_default() {
|
||||
let config = DynamicConfig::default();
|
||||
assert_eq!(config.forwarding.default, ForwardingAction::Allow);
|
||||
assert_eq!(config.rate_limits.max_connections_per_ip, 0);
|
||||
assert_eq!(config.rate_limits.max_auth_attempts, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_handle_updates_dynamic() {
|
||||
let (arc_swap, handle) = new_dynamic_config();
|
||||
let initial = arc_swap.load();
|
||||
assert_eq!(initial.forwarding.default, ForwardingAction::Allow);
|
||||
|
||||
let new_config = DynamicConfig {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::deny_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
};
|
||||
handle.reload(new_config);
|
||||
|
||||
let updated = arc_swap.load();
|
||||
assert_eq!(updated.forwarding.default, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_config_with_forwarding_policy_builder() {
|
||||
let config = DynamicConfig::new(AuthPolicy::empty())
|
||||
.with_forwarding_policy(ForwardingPolicy::deny_all());
|
||||
assert_eq!(config.forwarding.default, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limit_config_custom() {
|
||||
let limits = RateLimitConfig {
|
||||
max_connections_per_ip: 5,
|
||||
max_auth_attempts: 3,
|
||||
};
|
||||
assert_eq!(limits.max_connections_per_ip, 5);
|
||||
assert_eq!(limits.max_auth_attempts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_action_equality() {
|
||||
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
|
||||
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
|
||||
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_empty_rejects_all() {
|
||||
let policy = AuthPolicy::empty();
|
||||
let key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key =
|
||||
russh::keys::parse_public_key_base64(key_text.split_whitespace().nth(1).unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
policy.authenticate_publickey(&other_ssh_key),
|
||||
Err(crate::error::AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_debug_redacts_keys() {
|
||||
let policy = AuthPolicy::empty();
|
||||
let debug_str = format!("{:?}", policy);
|
||||
assert!(debug_str.contains("authorized_keys_count"));
|
||||
assert!(debug_str.contains("cert_authorities_count"));
|
||||
assert!(debug_str.contains("api_keys_count"));
|
||||
}
|
||||
|
||||
fn compute_api_key_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_valid_authenticates() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
let identity = policy.resolve_api_key(token);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_expired_rejected() {
|
||||
let token = "alk_expiredkey1";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_expi".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "expired key".to_string(),
|
||||
expires_at: Some(1),
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key(token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_wrong_hash_rejected() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||
.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "bad hash".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key("alk_testsecret123").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_unknown_prefix_falls_through() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_other".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "other key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key(token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_scopes_propagate() {
|
||||
let token = "alk_scopesecret";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_sco".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string(), "secrets:derive".to_string()],
|
||||
description: "scoped key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
let identity = policy.resolve_api_key(token).unwrap();
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "secrets:derive"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_api_key_prefix_returns_none() {
|
||||
let policy = AuthPolicy::empty();
|
||||
assert!(policy.resolve_api_key("bearer-some-token").is_none());
|
||||
assert!(policy.resolve_api_key("regular-token").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_entry_default_empty() {
|
||||
let config = DynamicConfig::default();
|
||||
assert!(config.auth.api_keys.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_with_api_keys_preserves_entries() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_abc".to_string(),
|
||||
hash: "sha256:abcdef".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy = AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry.clone()],
|
||||
);
|
||||
assert_eq!(policy.api_keys.len(), 1);
|
||||
assert_eq!(policy.api_keys[0], entry);
|
||||
}
|
||||
}
|
||||
@@ -1,534 +0,0 @@
|
||||
//! Forwarding policy engine for per-identity and per-transport access control.
|
||||
//!
|
||||
//! See [ADR-031](docs/architecture/decisions/031-forwarding-policy.md).
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::ops::Range;
|
||||
use std::str::FromStr;
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
|
||||
use crate::auth::identity::Identity;
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ForwardingAction {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum TargetPattern {
|
||||
Any,
|
||||
Host(String),
|
||||
Cidr(IpNetwork),
|
||||
PortRange(String, Range<u16>),
|
||||
AlknetPrefix,
|
||||
}
|
||||
|
||||
impl TargetPattern {
|
||||
pub fn matches(&self, target: &str, port: u16) -> bool {
|
||||
match self {
|
||||
TargetPattern::Any => true,
|
||||
TargetPattern::Host(pattern) => match_host_pattern(pattern, target),
|
||||
TargetPattern::Cidr(network) => match_cidr(network, target),
|
||||
TargetPattern::PortRange(host_pattern, port_range) => {
|
||||
match_host_pattern(host_pattern, target) && port_range.contains(&port)
|
||||
}
|
||||
TargetPattern::AlknetPrefix => {
|
||||
target.starts_with(crate::server::control_channel::ALKNET_PREFIX)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn match_host_pattern(pattern: &str, target: &str) -> bool {
|
||||
if pattern == target {
|
||||
return true;
|
||||
}
|
||||
if pattern.contains('*') {
|
||||
if let Some(pos) = pattern.find('*') {
|
||||
let prefix = &pattern[..pos];
|
||||
let suffix = &pattern[pos + 1..];
|
||||
return target.starts_with(prefix)
|
||||
&& target.ends_with(suffix)
|
||||
&& target.len() >= prefix.len() + suffix.len();
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn match_cidr(network: &IpNetwork, target: &str) -> bool {
|
||||
let Ok(addr) = IpAddr::from_str(target) else {
|
||||
return false;
|
||||
};
|
||||
network.contains(addr)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub struct ForwardingRule {
|
||||
pub target: TargetPattern,
|
||||
pub action: ForwardingAction,
|
||||
pub principals: Vec<String>,
|
||||
pub transports: Vec<TransportKind>,
|
||||
}
|
||||
|
||||
impl ForwardingRule {
|
||||
pub fn new(
|
||||
target: TargetPattern,
|
||||
action: ForwardingAction,
|
||||
principals: Vec<String>,
|
||||
transports: Vec<TransportKind>,
|
||||
) -> Self {
|
||||
Self {
|
||||
target,
|
||||
action,
|
||||
principals,
|
||||
transports,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ForwardingRule {
|
||||
fn matches_principal(&self, identity: &Identity) -> bool {
|
||||
if self.principals.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.principals
|
||||
.iter()
|
||||
.any(|p| p == &identity.id || identity.scopes.contains(p))
|
||||
}
|
||||
|
||||
fn matches_transport(&self, transport: &TransportKind) -> bool {
|
||||
if self.transports.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.transports.contains(transport)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ForwardingPolicy {
|
||||
pub default: ForwardingAction,
|
||||
pub rules: Vec<ForwardingRule>,
|
||||
}
|
||||
|
||||
impl ForwardingPolicy {
|
||||
pub fn allow_all() -> Self {
|
||||
Self {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deny_all() -> Self {
|
||||
Self {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(
|
||||
&self,
|
||||
target: &str,
|
||||
port: u16,
|
||||
identity: &Identity,
|
||||
transport: TransportKind,
|
||||
) -> bool {
|
||||
for rule in &self.rules {
|
||||
if rule.target.matches(target, port)
|
||||
&& rule.matches_principal(identity)
|
||||
&& rule.matches_transport(&transport)
|
||||
{
|
||||
return rule.action == ForwardingAction::Allow;
|
||||
}
|
||||
}
|
||||
self.default == ForwardingAction::Allow
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_identity(id: &str, scopes: Vec<&str>) -> Identity {
|
||||
Identity {
|
||||
id: id.to_string(),
|
||||
scopes: scopes.into_iter().map(|s| s.to_string()).collect(),
|
||||
resources: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_action_equality() {
|
||||
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
|
||||
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
|
||||
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_all_allows_everything() {
|
||||
let policy = ForwardingPolicy::allow_all();
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"10.0.0.1",
|
||||
22,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_all_denies_everything() {
|
||||
let policy = ForwardingPolicy::deny_all();
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check(
|
||||
"10.0.0.1",
|
||||
22,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_allowlist() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("allowed.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check("denied.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_blocklist() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_ordering() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
},
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
},
|
||||
],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_principals_matches_all() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity1 = make_identity("user1", vec![]);
|
||||
let identity2 = make_identity("user2", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity1, TransportKind::Tcp));
|
||||
assert!(policy.check("example.com", 80, &identity2, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn principal_matching_by_id() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["SHA256:abc123".to_string()],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let allowed = make_identity("SHA256:abc123", vec![]);
|
||||
let denied = make_identity("SHA256:other", vec![]);
|
||||
assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp));
|
||||
assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn principal_matching_by_scope() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["admin".to_string()],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let allowed = make_identity("user1", vec!["admin"]);
|
||||
let denied = make_identity("user2", vec!["viewer"]);
|
||||
assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp));
|
||||
assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_transports_matches_all() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
80,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
80,
|
||||
&identity,
|
||||
TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_matching() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::Tls { server_name: None }],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("example.com", 443, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
443,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_any_matches_all() {
|
||||
let pattern = TargetPattern::Any;
|
||||
assert!(pattern.matches("example.com", 80));
|
||||
assert!(pattern.matches("10.0.0.1", 22));
|
||||
assert!(pattern.matches("alknet-control", 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_exact_match() {
|
||||
let pattern = TargetPattern::Host("example.com".to_string());
|
||||
assert!(pattern.matches("example.com", 80));
|
||||
assert!(!pattern.matches("other.com", 80));
|
||||
assert!(!pattern.matches("sub.example.com", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_match() {
|
||||
let pattern = TargetPattern::Host("*.example.com".to_string());
|
||||
assert!(pattern.matches("sub.example.com", 80));
|
||||
assert!(pattern.matches("a.example.com", 443));
|
||||
assert!(!pattern.matches("example.com", 80));
|
||||
assert!(!pattern.matches("xsub.example.com.org", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_prefix() {
|
||||
let pattern = TargetPattern::Host("db-*".to_string());
|
||||
assert!(pattern.matches("db-primary", 5432));
|
||||
assert!(pattern.matches("db-replica", 5432));
|
||||
assert!(!pattern.matches("web-primary", 5432));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_suffix() {
|
||||
let pattern = TargetPattern::Host("*.internal".to_string());
|
||||
assert!(pattern.matches("app.internal", 8080));
|
||||
assert!(pattern.matches("db.internal", 5432));
|
||||
assert!(!pattern.matches("app.external", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_cidr_matches_ip() {
|
||||
let network: IpNetwork = "10.0.0.0/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(pattern.matches("10.0.0.1", 22));
|
||||
assert!(pattern.matches("10.255.255.255", 22));
|
||||
assert!(!pattern.matches("192.168.1.1", 22));
|
||||
assert!(!pattern.matches("not-an-ip", 22));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_cidr_ipv6() {
|
||||
let network: IpNetwork = "fd00::/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(pattern.matches("fd00::1", 22));
|
||||
assert!(!pattern.matches("10.0.0.1", 22));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_port_range_matches() {
|
||||
let pattern = TargetPattern::PortRange("localhost".to_string(), 8080..8090);
|
||||
assert!(pattern.matches("localhost", 8080));
|
||||
assert!(pattern.matches("localhost", 8085));
|
||||
assert!(pattern.matches("localhost", 8089));
|
||||
assert!(!pattern.matches("localhost", 8079));
|
||||
assert!(!pattern.matches("localhost", 8090));
|
||||
assert!(!pattern.matches("otherhost", 8080));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_port_range_with_glob() {
|
||||
let pattern = TargetPattern::PortRange("*.internal".to_string(), 3000..4000);
|
||||
assert!(pattern.matches("app.internal", 3000));
|
||||
assert!(pattern.matches("app.internal", 3999));
|
||||
assert!(!pattern.matches("app.internal", 2999));
|
||||
assert!(!pattern.matches("app.internal", 4000));
|
||||
assert!(!pattern.matches("app.external", 3000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_alknet_prefix() {
|
||||
let pattern = TargetPattern::AlknetPrefix;
|
||||
assert!(pattern.matches("alknet-control", 0));
|
||||
assert!(pattern.matches("alknet-status", 0));
|
||||
assert!(pattern.matches("alknet-", 0));
|
||||
assert!(!pattern.matches("example.com", 0));
|
||||
assert!(!pattern.matches("alknet.example.com", 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_fallthrough_allow() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("anything", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_fallthrough_deny() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("anything", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_principal_and_transport_matching() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("restricted.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["admin".to_string()],
|
||||
transports: vec![TransportKind::Tls { server_name: None }],
|
||||
}],
|
||||
};
|
||||
let admin = make_identity("admin-user", vec!["admin"]);
|
||||
let viewer = make_identity("viewer-user", vec!["viewer"]);
|
||||
assert!(policy.check(
|
||||
"restricted.example.com",
|
||||
443,
|
||||
&admin,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
assert!(!policy.check("restricted.example.com", 443, &admin, TransportKind::Tcp));
|
||||
assert!(!policy.check(
|
||||
"restricted.example.com",
|
||||
443,
|
||||
&viewer,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn webtransport_restricted_to_alknet() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![
|
||||
ForwardingRule {
|
||||
target: TargetPattern::AlknetPrefix,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::WebTransport { server_name: None }],
|
||||
},
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::WebTransport { server_name: None }],
|
||||
},
|
||||
],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check(
|
||||
"alknet-control",
|
||||
0,
|
||||
&identity,
|
||||
TransportKind::WebTransport { server_name: None }
|
||||
));
|
||||
assert!(!policy.check(
|
||||
"example.com",
|
||||
443,
|
||||
&identity,
|
||||
TransportKind::WebTransport { server_name: None }
|
||||
));
|
||||
assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cidr_does_not_match_hostname() {
|
||||
let network: IpNetwork = "10.0.0.0/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(!pattern.matches("example.com", 22));
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
pub mod config_service;
|
||||
pub mod dynamic_config;
|
||||
pub mod forwarding;
|
||||
pub mod static_config;
|
||||
|
||||
pub use config_service::ConfigServiceImpl;
|
||||
pub use dynamic_config::{
|
||||
new_dynamic_config, ApiKeyEntry, AuthPolicy, ConfigReloadHandle, DynamicConfig,
|
||||
RateLimitConfig, API_KEY_PREFIX,
|
||||
};
|
||||
pub use forwarding::{ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern};
|
||||
pub use static_config::StaticConfig;
|
||||
@@ -1,281 +0,0 @@
|
||||
//! Static (immutable) server configuration resolved at startup.
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use crate::interface::StreamInterfaceKind;
|
||||
use crate::server::handler::{ProxyConfig, ProxyMode};
|
||||
use crate::server::serve::{ListenerConfig, ServeTransportMode, StreamListenerConfig};
|
||||
use crate::transport::TransportKind;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub struct StaticConfig {
|
||||
pub transport_mode: ServeTransportMode,
|
||||
pub listen_addr: String,
|
||||
pub tls_cert: Option<String>,
|
||||
pub tls_key: Option<String>,
|
||||
pub acme_domain: Option<String>,
|
||||
pub stealth: bool,
|
||||
pub host_key: russh::keys::PrivateKey,
|
||||
pub host_key_algorithm: russh::keys::Algorithm,
|
||||
pub max_auth_attempts: usize,
|
||||
pub max_connections_per_ip: usize,
|
||||
pub proxy_config: Option<ProxyConfig>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub listeners: Vec<ListenerConfig>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StaticConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("StaticConfig")
|
||||
.field("transport_mode", &self.transport_mode)
|
||||
.field("listen_addr", &self.listen_addr)
|
||||
.field("tls_cert", &self.tls_cert.as_ref().map(|_| "<redacted>"))
|
||||
.field("tls_key", &self.tls_key.as_ref().map(|_| "<redacted>"))
|
||||
.field("acme_domain", &self.acme_domain)
|
||||
.field("stealth", &self.stealth)
|
||||
.field("host_key_algorithm", &self.host_key_algorithm)
|
||||
.field("max_auth_attempts", &self.max_auth_attempts)
|
||||
.field("max_connections_per_ip", &self.max_connections_per_ip)
|
||||
.field("proxy_config", &self.proxy_config)
|
||||
.field("iroh_relay", &self.iroh_relay)
|
||||
.field("listeners", &self.listeners)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticConfig {
|
||||
pub fn from_serve_options(
|
||||
opts: crate::server::serve::ServeOptions,
|
||||
) -> Result<(Self, crate::config::DynamicConfig), crate::error::ConfigError> {
|
||||
opts.validate()?;
|
||||
|
||||
let host_key = crate::auth::keys::load_private_key(opts.key.clone())?;
|
||||
let host_key_algorithm = host_key.algorithm();
|
||||
|
||||
let auth_config = crate::auth::ServerAuthConfig::from_keys_and_ca(
|
||||
opts.authorized_keys.clone(),
|
||||
opts.cert_authority.clone(),
|
||||
)?;
|
||||
|
||||
let auth_policy = crate::config::AuthPolicy::from_server_auth_config(auth_config);
|
||||
|
||||
let dynamic = crate::config::DynamicConfig::new(auth_policy);
|
||||
|
||||
let proxy_config = parse_proxy_config(opts.proxy.as_deref())?;
|
||||
|
||||
let listeners = if let Some(listeners) = opts.listeners {
|
||||
listeners
|
||||
} else {
|
||||
vec![ListenerConfig::Stream {
|
||||
config: StreamListenerConfig {
|
||||
transport_kind: match opts.transport_mode {
|
||||
ServeTransportMode::Tcp => TransportKind::Tcp,
|
||||
ServeTransportMode::Tls => TransportKind::Tls { server_name: None },
|
||||
ServeTransportMode::Iroh => TransportKind::Iroh {
|
||||
endpoint_id: String::new(),
|
||||
},
|
||||
},
|
||||
interface: StreamInterfaceKind::Ssh,
|
||||
listen_addr: opts.listen_addr.clone(),
|
||||
tls_cert: opts.tls_cert.clone(),
|
||||
tls_key: opts.tls_key.clone(),
|
||||
acme_domain: opts.acme_domain.clone(),
|
||||
stealth: opts.stealth,
|
||||
iroh_relay: opts.iroh_relay.clone(),
|
||||
},
|
||||
}]
|
||||
};
|
||||
|
||||
let static_config = StaticConfig {
|
||||
transport_mode: opts.transport_mode,
|
||||
listen_addr: opts.listen_addr,
|
||||
tls_cert: opts.tls_cert,
|
||||
tls_key: opts.tls_key,
|
||||
acme_domain: opts.acme_domain,
|
||||
stealth: opts.stealth,
|
||||
host_key,
|
||||
host_key_algorithm,
|
||||
max_auth_attempts: opts.max_auth_attempts,
|
||||
max_connections_per_ip: opts.max_connections_per_ip,
|
||||
proxy_config,
|
||||
iroh_relay: opts.iroh_relay,
|
||||
listeners,
|
||||
};
|
||||
|
||||
Ok((static_config, dynamic))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_proxy_config(
|
||||
proxy: Option<&str>,
|
||||
) -> Result<Option<ProxyConfig>, crate::error::ConfigError> {
|
||||
match proxy {
|
||||
None => Ok(None),
|
||||
Some(url) => {
|
||||
if let Some(rest) = url.strip_prefix("socks5://") {
|
||||
let addr: SocketAddr =
|
||||
rest.parse()
|
||||
.map_err(|e| crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!("invalid socks5 proxy address '{}': {}", rest, e),
|
||||
})?;
|
||||
Ok(Some(ProxyConfig {
|
||||
mode: ProxyMode::Socks5(addr),
|
||||
}))
|
||||
} else if let Some(rest) = url.strip_prefix("http://") {
|
||||
let addr: SocketAddr =
|
||||
rest.parse()
|
||||
.map_err(|e| crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!(
|
||||
"invalid http connect proxy address '{}': {}",
|
||||
rest, e
|
||||
),
|
||||
})?;
|
||||
Ok(Some(ProxyConfig {
|
||||
mode: ProxyMode::HttpConnect(addr),
|
||||
}))
|
||||
} else {
|
||||
Err(crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!("unsupported proxy URL scheme: {}", url),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::server::serve::ServeOptions;
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn make_key_source() -> KeySource {
|
||||
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn make_authorized_keys_source() -> KeySource {
|
||||
KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_socks5() {
|
||||
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050")).unwrap();
|
||||
assert!(config.is_some());
|
||||
match config.unwrap().mode {
|
||||
ProxyMode::Socks5(addr) => {
|
||||
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
|
||||
}
|
||||
_ => panic!("expected Socks5"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_http() {
|
||||
let config = parse_proxy_config(Some("http://127.0.0.1:8080")).unwrap();
|
||||
assert!(config.is_some());
|
||||
match config.unwrap().mode {
|
||||
ProxyMode::HttpConnect(addr) => {
|
||||
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
|
||||
}
|
||||
_ => panic!("expected HttpConnect"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_none() {
|
||||
assert!(parse_proxy_config(None).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_invalid_scheme() {
|
||||
let result = parse_proxy_config(Some("ftp://127.0.0.1:9050"));
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("unsupported proxy URL scheme"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_invalid_address() {
|
||||
let result = parse_proxy_config(Some("socks5://not-an-address"));
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("invalid socks5 proxy address"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_basic() {
|
||||
let opts =
|
||||
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
|
||||
let (static_config, dynamic) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert_eq!(static_config.listen_addr, "0.0.0.0:22");
|
||||
assert_eq!(static_config.max_auth_attempts, 10);
|
||||
assert!(dynamic.auth.authorized_keys.len() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_with_proxy() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("socks5://127.0.0.1:9050");
|
||||
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert!(static_config.proxy_config.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_with_listeners() {
|
||||
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")];
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.listeners(listeners);
|
||||
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert_eq!(static_config.listeners.len(), 1);
|
||||
match &static_config.listeners[0] {
|
||||
ListenerConfig::Stream { config } => {
|
||||
assert_eq!(config.transport_kind, TransportKind::Tcp);
|
||||
}
|
||||
_ => panic!("expected Stream variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_invalid_proxy_returns_err() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("ftp://bad-scheme");
|
||||
let result = StaticConfig::from_serve_options(opts);
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("unsupported proxy URL scheme"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_malformed_proxy_address_returns_err() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("socks5://not-a-valid-addr");
|
||||
let result = StaticConfig::from_serve_options(opts);
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("invalid socks5 proxy address"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum CredentialSet {
|
||||
ApiKey {
|
||||
header_name: String,
|
||||
token: String,
|
||||
},
|
||||
Basic {
|
||||
username: String,
|
||||
password: String,
|
||||
},
|
||||
Bearer {
|
||||
token: String,
|
||||
},
|
||||
S3AccessKey {
|
||||
access_key: String,
|
||||
secret_key: String,
|
||||
session_token: Option<String>,
|
||||
},
|
||||
OidcToken {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
Custom {
|
||||
scheme: String,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait CredentialProvider: Send + Sync + 'static {
|
||||
fn get_credentials(&self, service: &str) -> Option<CredentialSet>;
|
||||
fn refresh_credentials(&self, service: &str) -> Option<CredentialSet>;
|
||||
}
|
||||
|
||||
pub struct ConfigCredentialProvider {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigCredentialProvider {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for ConfigCredentialProvider {
|
||||
fn get_credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
let config = self.dynamic.load();
|
||||
config.credentials.get(service).cloned()
|
||||
}
|
||||
|
||||
fn refresh_credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
self.get_credentials(service)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecretStoreCredentialProvider;
|
||||
|
||||
impl SecretStoreCredentialProvider {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecretStoreCredentialProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for SecretStoreCredentialProvider {
|
||||
fn get_credentials(&self, _service: &str) -> Option<CredentialSet> {
|
||||
None
|
||||
}
|
||||
|
||||
fn refresh_credentials(&self, _service: &str) -> Option<CredentialSet> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::AuthPolicy;
|
||||
|
||||
fn make_dynamic_with_credentials() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let mut credentials = HashMap::new();
|
||||
credentials.insert(
|
||||
"vast-ai".to_string(),
|
||||
CredentialSet::Bearer {
|
||||
token: "secret-token".to_string(),
|
||||
},
|
||||
);
|
||||
credentials.insert(
|
||||
"custom-service".to_string(),
|
||||
CredentialSet::ApiKey {
|
||||
header_name: "X-API-Key".to_string(),
|
||||
token: "api-key-123".to_string(),
|
||||
},
|
||||
);
|
||||
let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials);
|
||||
Arc::new(ArcSwap::new(Arc::new(config)))
|
||||
}
|
||||
|
||||
fn make_dynamic_empty() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let config = DynamicConfig::default();
|
||||
Arc::new(ArcSwap::new(Arc::new(config)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_configured_credentials() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("vast-ai");
|
||||
assert!(creds.is_some());
|
||||
match creds.unwrap() {
|
||||
CredentialSet::Bearer { token } => assert_eq!(token, "secret-token"),
|
||||
_ => panic!("expected Bearer variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_api_key_variant() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("custom-service");
|
||||
assert!(creds.is_some());
|
||||
match creds.unwrap() {
|
||||
CredentialSet::ApiKey { header_name, token } => {
|
||||
assert_eq!(header_name, "X-API-Key");
|
||||
assert_eq!(token, "api-key-123");
|
||||
}
|
||||
_ => panic!("expected ApiKey variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_none_for_unknown_service() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("nonexistent");
|
||||
assert!(creds.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_empty_config_returns_none() {
|
||||
let dynamic = make_dynamic_empty();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("vast-ai");
|
||||
assert!(creds.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_store_credential_provider_returns_none() {
|
||||
let provider = SecretStoreCredentialProvider::new();
|
||||
assert!(provider.get_credentials("vast-ai").is_none());
|
||||
assert!(provider.get_credentials("rustfs").is_none());
|
||||
assert!(provider.get_credentials("gitea").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_store_credential_provider_refresh_returns_none() {
|
||||
let provider = SecretStoreCredentialProvider::new();
|
||||
assert!(provider.refresh_credentials("vast-ai").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_bearer_serialization() {
|
||||
let creds = CredentialSet::Bearer {
|
||||
token: "tok".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_s3_access_key_serialization() {
|
||||
let creds = CredentialSet::S3AccessKey {
|
||||
access_key: "AKIA123".to_string(),
|
||||
secret_key: "secret".to_string(),
|
||||
session_token: Some("session".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_oidc_token_serialization() {
|
||||
let creds = CredentialSet::OidcToken {
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: Some("refresh".to_string()),
|
||||
expires_at: Some(1234567890),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_custom_serialization() {
|
||||
let mut params = HashMap::new();
|
||||
params.insert("key1".to_string(), "val1".to_string());
|
||||
let creds = CredentialSet::Custom {
|
||||
scheme: "X-Custom".to_string(),
|
||||
params,
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_basic_serialization() {
|
||||
let creds = CredentialSet::Basic {
|
||||
username: "user".to_string(),
|
||||
password: "pass".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_clone() {
|
||||
let creds = CredentialSet::Bearer {
|
||||
token: "tok".to_string(),
|
||||
};
|
||||
let cloned = creds.clone();
|
||||
assert_eq!(creds, cloned);
|
||||
}
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
//! Error types for alknet-core.
|
||||
//!
|
||||
//! Layered error hierarchy:
|
||||
//! - `TransportError` — connection/handshake/timeout errors (trigger reconnection on client)
|
||||
//! - `AuthError` — key rejection, certificate validation failures
|
||||
//! - `ChannelError` — per-channel failures (target unreachable, channel closed)
|
||||
//! - `ConfigError` — invalid configuration (flags, key files, bind failures)
|
||||
//! - `ForwardError` — port forward setup and connection failures
|
||||
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TransportError {
|
||||
#[error("connection failed")]
|
||||
ConnectionFailed,
|
||||
#[error("handshake failed")]
|
||||
HandshakeFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("transport timeout")]
|
||||
Timeout,
|
||||
#[error("proxy failed")]
|
||||
ProxyFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("key rejected")]
|
||||
KeyRejected,
|
||||
#[error("certificate invalid")]
|
||||
CertInvalid,
|
||||
#[error("certificate expired")]
|
||||
CertExpired,
|
||||
#[error("certificate principal mismatch")]
|
||||
CertPrincipalMismatch,
|
||||
#[error("no matching key")]
|
||||
NoMatchingKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelError {
|
||||
#[error("target unreachable")]
|
||||
TargetUnreachable,
|
||||
#[error("proxy connect failed")]
|
||||
ProxyConnectFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("channel closed")]
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("invalid flag: {name}")]
|
||||
InvalidFlag { name: String },
|
||||
#[error("key file not found: {path}")]
|
||||
KeyFileNotFound { path: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("incompatible options")]
|
||||
IncompatibleOptions,
|
||||
#[error("invalid proxy config: {message}")]
|
||||
ProxyConfigInvalid { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ForwardError {
|
||||
#[error("invalid port forward spec: {spec}")]
|
||||
InvalidSpec { spec: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed {
|
||||
#[source]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
},
|
||||
#[error("connect to local target failed")]
|
||||
LocalConnectFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn transport_error_display() {
|
||||
assert_eq!(
|
||||
TransportError::ConnectionFailed.to_string(),
|
||||
"connection failed"
|
||||
);
|
||||
assert_eq!(
|
||||
TransportError::HandshakeFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
|
||||
}
|
||||
.to_string(),
|
||||
"handshake failed"
|
||||
);
|
||||
assert_eq!(TransportError::Timeout.to_string(), "transport timeout");
|
||||
assert_eq!(
|
||||
TransportError::ProxyFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "proxy err")
|
||||
}
|
||||
.to_string(),
|
||||
"proxy failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_error_display() {
|
||||
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
|
||||
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
|
||||
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
|
||||
assert_eq!(
|
||||
AuthError::CertPrincipalMismatch.to_string(),
|
||||
"certificate principal mismatch"
|
||||
);
|
||||
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_error_display() {
|
||||
assert_eq!(
|
||||
ChannelError::TargetUnreachable.to_string(),
|
||||
"target unreachable"
|
||||
);
|
||||
assert_eq!(
|
||||
ChannelError::ProxyConnectFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||
}
|
||||
.to_string(),
|
||||
"proxy connect failed"
|
||||
);
|
||||
assert_eq!(ChannelError::ChannelClosed.to_string(), "channel closed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_display() {
|
||||
assert_eq!(
|
||||
ConfigError::InvalidFlag {
|
||||
name: "--bad".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid flag: --bad"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::KeyFileNotFound {
|
||||
path: "/missing".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"key file not found: /missing"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::BindFailed {
|
||||
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||
}
|
||||
.to_string(),
|
||||
"bind failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::IncompatibleOptions.to_string(),
|
||||
"incompatible options"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::ProxyConfigInvalid {
|
||||
message: "bad proxy".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid proxy config: bad proxy"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_source_chaining() {
|
||||
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
|
||||
let transport_err = TransportError::HandshakeFailed { source: io_err };
|
||||
assert!(transport_err.source().is_some());
|
||||
|
||||
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "proxy");
|
||||
let channel_err = ChannelError::ProxyConnectFailed { source: io_err };
|
||||
assert!(channel_err.source().is_some());
|
||||
|
||||
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "addr");
|
||||
let config_err = ConfigError::BindFailed { source: io_err };
|
||||
assert!(config_err.source().is_some());
|
||||
|
||||
let plain = AuthError::KeyRejected;
|
||||
assert!(plain.source().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_error_display() {
|
||||
assert_eq!(
|
||||
ForwardError::InvalidSpec {
|
||||
spec: "bad".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid port forward spec: bad"
|
||||
);
|
||||
assert_eq!(
|
||||
ForwardError::BindFailed {
|
||||
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||
}
|
||||
.to_string(),
|
||||
"bind failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ForwardError::LocalConnectFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||
}
|
||||
.to_string(),
|
||||
"connect to local target failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_error_source_chaining() {
|
||||
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "in use");
|
||||
let forward_err = ForwardError::BindFailed { source: io_err };
|
||||
assert!(forward_err.source().is_some());
|
||||
|
||||
let plain = ForwardError::InvalidSpec {
|
||||
spec: "bad".to_string(),
|
||||
};
|
||||
assert!(plain.source().is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
use axum::extract::Request;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
|
||||
use crate::auth::{AuthToken, Identity, IdentityProvider};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct IdentityExt(pub Identity);
|
||||
|
||||
pub async fn auth_middleware(
|
||||
axum::extract::State(identity_provider): axum::extract::State<
|
||||
std::sync::Arc<dyn IdentityProvider>,
|
||||
>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let auth_header = request
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
let token_str = match auth_header {
|
||||
Some(h) if h.starts_with("Bearer ") => &h[7..],
|
||||
_ => {
|
||||
return axum::http::StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
match identity_provider.resolve_from_token(&token) {
|
||||
Some(identity) => {
|
||||
request.extensions_mut().insert(IdentityExt(identity));
|
||||
next.run(request).await
|
||||
}
|
||||
None => axum::http::StatusCode::UNAUTHORIZED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request as HttpRequest, StatusCode};
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
struct MockIdentityProvider {
|
||||
valid_token: String,
|
||||
identity: Identity,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
let token_str = String::from_utf8_lossy(&token.raw);
|
||||
if token_str == self.valid_token {
|
||||
Some(self.identity.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_provider(valid_token: &str) -> Arc<dyn IdentityProvider> {
|
||||
let identity = Identity {
|
||||
id: "test-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
Arc::new(MockIdentityProvider {
|
||||
valid_token: valid_token.to_string(),
|
||||
identity,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_extracts_bearer_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|request: Request| async move {
|
||||
let has_identity = request.extensions().get::<IdentityExt>().is_some();
|
||||
if has_identity {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_validtoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_returns_401_for_missing_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route("/test", get(|| async { StatusCode::OK.into_response() }))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_returns_401_for_invalid_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route("/test", get(|| async { StatusCode::OK.into_response() }))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_wrongtoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_attaches_identity_to_extensions() {
|
||||
let provider = make_provider("alk_testidentity1");
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|request: Request| async move {
|
||||
let identity = request.extensions().get::<IdentityExt>().unwrap();
|
||||
identity.0.id.clone()
|
||||
}),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_testidentity1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
||||
assert_eq!(&body[..], b"test-user");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
pub mod auth;
|
||||
pub mod router;
|
||||
|
||||
pub use auth::IdentityExt;
|
||||
pub use router::{build_router, serve_connection};
|
||||
@@ -1,150 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Router;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use hyper_util::service::TowerToHyperService;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
use crate::http::auth::auth_middleware;
|
||||
|
||||
async fn default_404() -> impl IntoResponse {
|
||||
axum::http::StatusCode::NOT_FOUND
|
||||
}
|
||||
|
||||
pub fn build_router(identity_provider: Arc<dyn IdentityProvider>) -> Router {
|
||||
Router::new()
|
||||
.fallback(default_404)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
identity_provider,
|
||||
auth_middleware,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn serve_connection<S>(stream: S, identity_provider: Arc<dyn IdentityProvider>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let app = build_router(identity_provider);
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let hyper_service = TowerToHyperService::new(app.into_service::<hyper::body::Incoming>());
|
||||
|
||||
let result = Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(io, hyper_service)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::debug!("http connection error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_connection_from_reader<S>(
|
||||
reader: BufReader<S>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
serve_connection(reader, identity_provider).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::{AuthToken, Identity};
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request as HttpRequest, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
struct NullIdentityProvider;
|
||||
|
||||
impl IdentityProvider for NullIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_404_handler_returns_not_found() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.header("authorization", "Bearer alk_sometoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_auth_returns_401_before_404() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_auth_returns_401_before_404() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NullIdentityProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.header("authorization", "Bearer alk_sometoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unmatched_route_returns_404_with_valid_auth() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/v1/unknown/op")
|
||||
.header("authorization", "Bearer alk_valid")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
struct MockValidProvider;
|
||||
|
||||
impl IdentityProvider for MockValidProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
Some(Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,270 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use russh::keys::PrivateKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum StreamInterfaceKind {
|
||||
Ssh,
|
||||
RawFraming,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StreamInterfaceKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamInterfaceKind::Ssh => write!(f, "ssh"),
|
||||
StreamInterfaceKind::RawFraming => write!(f, "raw-framing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum MessageInterfaceKind {
|
||||
Http,
|
||||
Dns,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MessageInterfaceKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageInterfaceKind::Http => write!(f, "http"),
|
||||
MessageInterfaceKind::Dns => write!(f, "dns"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum InterfaceConfig {
|
||||
Ssh(SshInterfaceConfig),
|
||||
RawFraming(RawFramingConfig),
|
||||
}
|
||||
|
||||
impl InterfaceConfig {
|
||||
pub fn kind(&self) -> StreamInterfaceKind {
|
||||
#[allow(unreachable_patterns)]
|
||||
match self {
|
||||
InterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh,
|
||||
InterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming,
|
||||
_ => StreamInterfaceKind::Ssh,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum StreamInterfaceConfig {
|
||||
Ssh(SshInterfaceConfig),
|
||||
RawFraming(RawFramingConfig),
|
||||
}
|
||||
|
||||
impl StreamInterfaceConfig {
|
||||
pub fn kind(&self) -> StreamInterfaceKind {
|
||||
match self {
|
||||
StreamInterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh,
|
||||
StreamInterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StreamInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamInterfaceConfig::Ssh(_) => write!(f, "ssh"),
|
||||
StreamInterfaceConfig::RawFraming(_) => write!(f, "raw-framing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum MessageInterfaceConfig {
|
||||
Http(HttpInterfaceConfig),
|
||||
Dns(DnsInterfaceConfig),
|
||||
}
|
||||
|
||||
impl MessageInterfaceConfig {
|
||||
pub fn kind(&self) -> MessageInterfaceKind {
|
||||
match self {
|
||||
MessageInterfaceConfig::Http(_) => MessageInterfaceKind::Http,
|
||||
MessageInterfaceConfig::Dns(_) => MessageInterfaceKind::Dns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MessageInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageInterfaceConfig::Http(_) => write!(f, "http"),
|
||||
MessageInterfaceConfig::Dns(_) => write!(f, "dns"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshInterfaceConfig {
|
||||
pub auth: Arc<dyn IdentityProvider>,
|
||||
pub forwarding: Arc<ArcSwap<DynamicConfig>>,
|
||||
pub host_key: Arc<PrivateKey>,
|
||||
}
|
||||
|
||||
pub struct RawFramingConfig {
|
||||
pub auth: Arc<dyn IdentityProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct HttpInterfaceConfig {
|
||||
pub bind_addr: std::net::SocketAddr,
|
||||
pub tls: bool,
|
||||
pub stealth: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HttpInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "http {}", self.bind_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct DnsInterfaceConfig {
|
||||
pub bind_addr: std::net::SocketAddr,
|
||||
pub tls: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DnsInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "dns {}", self.bind_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::ConfigIdentityProvider;
|
||||
|
||||
#[test]
|
||||
fn stream_interface_kind_display() {
|
||||
assert_eq!(StreamInterfaceKind::Ssh.to_string(), "ssh");
|
||||
assert_eq!(StreamInterfaceKind::RawFraming.to_string(), "raw-framing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_kind_display() {
|
||||
assert_eq!(MessageInterfaceKind::Http.to_string(), "http");
|
||||
assert_eq!(MessageInterfaceKind::Dns.to_string(), "dns");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_interface_config_kind() {
|
||||
let auth = Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
)));
|
||||
let ssh_config = StreamInterfaceConfig::Ssh(SshInterfaceConfig {
|
||||
auth,
|
||||
forwarding: Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))),
|
||||
host_key: Arc::new(
|
||||
russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
});
|
||||
assert_eq!(ssh_config.kind(), StreamInterfaceKind::Ssh);
|
||||
|
||||
let raw_config = StreamInterfaceConfig::RawFraming(RawFramingConfig {
|
||||
auth: Arc::new(ConfigIdentityProvider::new(Arc::new(ArcSwap::new(
|
||||
Arc::new(DynamicConfig::default()),
|
||||
)))),
|
||||
});
|
||||
assert_eq!(raw_config.kind(), StreamInterfaceKind::RawFraming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_config_kind() {
|
||||
let http_config = MessageInterfaceConfig::Http(HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: false,
|
||||
stealth: false,
|
||||
});
|
||||
assert_eq!(http_config.kind(), MessageInterfaceKind::Http);
|
||||
|
||||
let dns_config = MessageInterfaceConfig::Dns(DnsInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:53".parse().unwrap(),
|
||||
tls: false,
|
||||
});
|
||||
assert_eq!(dns_config.kind(), MessageInterfaceKind::Dns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_interface_kind_equality() {
|
||||
assert_eq!(StreamInterfaceKind::Ssh, StreamInterfaceKind::Ssh);
|
||||
assert_eq!(
|
||||
StreamInterfaceKind::RawFraming,
|
||||
StreamInterfaceKind::RawFraming
|
||||
);
|
||||
assert_ne!(StreamInterfaceKind::Ssh, StreamInterfaceKind::RawFraming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_kind_equality() {
|
||||
assert_eq!(MessageInterfaceKind::Http, MessageInterfaceKind::Http);
|
||||
assert_eq!(MessageInterfaceKind::Dns, MessageInterfaceKind::Dns);
|
||||
assert_ne!(MessageInterfaceKind::Http, MessageInterfaceKind::Dns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_framing_config_minimal() {
|
||||
let auth: Arc<dyn IdentityProvider> = Arc::new(ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
)));
|
||||
let _config = RawFramingConfig { auth };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_interface_config_display() {
|
||||
let config = HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: true,
|
||||
stealth: true,
|
||||
};
|
||||
assert_eq!(config.to_string(), "http 127.0.0.1:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_interface_config_display() {
|
||||
let config = DnsInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:53".parse().unwrap(),
|
||||
tls: false,
|
||||
};
|
||||
assert_eq!(config.to_string(), "dns 127.0.0.1:53");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_interface_config_serialization() {
|
||||
let config = HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: true,
|
||||
stealth: false,
|
||||
};
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: HttpInterfaceConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.bind_addr, config.bind_addr);
|
||||
assert_eq!(deserialized.tls, config.tls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_interface_config_serialization() {
|
||||
let config = DnsInterfaceConfig {
|
||||
bind_addr: "0.0.0.0:53".parse().unwrap(),
|
||||
tls: true,
|
||||
};
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: DnsInterfaceConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.bind_addr, config.bind_addr);
|
||||
assert_eq!(deserialized.tls, config.tls);
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface};
|
||||
|
||||
pub struct DnsInterface {
|
||||
pub domain: String,
|
||||
pub identity_provider: Arc<dyn crate::auth::IdentityProvider>,
|
||||
pub registry: Arc<crate::call::OperationRegistry>,
|
||||
pub env: OperationEnv,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for DnsInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Err(crate::call::CallError::new(
|
||||
"NOT_IMPLEMENTED",
|
||||
"DnsInterface is not yet implemented",
|
||||
false,
|
||||
)),
|
||||
status: 501,
|
||||
headers: std::collections::HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn dns_interface_type_exists() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let _iface = DnsInterface {
|
||||
domain: "alk.dev".to_string(),
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface};
|
||||
|
||||
pub struct HttpInterface {
|
||||
pub identity_provider: Arc<dyn crate::auth::IdentityProvider>,
|
||||
pub registry: Arc<crate::call::OperationRegistry>,
|
||||
pub env: OperationEnv,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for HttpInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Err(crate::call::CallError::new(
|
||||
"NOT_IMPLEMENTED",
|
||||
"HttpInterface is not yet implemented",
|
||||
false,
|
||||
)),
|
||||
status: 501,
|
||||
headers: std::collections::HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
impl HttpInterface {
|
||||
pub fn build_router(&self) -> axum::Router {
|
||||
crate::http::router::build_router(Arc::clone(&self.identity_provider))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn http_interface_type_exists() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let _iface = HttpInterface {
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[test]
|
||||
fn http_interface_builds_router() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let iface = HttpInterface {
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
let _router = iface.build_router();
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
//! Interface layer (Layer 2) of the three-layer model (ADR-026, ADR-035).
|
||||
//!
|
||||
//! The Interface layer sits between Transport (Layer 1) and Protocol (Layer 3).
|
||||
//! It has two distinct patterns:
|
||||
//!
|
||||
//! - **StreamInterface** — consumes a `TransportStream`, produces a long-lived
|
||||
//! `Session` that yields `InterfaceEvent` frames. SSH and raw framing are
|
||||
//! `StreamInterface` implementations.
|
||||
//!
|
||||
//! - **MessageInterface** — handles individual `InterfaceRequest` →
|
||||
//! `InterfaceResponse` pairs. Manages its own transport (HTTP server, DNS
|
||||
//! server). HTTP and DNS are `MessageInterface` implementations.
|
||||
|
||||
pub mod config;
|
||||
pub mod dns;
|
||||
pub mod http;
|
||||
pub mod pairs;
|
||||
pub mod raw_framing;
|
||||
pub mod session;
|
||||
pub mod ssh;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use config::{
|
||||
DnsInterfaceConfig, HttpInterfaceConfig, InterfaceConfig, MessageInterfaceConfig,
|
||||
MessageInterfaceKind, RawFramingConfig, SshInterfaceConfig, StreamInterfaceConfig,
|
||||
StreamInterfaceKind,
|
||||
};
|
||||
pub use dns::DnsInterface;
|
||||
pub use http::HttpInterface;
|
||||
pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS};
|
||||
pub use raw_framing::{RawFramingInterface, RawFramingSession};
|
||||
pub use session::{InterfaceEvent, InterfaceSession};
|
||||
pub use ssh::{ControlChannelBridge, SshInterface, SshSession};
|
||||
|
||||
pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> TransportStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait StreamInterface: Send + Sync + 'static {
|
||||
type Session: InterfaceSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait MessageInterface: Send + Sync + 'static {
|
||||
async fn handle_request(&self, request: InterfaceRequest) -> Result<InterfaceResponse>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceRequest {
|
||||
pub operation_path: String,
|
||||
pub input: serde_json::Value,
|
||||
pub auth_token: Option<crate::auth::AuthToken>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceResponse {
|
||||
pub result: Result<serde_json::Value, crate::call::CallError>,
|
||||
pub status: u16,
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn transport_stream_trait_bounds() {
|
||||
fn assert_transport_stream<S: TransportStream>() {}
|
||||
assert_transport_stream::<tokio::io::DuplexStream>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_stream_from_duplex() {
|
||||
let (client, server) = duplex(1024);
|
||||
let _boxed: Box<dyn TransportStream> = Box::new(server);
|
||||
let _: Box<dyn TransportStream> = Box::new(client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_request_fields() {
|
||||
let req = InterfaceRequest {
|
||||
operation_path: "/v1/head/auth/verify".to_string(),
|
||||
input: serde_json::json!({"key": "value"}),
|
||||
auth_token: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
assert_eq!(req.operation_path, "/v1/head/auth/verify");
|
||||
assert!(req.auth_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_response_fields() {
|
||||
let resp = InterfaceResponse {
|
||||
result: Ok(serde_json::json!({"status": "ok"})),
|
||||
status: 200,
|
||||
headers: HashMap::new(),
|
||||
};
|
||||
assert_eq!(resp.status, 200);
|
||||
}
|
||||
|
||||
struct MockMessageInterface;
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for MockMessageInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Ok(serde_json::json!({})),
|
||||
status: 200,
|
||||
headers: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_interface_trait_compiles() {
|
||||
let iface = MockMessageInterface;
|
||||
let req = InterfaceRequest {
|
||||
operation_path: "/test".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
auth_token: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
let resp = iface.handle_request(req).await.unwrap();
|
||||
assert_eq!(resp.status, 200);
|
||||
}
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
use super::config::StreamInterfaceKind;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TransportKindBase {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
WebTransport,
|
||||
}
|
||||
|
||||
fn transport_base(kind: &TransportKind) -> TransportKindBase {
|
||||
match kind {
|
||||
TransportKind::Tcp => TransportKindBase::Tcp,
|
||||
TransportKind::Tls { .. } => TransportKindBase::Tls,
|
||||
TransportKind::Iroh { .. } => TransportKindBase::Iroh,
|
||||
TransportKind::WebTransport { .. } => TransportKindBase::WebTransport,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_valid_pair(transport: &TransportKind, interface: StreamInterfaceKind) -> bool {
|
||||
let base = transport_base(transport);
|
||||
matches!(
|
||||
(base, interface),
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::Tls, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::Iroh, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::WebTransport, StreamInterfaceKind::Ssh)
|
||||
| (
|
||||
TransportKindBase::WebTransport,
|
||||
StreamInterfaceKind::RawFraming
|
||||
)
|
||||
| (TransportKindBase::Tcp, StreamInterfaceKind::RawFraming)
|
||||
)
|
||||
}
|
||||
|
||||
pub const VALID_TRANSPORT_INTERFACE_PAIRS: &[(TransportKindBase, StreamInterfaceKind)] = &[
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::Tls, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::Iroh, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::WebTransport, StreamInterfaceKind::Ssh),
|
||||
(
|
||||
TransportKindBase::WebTransport,
|
||||
StreamInterfaceKind::RawFraming,
|
||||
),
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::RawFraming),
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn valid_ssh_pairs() {
|
||||
assert!(is_valid_pair(&TransportKind::Tcp, StreamInterfaceKind::Ssh));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Tls { server_name: None },
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
},
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::WebTransport { server_name: None },
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_raw_framing_pairs() {
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Tcp,
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::WebTransport { server_name: None },
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pairs() {
|
||||
assert!(!is_valid_pair(
|
||||
&TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
},
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_base_classification() {
|
||||
assert_eq!(transport_base(&TransportKind::Tcp), TransportKindBase::Tcp);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::Tls {
|
||||
server_name: Some("example.com".to_string())
|
||||
}),
|
||||
TransportKindBase::Tls
|
||||
);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::Iroh {
|
||||
endpoint_id: "abc".to_string()
|
||||
}),
|
||||
TransportKindBase::Iroh
|
||||
);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::WebTransport {
|
||||
server_name: Some("example.com".to_string())
|
||||
}),
|
||||
TransportKindBase::WebTransport
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_pairs_table_complete() {
|
||||
assert_eq!(VALID_TRANSPORT_INTERFACE_PAIRS.len(), 6);
|
||||
}
|
||||
}
|
||||
@@ -1,399 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||
|
||||
use crate::auth::{AuthToken, Identity, IdentityProvider};
|
||||
use crate::call::frame::{decode_with_remainder, encode};
|
||||
use crate::call::EventEnvelope;
|
||||
use crate::interface::session::{InterfaceEvent, InterfaceSession};
|
||||
use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream};
|
||||
|
||||
const READ_BUF_SIZE: usize = 8192;
|
||||
|
||||
pub struct RawFramingInterface;
|
||||
|
||||
#[async_trait]
|
||||
impl StreamInterface for RawFramingInterface {
|
||||
type Session = RawFramingSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session> {
|
||||
let raw_config = match config {
|
||||
StreamInterfaceConfig::RawFraming(c) => c,
|
||||
StreamInterfaceConfig::Ssh(_) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"RawFramingInterface received SshInterfaceConfig"
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(RawFramingSession::new(stream, Arc::clone(&raw_config.auth)))
|
||||
}
|
||||
}
|
||||
|
||||
enum AuthState {
|
||||
Pending,
|
||||
Authenticated(Identity),
|
||||
Failed,
|
||||
}
|
||||
|
||||
pub struct RawFramingSession {
|
||||
reader: BufReader<tokio::io::ReadHalf<Box<dyn TransportStream>>>,
|
||||
writer: BufWriter<tokio::io::WriteHalf<Box<dyn TransportStream>>>,
|
||||
auth_state: AuthState,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
read_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RawFramingSession {
|
||||
pub fn new(
|
||||
stream: Box<dyn TransportStream>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) -> Self {
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
Self {
|
||||
reader: BufReader::new(read_half),
|
||||
writer: BufWriter::new(write_half),
|
||||
auth_state: AuthState::Pending,
|
||||
identity_provider,
|
||||
read_buf: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_frame(&mut self) -> Result<EventEnvelope> {
|
||||
loop {
|
||||
match decode_with_remainder(&self.read_buf) {
|
||||
Ok((envelope, consumed)) => {
|
||||
self.read_buf.drain(..consumed);
|
||||
return Ok(envelope);
|
||||
}
|
||||
Err(crate::call::frame::FrameDecodeError::TooShort { .. })
|
||||
| Err(crate::call::frame::FrameDecodeError::Incomplete { .. }) => {
|
||||
let mut tmp = [0u8; READ_BUF_SIZE];
|
||||
let n = self.reader.read(&mut tmp).await?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("stream closed while reading frame"));
|
||||
}
|
||||
self.read_buf.extend_from_slice(&tmp[..n]);
|
||||
}
|
||||
Err(crate::call::frame::FrameDecodeError::Json(e)) => {
|
||||
return Err(anyhow::anyhow!("frame JSON decode error: {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_frame(&mut self, envelope: &EventEnvelope) -> Result<()> {
|
||||
let frame = encode(envelope);
|
||||
self.writer.write_all(&frame).await?;
|
||||
self.writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InterfaceSession for RawFramingSession {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent> {
|
||||
match &self.auth_state {
|
||||
AuthState::Failed => return None,
|
||||
AuthState::Authenticated(_) => {
|
||||
let identity = match &self.auth_state {
|
||||
AuthState::Authenticated(id) => id.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let envelope = match self.read_frame().await {
|
||||
Ok(e) => e,
|
||||
Err(_) => return None,
|
||||
};
|
||||
return Some(InterfaceEvent::with_identity(envelope, identity));
|
||||
}
|
||||
AuthState::Pending => {}
|
||||
}
|
||||
|
||||
let envelope = match self.read_frame().await {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
self.auth_state = AuthState::Failed;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let token_raw = envelope.payload.as_str().unwrap_or("").as_bytes().to_vec();
|
||||
let token = AuthToken { raw: token_raw };
|
||||
|
||||
match self.identity_provider.resolve_from_token(&token) {
|
||||
Some(identity) => {
|
||||
self.auth_state = AuthState::Authenticated(identity.clone());
|
||||
Some(InterfaceEvent::with_identity(envelope, identity))
|
||||
}
|
||||
None => {
|
||||
self.auth_state = AuthState::Failed;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()> {
|
||||
match self.auth_state {
|
||||
AuthState::Failed => Err(anyhow::anyhow!("session authentication failed")),
|
||||
_ => self.write_frame(&envelope).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::ConfigIdentityProvider;
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::interface::RawFramingConfig;
|
||||
use arc_swap::ArcSwap;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_provider() -> Arc<dyn IdentityProvider> {
|
||||
Arc::new(ConfigIdentityProvider::new(Arc::new(ArcSwap::new(
|
||||
Arc::new(DynamicConfig::default()),
|
||||
))))
|
||||
}
|
||||
|
||||
fn make_provider_with_identity(
|
||||
identity: Identity,
|
||||
valid_token: &str,
|
||||
) -> (Arc<dyn IdentityProvider>, String) {
|
||||
struct MockProvider {
|
||||
identity: Identity,
|
||||
valid_token: String,
|
||||
}
|
||||
impl IdentityProvider for MockProvider {
|
||||
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
if token.raw == self.valid_token.as_bytes() {
|
||||
Some(self.identity.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
let provider = Arc::new(MockProvider {
|
||||
identity,
|
||||
valid_token: valid_token.to_string(),
|
||||
});
|
||||
(provider, valid_token.to_string())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_interface_accept_succeeds() {
|
||||
let iface = RawFramingInterface;
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let config = StreamInterfaceConfig::RawFraming(RawFramingConfig {
|
||||
auth: make_provider(),
|
||||
});
|
||||
let result = iface.accept(stream, &config).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_interface_rejects_ssh_config() {
|
||||
let iface = RawFramingInterface;
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let config = StreamInterfaceConfig::Ssh(crate::interface::SshInterfaceConfig {
|
||||
auth: make_provider(),
|
||||
forwarding: Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))),
|
||||
host_key: Arc::new(
|
||||
russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
});
|
||||
let result = iface.accept(stream, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_session_round_trip() {
|
||||
let identity = Identity {
|
||||
id: "test-id".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) =
|
||||
make_provider_with_identity(identity.clone(), "valid-test-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut server_session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let auth_frame = encode(&auth_envelope);
|
||||
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
client_writer.write_all(&auth_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let event = server_session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "test-id");
|
||||
|
||||
let data_envelope =
|
||||
EventEnvelope::call_requested("req-2", serde_json::json!({"op": "test"}));
|
||||
let data_frame = encode(&data_envelope);
|
||||
client_writer.write_all(&data_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let event = server_session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.envelope.r#type, "call.requested");
|
||||
assert_eq!(event.envelope.id, "req-2");
|
||||
assert!(event.identity.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_frame_auth_valid_token() {
|
||||
let identity = Identity {
|
||||
id: "auth-user".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "my-valid-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let frame = encode(&auth_envelope);
|
||||
let mut writer = tokio::io::BufWriter::new(client_stream);
|
||||
writer.write_all(&frame).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "auth-user");
|
||||
assert_eq!(event.identity.as_ref().unwrap().scopes, vec!["admin"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_frame_auth_invalid_token() {
|
||||
let identity = Identity {
|
||||
id: "auth-user".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, _) = make_provider_with_identity(identity, "correct-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let bad_envelope =
|
||||
EventEnvelope::new("auth", "auth-1", serde_json::json!("bad-token-value"));
|
||||
let frame = encode(&bad_envelope);
|
||||
let mut writer = tokio::io::BufWriter::new(client_stream);
|
||||
writer.write_all(&frame).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_none());
|
||||
|
||||
let data_envelope = EventEnvelope::call_requested("req-2", serde_json::json!({}));
|
||||
let result = session.send(data_envelope).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_session_send() {
|
||||
let identity = Identity {
|
||||
id: "send-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "send-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut server_session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let auth_frame = encode(&auth_envelope);
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
client_writer.write_all(&auth_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let _ = server_session.recv().await;
|
||||
|
||||
let response = EventEnvelope::call_responded("req-1", serde_json::json!({"result": "ok"}));
|
||||
let send_result = server_session.send(response).await;
|
||||
assert!(send_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_multiple_frames_over_duplex() {
|
||||
let identity = Identity {
|
||||
id: "multi-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "multi-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(8192);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-0", serde_json::json!(token_str));
|
||||
client_writer
|
||||
.write_all(&encode(&auth_envelope))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 1..=5 {
|
||||
let envelope =
|
||||
EventEnvelope::call_requested(format!("req-{i}"), serde_json::json!({"seq": i}));
|
||||
client_writer.write_all(&encode(&envelope)).await.unwrap();
|
||||
}
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let auth_event = session.recv().await;
|
||||
assert!(auth_event.is_some());
|
||||
assert!(auth_event.unwrap().identity.is_some());
|
||||
|
||||
for i in 1..=5 {
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.envelope.id, format!("req-{i}"));
|
||||
assert!(event.identity.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_framing_interface_type_exists() {
|
||||
let _iface = RawFramingInterface;
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::auth::Identity;
|
||||
use crate::call::EventEnvelope;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceEvent {
|
||||
pub envelope: EventEnvelope,
|
||||
pub identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl InterfaceEvent {
|
||||
pub fn new(envelope: EventEnvelope) -> Self {
|
||||
Self {
|
||||
envelope,
|
||||
identity: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_identity(envelope: EventEnvelope, identity: Identity) -> Self {
|
||||
Self {
|
||||
envelope,
|
||||
identity: Some(identity),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait InterfaceSession: Send {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent>;
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn interface_event_new() {
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
let event = InterfaceEvent::new(envelope.clone());
|
||||
assert_eq!(event.envelope, envelope);
|
||||
assert!(event.identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_event_with_identity() {
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let event = InterfaceEvent::with_identity(envelope.clone(), identity.clone());
|
||||
assert_eq!(event.envelope, envelope);
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "SHA256:abc123");
|
||||
}
|
||||
}
|
||||
@@ -1,982 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::server::{self, Config};
|
||||
use russh::Channel;
|
||||
use russh::ChannelId;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::auth::identity::{Identity, IdentityProvider};
|
||||
use crate::call::frame::{FrameFramedReader, FrameFramedWriter};
|
||||
use crate::call::EventEnvelope;
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::interface::session::{InterfaceEvent, InterfaceSession};
|
||||
use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream};
|
||||
use crate::server::control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
|
||||
ALKNET_PREFIX,
|
||||
};
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
struct SshHandler {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
authenticated_identity: Option<Identity>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
bridge_event_tx: Option<mpsc::Sender<InterfaceEvent>>,
|
||||
bridge_envelope_rx: Option<mpsc::Receiver<EventEnvelope>>,
|
||||
connected_at: Instant,
|
||||
}
|
||||
|
||||
impl SshHandler {
|
||||
fn new(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
dynamic,
|
||||
identity_provider,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
authenticated_identity: None,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
bridge_event_tx: None,
|
||||
bridge_envelope_rx: None,
|
||||
connected_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn with_control_channel_router(mut self, router: ControlChannelRouter) -> Self {
|
||||
self.control_channel_router = router;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_bridge_channels(
|
||||
mut self,
|
||||
event_tx: mpsc::Sender<InterfaceEvent>,
|
||||
envelope_rx: mpsc::Receiver<EventEnvelope>,
|
||||
) -> Self {
|
||||
self.bridge_event_tx = Some(event_tx);
|
||||
self.bridge_envelope_rx = Some(envelope_rx);
|
||||
self
|
||||
}
|
||||
|
||||
fn has_control_channel_bridge(&self) -> bool {
|
||||
self.bridge_event_tx.is_some() && self.bridge_envelope_rx.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SshHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl server::Handler for SshHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn auth_publickey(
|
||||
&mut self,
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<server::Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
|
||||
let identity = self
|
||||
.identity_provider
|
||||
.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match identity {
|
||||
Some(id) => {
|
||||
self.authenticated_identity = Some(id);
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(server::Auth::Accept)
|
||||
}
|
||||
None => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn channel_open_direct_tcpip(
|
||||
&mut self,
|
||||
channel: Channel<server::Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
originator_address: &str,
|
||||
originator_port: u32,
|
||||
_session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
if host_to_connect.starts_with(ALKNET_PREFIX) {
|
||||
if host_to_connect == ALKNET_CONTROL_DESTINATION && self.has_control_channel_bridge() {
|
||||
let event_tx = self.bridge_event_tx.take().unwrap();
|
||||
let envelope_rx = self.bridge_envelope_rx.take().unwrap();
|
||||
let identity = self.authenticated_identity.clone();
|
||||
tokio::spawn(async move {
|
||||
let stream = channel.into_stream();
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
run_control_channel_bridge(
|
||||
read_half,
|
||||
write_half,
|
||||
identity,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
let _ = (originator_address, originator_port);
|
||||
return Ok(true);
|
||||
}
|
||||
if self.control_channel_router.has_handler() {
|
||||
if let Some(handler) = self.control_channel_router.take_handler() {
|
||||
let stream: Box<dyn DuplexStream> = Box::new(channel.into_stream());
|
||||
tokio::spawn(async move {
|
||||
handler.handle_channel(stream).await;
|
||||
});
|
||||
}
|
||||
let _ = (originator_address, originator_port);
|
||||
return Ok(true);
|
||||
}
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let identity = self
|
||||
.authenticated_identity
|
||||
.clone()
|
||||
.unwrap_or_else(|| Identity {
|
||||
id: String::new(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
});
|
||||
|
||||
let policy = self.dynamic.load();
|
||||
let allowed = policy.forwarding.check(
|
||||
host_to_connect,
|
||||
port_to_connect as u16,
|
||||
&identity,
|
||||
self.transport.clone(),
|
||||
);
|
||||
|
||||
if !allowed {
|
||||
tracing::info!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
identity = %identity.id,
|
||||
transport = %self.transport,
|
||||
"forwarding denied by policy"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let target_host = host_to_connect.to_string();
|
||||
let target_port = port_to_connect;
|
||||
let proxy_config =
|
||||
self.outbound_proxy
|
||||
.clone()
|
||||
.unwrap_or(crate::server::handler::ProxyConfig {
|
||||
mode: crate::server::handler::ProxyMode::Direct,
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let target = match format!("{target_host}:{target_port}")
|
||||
.parse::<std::net::SocketAddr>()
|
||||
{
|
||||
Ok(addr) => addr,
|
||||
Err(_) => {
|
||||
match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
|
||||
Ok(mut addrs) => match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return,
|
||||
},
|
||||
Err(_) => return,
|
||||
}
|
||||
}
|
||||
};
|
||||
crate::server::channel_proxy::proxy_channel(
|
||||
channel.into_stream(),
|
||||
target,
|
||||
&proxy_config,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let _ = (originator_address, originator_port);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn channel_open_session(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected session channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_x11(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected x11 channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_forwarded_tcpip(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn exec_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
data: &[u8],
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
data_len = data.len(),
|
||||
"rejected exec request on channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shell_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected shell request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subsystem_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
name: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
subsystem = name,
|
||||
"rejected subsystem request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pty_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
term: &str,
|
||||
col_width: u32,
|
||||
row_height: u32,
|
||||
pix_width: u32,
|
||||
pix_height: u32,
|
||||
modes: &[(russh::Pty, u32)],
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
term = term,
|
||||
"rejected pty request on channel"
|
||||
);
|
||||
let _ = (col_width, row_height, pix_width, pix_height, modes);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn env_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
variable_name: &str,
|
||||
variable_value: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
variable = variable_name,
|
||||
"rejected env request on channel"
|
||||
);
|
||||
let _ = variable_value;
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn x11_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
single_connection: bool,
|
||||
x11_auth_protocol: &str,
|
||||
x11_auth_cookie: &str,
|
||||
x11_screen_number: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected x11 request on channel"
|
||||
);
|
||||
let _ = (
|
||||
single_connection,
|
||||
x11_auth_protocol,
|
||||
x11_auth_cookie,
|
||||
x11_screen_number,
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn agent_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected agent forwarding request on channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: &mut u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
address = address,
|
||||
port = *port,
|
||||
"rejected tcpip-forward request (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn cancel_tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
let _ = (address, port, session);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn streamlocal_forward(
|
||||
&mut self,
|
||||
socket_path: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
socket_path = socket_path,
|
||||
"rejected streamlocal-forward request"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn signal(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
signal: russh::Sig,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::debug!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
signal = ?signal,
|
||||
"received signal on channel (ignored)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshInterface {
|
||||
config: Arc<Config>,
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
max_auth_attempts: usize,
|
||||
}
|
||||
|
||||
impl SshInterface {
|
||||
pub fn new(config: Arc<Config>, dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
dynamic,
|
||||
connection_limiter: Arc::new(ConnectionRateLimiter::new(0)),
|
||||
outbound_proxy: None,
|
||||
max_auth_attempts: 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_connection_limiter(mut self, limiter: Arc<ConnectionRateLimiter>) -> Self {
|
||||
self.connection_limiter = limiter;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_outbound_proxy(
|
||||
mut self,
|
||||
proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
) -> Self {
|
||||
self.outbound_proxy = proxy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_auth_attempts(mut self, max: usize) -> Self {
|
||||
self.max_auth_attempts = max;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Arc<Config> {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn dynamic(&self) -> &Arc<ArcSwap<DynamicConfig>> {
|
||||
&self.dynamic
|
||||
}
|
||||
|
||||
async fn accept_inner(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
ssh_config: &crate::interface::SshInterfaceConfig,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
) -> Result<SshSession> {
|
||||
let identity_provider = Arc::clone(&ssh_config.auth);
|
||||
let _forwarding = Arc::clone(&ssh_config.forwarding);
|
||||
|
||||
let (event_tx, event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let handler = SshHandler::new(
|
||||
Arc::clone(&self.dynamic),
|
||||
identity_provider,
|
||||
self.outbound_proxy.clone(),
|
||||
remote_addr,
|
||||
transport,
|
||||
Arc::clone(&self.connection_limiter),
|
||||
self.max_auth_attempts,
|
||||
)
|
||||
.with_bridge_channels(event_tx, envelope_rx);
|
||||
|
||||
let running = server::run_stream(Arc::clone(&self.config), stream, handler).await?;
|
||||
let handle = running.handle();
|
||||
let join = tokio::spawn(async {
|
||||
let _ = running.await;
|
||||
});
|
||||
|
||||
Ok(SshSession {
|
||||
handle,
|
||||
_join: join,
|
||||
event_rx,
|
||||
envelope_tx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StreamInterface for SshInterface {
|
||||
type Session = SshSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session> {
|
||||
let ssh_config = match config {
|
||||
StreamInterfaceConfig::Ssh(c) => c,
|
||||
StreamInterfaceConfig::RawFraming(_) => {
|
||||
return Err(anyhow::anyhow!("SshInterface received RawFramingConfig"));
|
||||
}
|
||||
};
|
||||
|
||||
self.accept_inner(stream, ssh_config, None, TransportKind::Tcp)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshSession {
|
||||
handle: server::Handle,
|
||||
_join: tokio::task::JoinHandle<()>,
|
||||
event_rx: mpsc::Receiver<InterfaceEvent>,
|
||||
envelope_tx: mpsc::Sender<EventEnvelope>,
|
||||
}
|
||||
|
||||
impl SshSession {
|
||||
pub fn handle(&self) -> &server::Handle {
|
||||
&self.handle
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InterfaceSession for SshSession {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent> {
|
||||
self.event_rx.recv().await
|
||||
}
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()> {
|
||||
self.envelope_tx
|
||||
.send(envelope)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("control channel bridge closed"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_control_channel_bridge<R, W>(
|
||||
read_half: R,
|
||||
write_half: W,
|
||||
identity: Option<Identity>,
|
||||
event_tx: mpsc::Sender<InterfaceEvent>,
|
||||
mut envelope_rx: mpsc::Receiver<EventEnvelope>,
|
||||
) where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let mut reader = FrameFramedReader::new(read_half);
|
||||
let mut writer = FrameFramedWriter::new(write_half);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
frame = reader.read_frame() => {
|
||||
match frame {
|
||||
Ok(Some(envelope)) => {
|
||||
let event = match &identity {
|
||||
Some(id) => InterfaceEvent::with_identity(envelope, id.clone()),
|
||||
None => InterfaceEvent::new(envelope),
|
||||
};
|
||||
if event_tx.send(event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(None) => return,
|
||||
Err(_) => return,
|
||||
}
|
||||
}
|
||||
envelope = envelope_rx.recv() => {
|
||||
match envelope {
|
||||
Some(envelope) => {
|
||||
if writer.write_frame(&envelope).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
None => return,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlChannelBridge {
|
||||
identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl ControlChannelBridge {
|
||||
pub fn new(identity: Option<Identity>) -> Self {
|
||||
Self { identity }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for ControlChannelBridge {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>) {
|
||||
let (event_tx, _event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (_envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let identity = self.identity.clone();
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
read_half,
|
||||
write_half,
|
||||
identity,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::frame::{FrameFramedReader, FrameFramedWriter};
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn ssh_interface_constructs_with_config() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
|
||||
let iface = SshInterface::new(config, dynamic);
|
||||
assert!(iface.config().keys.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_interface_builder_pattern() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(5));
|
||||
|
||||
let iface = SshInterface::new(config, dynamic)
|
||||
.with_connection_limiter(limiter)
|
||||
.with_max_auth_attempts(3);
|
||||
|
||||
assert!(iface.config().keys.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_handler_auth_delegates_to_identity_provider() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct MockProvider {
|
||||
identities: HashMap<String, Identity>,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockProvider {
|
||||
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
|
||||
self.identities.get(fp).cloned()
|
||||
}
|
||||
fn resolve_from_token(&self, _t: &crate::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let mut ids = HashMap::new();
|
||||
ids.insert(
|
||||
"SHA256:testkey".to_string(),
|
||||
Identity {
|
||||
id: "SHA256:testkey".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
},
|
||||
);
|
||||
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockProvider { identities: ids });
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
|
||||
let handler = SshHandler::new(
|
||||
dynamic,
|
||||
provider,
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
|
||||
assert!(handler.authenticated_identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_handler_connection_rate_limiting() {
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
||||
crate::auth::identity::ConfigIdentityProvider::new(Arc::clone(&dynamic)),
|
||||
);
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = SshHandler::new(
|
||||
Arc::clone(&dynamic),
|
||||
Arc::clone(&provider),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
Arc::clone(&limiter),
|
||||
10,
|
||||
);
|
||||
assert!(h1.connection_allowed);
|
||||
|
||||
let h2 = SshHandler::new(
|
||||
dynamic,
|
||||
provider,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(!h2.connection_allowed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_interface_rejects_raw_framing_config() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let iface = SshInterface::new(config, dynamic);
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
|
||||
let raw_config = StreamInterfaceConfig::RawFraming(crate::interface::RawFramingConfig {
|
||||
auth: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
))),
|
||||
});
|
||||
let result = iface.accept(stream, &raw_config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_session_round_trip_event_envelope() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let identity = Identity {
|
||||
id: "SHA256:test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let identity_clone = identity.clone();
|
||||
|
||||
let (server_read, server_write) = tokio::io::split(server);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
server_read,
|
||||
server_write,
|
||||
Some(identity_clone),
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
|
||||
let (client_read, client_write) = tokio::io::split(client);
|
||||
let mut client_reader = FrameFramedReader::new(client_read);
|
||||
let mut client_writer = FrameFramedWriter::new(client_write);
|
||||
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
client_writer.write_frame(&envelope).await.unwrap();
|
||||
|
||||
let received_event =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(received_event.envelope, envelope);
|
||||
assert_eq!(received_event.identity.as_ref().unwrap().id, "SHA256:test");
|
||||
|
||||
let response = EventEnvelope::call_responded("req-1", serde_json::json!({"result": 42}));
|
||||
envelope_tx.send(response.clone()).await.unwrap();
|
||||
|
||||
let read_back = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
client_reader.read_frame(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(read_back, response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_session_recv_without_identity() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (_envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let (server_read, server_write) = tokio::io::split(server);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
server_read,
|
||||
server_write,
|
||||
None,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
|
||||
let (client_read, client_write) = tokio::io::split(client);
|
||||
let mut client_writer = FrameFramedWriter::new(client_write);
|
||||
let _client_reader = FrameFramedReader::new(client_read);
|
||||
|
||||
let envelope = EventEnvelope::call_requested("req-2", serde_json::json!({"op": "no-id"}));
|
||||
client_writer.write_frame(&envelope).await.unwrap();
|
||||
|
||||
let received_event =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(received_event.envelope, envelope);
|
||||
assert!(received_event.identity.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_channel_router_with_handler_routes_data() {
|
||||
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let called_clone = called.clone();
|
||||
|
||||
struct TrackingHandler {
|
||||
called: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackingHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(TrackingHandler {
|
||||
called: called_clone,
|
||||
}));
|
||||
assert!(router.has_handler());
|
||||
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(std::sync::atomic::Ordering::SeqCst));
|
||||
}
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
//! # alknet-core
|
||||
//!
|
||||
//! Core library for [Alknet](https://git.alk.dev/alkdev/alknet), a self-hostable SSH-based
|
||||
//! tunnel tool. This crate provides the transport abstraction, SOCKS5 server, port forwarding,
|
||||
//! authentication, and server handler — everything needed to build an alknet client or server
|
||||
//! on top of pluggable transports.
|
||||
//!
|
||||
//! > **Alpha software.** This crate depends on solid libraries (russh, tokio, rustls, iroh)
|
||||
//! > for core functionality, but the integration layer has not been battle-tested. Use with
|
||||
//! > caution and report issues.
|
||||
//!
|
||||
//! # Key concepts
|
||||
//!
|
||||
//! - **Transport trait** — produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
|
||||
//! that SSH consumes. Implementations: TCP, TLS, iroh (QUIC P2P).
|
||||
//! - **SOCKS5 server** — the primary client interface, listening on a local port and routing
|
||||
//! traffic through SSH channels.
|
||||
//! - **Port forwarding** — `-L` local and `-R` remote port forwards over SSH channels.
|
||||
//! - **Authentication** — Ed25519 public key and OpenSSH certificate authority. No passwords.
|
||||
//! - **Server handler** — accepts SSH connections via a `TransportAcceptor` and proxies
|
||||
//! `direct-tcpip` channel requests to targets (directly or via outbound proxy).
|
||||
//!
|
||||
//! # Feature flags
|
||||
//!
|
||||
//! | Feature | Default | Description |
|
||||
//! |---------|---------|-------------|
|
||||
//! | `tls` | yes | TLS transport via `tokio-rustls` |
|
||||
//! | `iroh` | yes | iroh QUIC P2P transport |
|
||||
//! | `acme` | no | ACME/Let's Encrypt auto-cert provisioning (implies `tls`) |
|
||||
//! | `irpc` | no | irpc service layer (AuthProtocol, AuthServiceImpl) |
|
||||
//! | `testutil` | no | Test utilities (for internal use) |
|
||||
//!
|
||||
//! # Quick example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use std::sync::Arc;
|
||||
//! use alknet_core::transport::TcpTransport;
|
||||
//! use alknet_core::client::{ClientSession, ConnectOptions, TransportMode};
|
||||
//! use alknet_core::auth::keys::KeySource;
|
||||
//! use alknet_core::Transport;
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
|
||||
//! .server("example.com:22")
|
||||
//! .transport_mode(TransportMode::Tcp);
|
||||
//! let transport = Arc::new(TcpTransport::new("example.com:22".parse()?));
|
||||
//! let session = ClientSession::new(opts, transport).await?;
|
||||
//! session.run().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod auth;
|
||||
pub mod call;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod credentials;
|
||||
pub mod error;
|
||||
pub mod interface;
|
||||
pub mod server;
|
||||
pub mod socks5;
|
||||
pub mod transport;
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub mod http;
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub use http::IdentityExt;
|
||||
|
||||
#[cfg(feature = "testutil")]
|
||||
pub mod testutil;
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub use auth::{AuthProtocol, AuthResult, AuthServiceImpl};
|
||||
pub use auth::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
pub use call::{
|
||||
decode as decode_frame, decode_with_remainder as decode_frame_with_remainder,
|
||||
encode as encode_frame,
|
||||
};
|
||||
pub use call::{
|
||||
register_default_operations, services_list_spec, services_schema_spec, AccessControl,
|
||||
CallError, EventEnvelope, FrameDecodeError, Handler, OperationContext, OperationEnv,
|
||||
OperationRegistry, OperationRegistryBuilder, OperationSpec, OperationType, PendingRequestMap,
|
||||
ResponseEnvelope,
|
||||
};
|
||||
pub use call::{CALL_ABORTED, CALL_COMPLETED, CALL_ERROR, CALL_REQUESTED, CALL_RESPONDED};
|
||||
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
pub use config::{
|
||||
AuthPolicy, ConfigReloadHandle, ConfigServiceImpl, DynamicConfig, ForwardingAction,
|
||||
ForwardingPolicy, ForwardingRule, RateLimitConfig, StaticConfig, TargetPattern,
|
||||
};
|
||||
pub use credentials::{
|
||||
ConfigCredentialProvider, CredentialProvider, CredentialSet, SecretStoreCredentialProvider,
|
||||
};
|
||||
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||
pub use interface::{
|
||||
is_valid_pair, DnsInterface, DnsInterfaceConfig, HttpInterface, HttpInterfaceConfig,
|
||||
InterfaceConfig, InterfaceEvent, InterfaceRequest, InterfaceResponse, InterfaceSession,
|
||||
MessageInterface, MessageInterfaceConfig, MessageInterfaceKind, RawFramingConfig,
|
||||
RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig, SshSession,
|
||||
StreamInterface, StreamInterfaceConfig, StreamInterfaceKind, TransportKindBase,
|
||||
TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS,
|
||||
};
|
||||
pub use server::serve::{
|
||||
DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions,
|
||||
ServeTransportMode, Server, StreamListenerConfig,
|
||||
};
|
||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
@@ -1,555 +0,0 @@
|
||||
//! Outbound connection proxy for SSH channel targets.
|
||||
//!
|
||||
//! Connects to the requested `host:port` either directly, via SOCKS5 proxy, or
|
||||
//! via HTTP CONNECT proxy, then proxies bytes bidirectionally between the SSH
|
||||
//! channel and the outbound TCP stream.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use super::handler::{ProxyConfig, ProxyMode};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelProxyError {
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("target unreachable")]
|
||||
TargetUnreachable,
|
||||
#[error("socks5 proxy handshake failed")]
|
||||
Socks5HandshakeFailed,
|
||||
#[error("socks5 proxy rejected connection")]
|
||||
Socks5ProxyRejected,
|
||||
#[error("http connect proxy handshake failed")]
|
||||
HttpConnectHandshakeFailed,
|
||||
#[error("http connect proxy rejected: {0}")]
|
||||
HttpConnectProxyRejected(String),
|
||||
#[error("io error")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub async fn connect_outbound(
|
||||
target: SocketAddr,
|
||||
proxy: &ProxyConfig,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
match &proxy.mode {
|
||||
ProxyMode::Direct => connect_direct(target).await,
|
||||
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
|
||||
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||
TcpStream::connect(target)
|
||||
.await
|
||||
.map_err(|e| map_connection_error(e, target))
|
||||
}
|
||||
|
||||
async fn connect_socks5(
|
||||
target: SocketAddr,
|
||||
proxy_addr: SocketAddr,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
let mut stream = TcpStream::connect(proxy_addr)
|
||||
.await
|
||||
.map_err(ChannelProxyError::from)?;
|
||||
|
||||
stream.write_all(&[0x05, 0x01, 0x00]).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
stream.read_exact(&mut resp).await?;
|
||||
if resp[0] != 0x05 || resp[1] != 0x00 {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
|
||||
let ip_bytes = target.ip().to_string();
|
||||
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
|
||||
connect_req.push(ip_bytes.len() as u8);
|
||||
connect_req.extend_from_slice(ip_bytes.as_bytes());
|
||||
connect_req.extend_from_slice(&target.port().to_be_bytes());
|
||||
stream.write_all(&connect_req).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut reply_header = [0u8; 4];
|
||||
stream.read_exact(&mut reply_header).await?;
|
||||
if reply_header[0] != 0x05 {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
if reply_header[1] != 0x00 {
|
||||
return Err(ChannelProxyError::Socks5ProxyRejected);
|
||||
}
|
||||
|
||||
let atyp = reply_header[3];
|
||||
match atyp {
|
||||
0x01 => {
|
||||
let mut _addr = [0u8; 4];
|
||||
stream.read_exact(&mut _addr).await?;
|
||||
}
|
||||
0x04 => {
|
||||
let mut _addr = [0u8; 16];
|
||||
stream.read_exact(&mut _addr).await?;
|
||||
}
|
||||
0x03 => {
|
||||
let len = stream.read_u8().await?;
|
||||
let mut _domain = vec![0u8; len as usize];
|
||||
stream.read_exact(&mut _domain).await?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
}
|
||||
let mut _port = [0u8; 2];
|
||||
stream.read_exact(&mut _port).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn connect_http_connect(
|
||||
target: SocketAddr,
|
||||
proxy_addr: SocketAddr,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
let mut stream = TcpStream::connect(proxy_addr)
|
||||
.await
|
||||
.map_err(ChannelProxyError::from)?;
|
||||
|
||||
let connect_request = format!(
|
||||
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
|
||||
target.ip(),
|
||||
target.port(),
|
||||
target.ip(),
|
||||
target.port()
|
||||
);
|
||||
stream.write_all(connect_request.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut response = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = stream.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
|
||||
}
|
||||
response.extend_from_slice(&buf[..n]);
|
||||
if response.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response);
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
|
||||
if status_line.contains("200") {
|
||||
Ok(stream)
|
||||
} else {
|
||||
Err(ChannelProxyError::HttpConnectProxyRejected(
|
||||
status_line.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn map_connection_error(e: std::io::Error, _target: SocketAddr) -> ChannelProxyError {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
|
||||
std::io::ErrorKind::AddrNotAvailable
|
||||
| std::io::ErrorKind::NetworkUnreachable
|
||||
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
|
||||
_ => ChannelProxyError::Io(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
|
||||
where
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
if let Ok(outbound) = connect_outbound(target, proxy).await {
|
||||
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
|
||||
let (mut read_out, mut write_out) = outbound.into_split();
|
||||
|
||||
let client_to_target = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
|
||||
let _ = write_out.shutdown().await;
|
||||
});
|
||||
|
||||
let target_to_client = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
|
||||
let _ = write_chan.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = client_to_target.await;
|
||||
let _ = target_to_client.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
fn direct_config() -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::Direct,
|
||||
}
|
||||
}
|
||||
|
||||
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::Socks5(addr),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::HttpConnect(addr),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_to_echo_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut sock, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
|
||||
let (mut read, mut write) = stream.into_split();
|
||||
write.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
read.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_target_unreachable() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let result = connect_outbound(target, &direct_config()).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_handshake() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = target_listener.local_addr().unwrap();
|
||||
|
||||
let target_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut greeting = [0u8; 3];
|
||||
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||
assert_eq!(greeting[0], 0x05);
|
||||
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||
|
||||
let mut req_header = [0u8; 4];
|
||||
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||
assert_eq!(req_header[0], 0x05);
|
||||
assert_eq!(req_header[1], 0x01);
|
||||
|
||||
let atyp = req_header[3];
|
||||
assert_eq!(atyp, 0x03);
|
||||
|
||||
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||
let mut domain = vec![0u8; domain_len];
|
||||
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||
let mut port_bytes = [0u8; 2];
|
||||
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||
|
||||
let target: SocketAddr = format!(
|
||||
"{}:{}",
|
||||
String::from_utf8_lossy(&domain),
|
||||
u16::from_be_bytes(port_bytes)
|
||||
)
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
let reply = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||
proxy_sock.write_all(&reply).await.unwrap();
|
||||
|
||||
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
||||
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||
});
|
||||
|
||||
let config = socks5_config(proxy_addr);
|
||||
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||
stream.write_all(b"hello socks").await.unwrap();
|
||||
let mut buf = [0u8; 11];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello socks");
|
||||
drop(stream);
|
||||
|
||||
let _ = target_server.await;
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_rejected() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut greeting = [0u8; 3];
|
||||
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||
|
||||
let mut req_header = [0u8; 4];
|
||||
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||
|
||||
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||
let mut domain = vec![0u8; domain_len];
|
||||
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||
let mut port_bytes = [0u8; 2];
|
||||
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||
|
||||
let reply = vec![0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||
proxy_sock.write_all(&reply).await.unwrap();
|
||||
});
|
||||
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let config = socks5_config(proxy_addr);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
ChannelProxyError::Socks5ProxyRejected
|
||||
));
|
||||
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_handshake() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = target_listener.local_addr().unwrap();
|
||||
|
||||
let target_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||
|
||||
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
|
||||
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
|
||||
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||
});
|
||||
|
||||
let config = http_connect_config(proxy_addr);
|
||||
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||
stream.write_all(b"hello http").await.unwrap();
|
||||
let mut buf = [0u8; 10];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello http");
|
||||
drop(stream);
|
||||
|
||||
let _ = target_server.await;
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
fn extract_connect_target(request: &str) -> String {
|
||||
let connect_line = request.lines().next().unwrap_or("");
|
||||
let parts: Vec<&str> = connect_line.split_whitespace().collect();
|
||||
if parts.len() >= 2 {
|
||||
parts[1].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_rejected() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
|
||||
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||
});
|
||||
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let config = http_connect_config(proxy_addr);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
ChannelProxyError::HttpConnectProxyRejected(msg) => {
|
||||
assert!(msg.contains("403"));
|
||||
}
|
||||
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
|
||||
}
|
||||
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn target_unreachable_returns_appropriate_error() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let result = connect_outbound(target, &direct_config()).await;
|
||||
match result.unwrap_err() {
|
||||
ChannelProxyError::TargetUnreachable
|
||||
| ChannelProxyError::ConnectionRefused
|
||||
| ChannelProxyError::Io(_) => {}
|
||||
other => panic!("unexpected error type: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_unreachable() {
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let config = socks5_config(bad_proxy);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_unreachable() {
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let config = http_connect_config(bad_proxy);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
struct MockChannel {
|
||||
read_half: tokio::io::ReadHalf<DuplexStream>,
|
||||
write_half: tokio::io::WriteHalf<DuplexStream>,
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncRead for MockChannel {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for MockChannel {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_mock_channel() -> (MockChannel, DuplexStream) {
|
||||
let (client, server) = duplex(4096);
|
||||
let (read_half, write_half) = tokio::io::split(client);
|
||||
(
|
||||
MockChannel {
|
||||
read_half,
|
||||
write_half,
|
||||
},
|
||||
server,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_channel_bidirectional_data_flow() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = listener.local_addr().unwrap();
|
||||
|
||||
let echo_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let (channel, mut channel_peer) = make_mock_channel();
|
||||
|
||||
let target = target_addr;
|
||||
let proxy = direct_config();
|
||||
tokio::spawn(async move {
|
||||
proxy_channel(channel, target, &proxy).await;
|
||||
});
|
||||
|
||||
channel_peer.write_all(b"ping").await.unwrap();
|
||||
channel_peer.flush().await.unwrap();
|
||||
|
||||
let mut buf = [0u8; 4];
|
||||
channel_peer.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"ping");
|
||||
|
||||
drop(channel_peer);
|
||||
let _ = echo_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_channel_target_unreachable_closes_cleanly() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let (channel, _channel_peer) = make_mock_channel();
|
||||
|
||||
let proxy = direct_config();
|
||||
proxy_channel(channel, target, &proxy).await;
|
||||
}
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
//! Control channel routing for reserved `alknet-*` destinations.
|
||||
//!
|
||||
//! SSH channels opened with a destination starting with `alknet-` are intercepted
|
||||
//! by the server and routed to a `ControlChannelHandler` instead of proxied to a
|
||||
//! TCP target. See ADR-018 for the design rationale.
|
||||
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub const ALKNET_CONTROL_DESTINATION: &str = "alknet-control";
|
||||
pub const ALKNET_PREFIX: &str = "alknet-";
|
||||
|
||||
pub fn is_reserved_destination(host: &str) -> bool {
|
||||
host.starts_with(ALKNET_PREFIX)
|
||||
}
|
||||
|
||||
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ControlChannelHandler: Send + Sync {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
|
||||
}
|
||||
|
||||
pub struct ControlChannelRouter {
|
||||
handler: Option<Box<dyn ControlChannelHandler>>,
|
||||
}
|
||||
|
||||
impl ControlChannelRouter {
|
||||
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
|
||||
pub fn without_handler() -> Self {
|
||||
Self { handler: None }
|
||||
}
|
||||
|
||||
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
Self {
|
||||
handler: Some(handler),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_handler(&self) -> bool {
|
||||
self.handler.is_some()
|
||||
}
|
||||
|
||||
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
|
||||
match &self.handler {
|
||||
Some(handler) => {
|
||||
handler.handle_channel(stream).await;
|
||||
Ok(())
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionRefused,
|
||||
"no control channel handler configured",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_handler(&mut self) -> Option<Box<dyn ControlChannelHandler>> {
|
||||
self.handler.take()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn alknet_control_destination_constant() {
|
||||
assert_eq!(ALKNET_CONTROL_DESTINATION, "alknet-control");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alknet_prefix_constant() {
|
||||
assert_eq!(ALKNET_PREFIX, "alknet-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_destination_detected() {
|
||||
assert!(is_reserved_destination("alknet-control"));
|
||||
assert!(is_reserved_destination("alknet-status"));
|
||||
assert!(is_reserved_destination("alknet-events"));
|
||||
assert!(is_reserved_destination("alknet-"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_reserved_destination_passes_through() {
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("192.168.1.1"));
|
||||
assert!(!is_reserved_destination("alknet.example.com"));
|
||||
assert!(!is_reserved_destination(""));
|
||||
assert!(!is_reserved_destination("alkne-control"));
|
||||
assert!(!is_reserved_destination("ALKNET-control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_matching_case_sensitive() {
|
||||
assert!(!is_reserved_destination("Alknet-control"));
|
||||
assert!(!is_reserved_destination("ALKNET-control"));
|
||||
assert!(is_reserved_destination("alknet-Control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_without_handler_has_no_handler() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
assert!(!router.has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_with_handler_has_handler() {
|
||||
struct DummyHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for DummyHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
|
||||
}
|
||||
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
|
||||
assert!(router.has_handler());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_without_handler_returns_error() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_succeeds() {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TrackedHandler {
|
||||
called: Arc<AtomicBool>,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackedHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let called = Arc::new(AtomicBool::new(false));
|
||||
let handler = TrackedHandler {
|
||||
called: called.clone(),
|
||||
};
|
||||
let router = ControlChannelRouter::with_handler(Box::new(handler));
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_can_read_write() {
|
||||
struct EchoHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for EchoHandler {
|
||||
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
|
||||
let mut buf = [0u8; 64];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
stream.write_all(&buf[..n]).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
|
||||
let (client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
tokio::spawn(async move {
|
||||
router.route(stream).await.unwrap();
|
||||
});
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
let mut client = client;
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_destination_matches_prefix() {
|
||||
assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION));
|
||||
}
|
||||
}
|
||||
@@ -1,974 +0,0 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::server::{Auth, Handler, Msg, Session};
|
||||
use russh::Channel;
|
||||
use russh::ChannelId;
|
||||
|
||||
use crate::auth::identity::{ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX};
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
|
||||
pub use crate::transport::TransportKind;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyMode {
|
||||
Direct,
|
||||
Socks5(SocketAddr),
|
||||
HttpConnect(SocketAddr),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyConfig {
|
||||
pub mode: ProxyMode,
|
||||
}
|
||||
|
||||
pub struct ServerHandler {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
#[allow(dead_code)]
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
#[allow(dead_code)]
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
connected_at: Instant,
|
||||
authenticated_identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
pub fn new(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let identity_provider: Arc<dyn IdentityProvider> =
|
||||
Arc::new(ConfigIdentityProvider::new(Arc::clone(&dynamic)));
|
||||
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
dynamic,
|
||||
identity_provider,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
connected_at: Instant::now(),
|
||||
authenticated_identity: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_identity_provider(mut self, provider: Arc<dyn IdentityProvider>) -> Self {
|
||||
self.identity_provider = provider;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn authenticated_identity(&self) -> Option<&Identity> {
|
||||
self.authenticated_identity.as_ref()
|
||||
}
|
||||
|
||||
pub fn is_connection_allowed(&self) -> bool {
|
||||
self.connection_allowed
|
||||
}
|
||||
|
||||
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||
self.remote_addr.map(|a| a.ip())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
}
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
pub fn with_control_channel_handler(mut self, handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn control_channel_router(&self) -> &ControlChannelRouter {
|
||||
&self.control_channel_router
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for ServerHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn auth_publickey(
|
||||
&mut self,
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
|
||||
let identity = self
|
||||
.identity_provider
|
||||
.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match identity {
|
||||
Some(id) => {
|
||||
self.authenticated_identity = Some(id);
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(Auth::Accept)
|
||||
}
|
||||
None => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn channel_open_direct_tcpip(
|
||||
&mut self,
|
||||
channel: Channel<Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
originator_address: &str,
|
||||
originator_port: u32,
|
||||
_session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
if host_to_connect.starts_with(ALKNET_PREFIX) {
|
||||
if !self.control_channel_router.has_handler() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let _ = channel;
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let identity = self
|
||||
.authenticated_identity
|
||||
.clone()
|
||||
.unwrap_or_else(|| Identity {
|
||||
id: String::new(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
});
|
||||
|
||||
let policy = self.dynamic.load();
|
||||
let allowed = policy.forwarding.check(
|
||||
host_to_connect,
|
||||
port_to_connect as u16,
|
||||
&identity,
|
||||
self.transport.clone(),
|
||||
);
|
||||
|
||||
if !allowed {
|
||||
tracing::info!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
identity = %identity.id,
|
||||
transport = %self.transport,
|
||||
"forwarding denied by policy"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let target_host = host_to_connect.to_string();
|
||||
let target_port = port_to_connect;
|
||||
let proxy_config = self.outbound_proxy.clone().unwrap_or(ProxyConfig {
|
||||
mode: ProxyMode::Direct,
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let target =
|
||||
match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16))
|
||||
.await
|
||||
{
|
||||
Ok(mut addrs) => match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return,
|
||||
},
|
||||
Err(_) => return,
|
||||
},
|
||||
};
|
||||
crate::server::channel_proxy::proxy_channel(
|
||||
channel.into_stream(),
|
||||
target,
|
||||
&proxy_config,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let _ = (originator_address, originator_port);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn channel_open_session(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected session channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_x11(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected x11 channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_forwarded_tcpip(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn exec_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
data: &[u8],
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
data_len = data.len(),
|
||||
"rejected exec request on channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shell_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected shell request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subsystem_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
name: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
subsystem = name,
|
||||
"rejected subsystem request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pty_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
term: &str,
|
||||
col_width: u32,
|
||||
row_height: u32,
|
||||
pix_width: u32,
|
||||
pix_height: u32,
|
||||
modes: &[(russh::Pty, u32)],
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
term = term,
|
||||
"rejected pty request on channel"
|
||||
);
|
||||
let _ = (col_width, row_height, pix_width, pix_height, modes);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn env_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
variable_name: &str,
|
||||
variable_value: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
variable = variable_name,
|
||||
"rejected env request on channel"
|
||||
);
|
||||
let _ = variable_value;
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn x11_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
single_connection: bool,
|
||||
x11_auth_protocol: &str,
|
||||
x11_auth_cookie: &str,
|
||||
x11_screen_number: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected x11 request on channel"
|
||||
);
|
||||
let _ = (
|
||||
single_connection,
|
||||
x11_auth_protocol,
|
||||
x11_auth_cookie,
|
||||
x11_screen_number,
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn agent_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected agent forwarding request on channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: &mut u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
address = address,
|
||||
port = *port,
|
||||
"rejected tcpip-forward request (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn cancel_tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
let _ = (address, port, session);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn streamlocal_forward(
|
||||
&mut self,
|
||||
socket_path: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
socket_path = socket_path,
|
||||
"rejected streamlocal-forward request"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn signal(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
signal: russh::Sig,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::debug!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
signal = ?signal,
|
||||
"received signal on channel (ignored)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::{decode_secret_key, PrivateKey};
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_auth_config(keys_content: &str) -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||
}
|
||||
|
||||
fn make_empty_auth_config() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let dynamic = DynamicConfig::default();
|
||||
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||
}
|
||||
|
||||
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||
Arc::new(ConnectionRateLimiter::new(0))
|
||||
}
|
||||
|
||||
fn make_handler(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
) -> ServerHandler {
|
||||
ServerHandler::new(
|
||||
dynamic,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_accepts_known_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_rejects_unknown_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key =
|
||||
russh::keys::parse_public_key_base64(other_key_text.split_whitespace().nth(1).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let result = handler
|
||||
.auth_publickey("testuser", &other_ssh_key)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_empty_config_rejects_all() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_logging_includes_remote_addr() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_alknet_destination_routing() {
|
||||
use crate::server::control_channel::is_reserved_destination;
|
||||
assert!(is_reserved_destination("alknet-control"));
|
||||
assert!(is_reserved_destination("alknet-status"));
|
||||
assert!(is_reserved_destination("alknet-events"));
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("alknet.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_without_control_handler_rejects_alknet_destinations() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler = make_handler(auth_config, None, None);
|
||||
assert!(!handler.control_channel_router().has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_mode_variants() {
|
||||
let direct = ProxyMode::Direct;
|
||||
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
|
||||
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
|
||||
|
||||
match direct {
|
||||
ProxyMode::Direct => {}
|
||||
_ => panic!("expected Direct"),
|
||||
}
|
||||
match socks5 {
|
||||
ProxyMode::Socks5(_) => {}
|
||||
_ => panic!("expected Socks5"),
|
||||
}
|
||||
match http {
|
||||
ProxyMode::HttpConnect(_) => {}
|
||||
_ => panic!("expected HttpConnect"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_holds_config() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let proxy = Some(ProxyConfig {
|
||||
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
|
||||
});
|
||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||
|
||||
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||
assert!(handler.outbound_proxy.is_some());
|
||||
assert!(handler.remote_addr.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_handler_per_connection() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler1 = make_handler(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
);
|
||||
let handler2 = make_handler(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some("10.0.0.2:22".parse().unwrap()),
|
||||
);
|
||||
|
||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
2,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
|
||||
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
r1,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
|
||||
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
r2,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
|
||||
assert!(!handler.auth_limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_rate_limit_blocks_over_limit() {
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let auth_config = make_empty_auth_config();
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(h1.is_connection_allowed());
|
||||
|
||||
let h2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(!h2.is_connection_allowed());
|
||||
|
||||
drop(h1);
|
||||
|
||||
let h3 = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(h3.is_connection_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_display() {
|
||||
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportKind::Tls { server_name: None }.to_string(), "tls");
|
||||
assert_eq!(
|
||||
TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
}
|
||||
.to_string(),
|
||||
"iroh"
|
||||
);
|
||||
assert_eq!(
|
||||
TransportKind::WebTransport { server_name: None }.to_string(),
|
||||
"webtransport"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_log_includes_user_field() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tls { server_name: None },
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_logs_duration_on_drop() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let _handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn config_reload_new_keys_take_effect() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
drop(handler);
|
||||
|
||||
let new_dynamic = DynamicConfig::default();
|
||||
auth_config.store(Arc::new(new_dynamic));
|
||||
|
||||
let mut handler2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let result2 = handler2.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarding_policy_deny_blocks_channel_open() {
|
||||
use crate::config::forwarding::{
|
||||
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
|
||||
};
|
||||
|
||||
let deny_policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
{
|
||||
let dynamic = auth_config.load();
|
||||
let new_dynamic = DynamicConfig {
|
||||
auth: dynamic.auth.clone(),
|
||||
forwarding: deny_policy,
|
||||
rate_limits: dynamic.rate_limits.clone(),
|
||||
credentials: dynamic.credentials.clone(),
|
||||
};
|
||||
drop(dynamic);
|
||||
auth_config.store(Arc::new(new_dynamic));
|
||||
}
|
||||
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("127.0.0.1:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
assert!(handler.authenticated_identity().is_some());
|
||||
|
||||
let identity = handler.authenticated_identity().unwrap();
|
||||
let dynamic = handler.dynamic.load();
|
||||
assert!(!dynamic.forwarding.check(
|
||||
"blocked.example.com",
|
||||
443,
|
||||
identity,
|
||||
TransportKind::Tcp
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_deny_with_custom_identity() {
|
||||
use crate::config::forwarding::{
|
||||
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["gitea".to_string()]);
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources,
|
||||
};
|
||||
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("allowed.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["SHA256:abc123".to_string()],
|
||||
transports: vec![TransportKind::Tcp],
|
||||
}],
|
||||
};
|
||||
|
||||
assert!(policy.check("allowed.example.com", 443, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check("denied.example.com", 443, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_with_custom_identity_provider() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct MockIdentityProvider {
|
||||
identities: HashMap<String, Identity>,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
self.identities.get(fingerprint).cloned()
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &crate::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let mut identities = HashMap::new();
|
||||
identities.insert(
|
||||
"SHA256:testkey".to_string(),
|
||||
Identity {
|
||||
id: "SHA256:testkey".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
},
|
||||
);
|
||||
|
||||
let provider = Arc::new(MockIdentityProvider { identities }) as Arc<dyn IdentityProvider>;
|
||||
let dynamic = make_empty_auth_config();
|
||||
|
||||
let handler = ServerHandler::new(
|
||||
dynamic,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
)
|
||||
.with_identity_provider(provider);
|
||||
|
||||
assert!(handler.authenticated_identity().is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//! Server-side SSH connection handling.
|
||||
//!
|
||||
//! Provides `Server` for accepting SSH connections over any transport and proxying
|
||||
//! `direct-tcpip` channel requests to targets. Supports Ed25519 and certificate-authority
|
||||
//! auth, connection rate limiting, auth attempt limiting, stealth mode (fake nginx 404),
|
||||
//! and outbound proxy routing (direct/SOCKS5/HTTP CONNECT).
|
||||
//!
|
||||
//! Destination hosts starting with `alknet-` are reserved for internal use (control channel, ADR-018).
|
||||
|
||||
pub mod channel_proxy;
|
||||
pub mod control_channel;
|
||||
pub mod handler;
|
||||
pub mod rate_limit;
|
||||
pub mod serve;
|
||||
pub mod stealth;
|
||||
|
||||
pub use channel_proxy::{connect_outbound, proxy_channel};
|
||||
pub use control_channel::{
|
||||
is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream,
|
||||
ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX,
|
||||
};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
pub use serve::{
|
||||
DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions,
|
||||
ServeTransportMode, Server, StreamListenerConfig,
|
||||
};
|
||||
|
||||
pub use crate::transport::TransportKind;
|
||||
pub use stealth::{
|
||||
detect_protocol, handle_http_stealth, send_fake_nginx_404, validate_stealth_config,
|
||||
ProtocolDetection,
|
||||
};
|
||||
@@ -1,200 +0,0 @@
|
||||
//! Connection rate limiting and auth attempt limiting.
|
||||
//!
|
||||
//! `ConnectionRateLimiter` tracks per-IP active connections (thread-safe).
|
||||
//! `AuthAttemptLimiter` caps failed auth attempts per connection.
|
||||
//! These complement fail2ban on Linux and provide abuse protection on all platforms.
|
||||
//! See ADR-013.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub struct ConnectionRateLimiter {
|
||||
max_per_ip: usize,
|
||||
active: Mutex<HashMap<IpAddr, usize>>,
|
||||
}
|
||||
|
||||
impl ConnectionRateLimiter {
|
||||
pub fn new(max_per_ip: usize) -> Self {
|
||||
Self {
|
||||
max_per_ip,
|
||||
active: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, ip: IpAddr) -> bool {
|
||||
if self.max_per_ip == 0 {
|
||||
return true;
|
||||
}
|
||||
let active = self.active.lock().unwrap();
|
||||
let count = active.get(&ip).copied().unwrap_or(0);
|
||||
count < self.max_per_ip
|
||||
}
|
||||
|
||||
pub fn on_connect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
*active.entry(ip).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
if let Some(count) = active.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
active.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthAttemptLimiter {
|
||||
max_attempts: usize,
|
||||
failures: usize,
|
||||
}
|
||||
|
||||
impl AuthAttemptLimiter {
|
||||
pub fn new(max_attempts: usize) -> Self {
|
||||
Self {
|
||||
max_attempts,
|
||||
failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self) -> bool {
|
||||
if self.max_attempts == 0 {
|
||||
return true;
|
||||
}
|
||||
self.failures < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn on_failure(&mut self) {
|
||||
self.failures += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
fn ip(n: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_when_under_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_blocks_when_at_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_after_disconnect() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
limiter.on_disconnect(ip(1));
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_unlimited_when_zero() {
|
||||
let limiter = ConnectionRateLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_connect(ip(1));
|
||||
}
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_tracks_per_ip_independently() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
assert!(limiter.check(ip(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_ipv6() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||
limiter.on_connect(ip6);
|
||||
assert!(!limiter.check(ip6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_disconnect(ip(1));
|
||||
{
|
||||
let active = limiter.active.lock().unwrap();
|
||||
assert!(!active.contains_key(&ip(1)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_allows_when_under_limit() {
|
||||
let limiter = AuthAttemptLimiter::new(3);
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_blocks_after_max_failures() {
|
||||
let mut limiter = AuthAttemptLimiter::new(2);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_unlimited_when_zero() {
|
||||
let mut limiter = AuthAttemptLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_failure();
|
||||
}
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||
let mut limiter = AuthAttemptLimiter::new(3);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(limiter.check());
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_thread_safety() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let lim = Arc::clone(&limiter);
|
||||
handles.push(thread::spawn(move || {
|
||||
let ip_addr = ip((i % 3) as u8 + 1);
|
||||
lim.on_connect(ip_addr);
|
||||
assert!(lim.check(ip_addr));
|
||||
lim.on_disconnect(ip_addr);
|
||||
}));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,316 +0,0 @@
|
||||
//! Stealth mode: protocol detection on TLS connections.
|
||||
//!
|
||||
//! When stealth mode is enabled with TLS transport, the server peeks at the first
|
||||
//! bytes after the TLS handshake to determine whether the client is speaking SSH
|
||||
//! or HTTP. When the `http` feature is enabled, detected HTTP traffic is routed to
|
||||
//! the axum router. When `http` is disabled, non-SSH connections receive a fake
|
||||
//! nginx 404 response, making the server appear as an ordinary web server to port
|
||||
//! scanners and DPI systems. See ADR-017.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
|
||||
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
|
||||
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProtocolDetection {
|
||||
Ssh,
|
||||
Http,
|
||||
}
|
||||
|
||||
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let detection = match reader.fill_buf().await {
|
||||
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
|
||||
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
Ok(buf) if !buf.is_empty() => {
|
||||
if buf.starts_with(SSH_BANNER_PREFIX) {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
_ => ProtocolDetection::Http,
|
||||
};
|
||||
|
||||
(detection, reader)
|
||||
}
|
||||
|
||||
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
|
||||
let _ = reader.get_mut().shutdown().await;
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub async fn handle_http_stealth<S>(
|
||||
reader: BufReader<S>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
crate::http::router::serve_connection_from_reader(reader, identity_provider).await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "http"))]
|
||||
pub async fn handle_http_stealth<S>(
|
||||
mut reader: BufReader<S>,
|
||||
_identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
send_fake_nginx_404(&mut reader).await
|
||||
}
|
||||
|
||||
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
|
||||
if stealth && !transport_is_tls {
|
||||
return Err("stealth mode requires TLS transport (--transport tls)");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client.write_all(data).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
detection
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_detected() {
|
||||
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_other_implementation() {
|
||||
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_minimal() {
|
||||
let detection = write_and_detect(b"SSH-2.0-X\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_get_detected_as_http() {
|
||||
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_post_detected_as_http() {
|
||||
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn random_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_stream_detected_as_http() {
|
||||
let (client, server) = duplex(1024);
|
||||
drop(client);
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_bytes_preserved_by_bufreader() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
|
||||
client.write_all(banner).await.unwrap();
|
||||
client.write_all(b"subsequent data").await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
|
||||
let mut all_data = Vec::new();
|
||||
reader.read_to_end(&mut all_data).await.unwrap();
|
||||
assert!(
|
||||
all_data.starts_with(banner),
|
||||
"banner bytes must be preserved after detection"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fake_nginx_404_response() {
|
||||
let (client, server) = duplex(1024);
|
||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||
|
||||
client_write
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
drop(client_write);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client_read.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.contains("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn protocol_detection_enum_equality() {
|
||||
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
|
||||
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
|
||||
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_without_tls_rejected() {
|
||||
let result = validate_stealth_config(true, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("TLS transport"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(true, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tcp_accepted() {
|
||||
let result = validate_stealth_config(false, false);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(false, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"GE").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_ssh_prefix_detected_as_http() {
|
||||
let detection = write_and_detect(b"SSH-1.").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_request_gets_404_then_closed() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
|
||||
let mut extra = [0u8; 16];
|
||||
let result = client.read(&mut extra).await;
|
||||
assert!(result.is_err() || result.unwrap() == 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[tokio::test]
|
||||
async fn stealth_handoff_routes_http_to_axum() {
|
||||
use crate::auth::{AuthToken, IdentityProvider};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
struct NullProvider;
|
||||
|
||||
impl IdentityProvider for NullProvider {
|
||||
fn resolve_from_fingerprint(
|
||||
&self,
|
||||
_fingerprint: &str,
|
||||
) -> Option<crate::auth::Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<crate::auth::Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let (client, server) = duplex(4096);
|
||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||
|
||||
client_write
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
drop(client_write);
|
||||
|
||||
let (detection, reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NullProvider);
|
||||
let handle = tokio::spawn(async move {
|
||||
handle_http_stealth(reader, provider).await;
|
||||
});
|
||||
|
||||
let mut buf = Vec::new();
|
||||
tokio::io::AsyncReadExt::read_to_end(&mut client_read, &mut buf)
|
||||
.await
|
||||
.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf);
|
||||
assert!(
|
||||
response.contains("401"),
|
||||
"expected 401 from axum auth middleware, got: {response}"
|
||||
);
|
||||
assert!(
|
||||
!response.contains("nginx"),
|
||||
"should not contain fake nginx response when http feature is enabled"
|
||||
);
|
||||
|
||||
let _ = handle.await;
|
||||
}
|
||||
}
|
||||
@@ -1,490 +0,0 @@
|
||||
//! SOCKS5 proxy server.
|
||||
//!
|
||||
//! Listens on a local port and routes each SOCKS5 connection through an SSH
|
||||
//! `direct-tcpip` channel. Supports SOCKS5h (domain names resolved server-side)
|
||||
//! to prevent DNS leaks. Uses the `ChannelOpener` trait to abstract over the
|
||||
//! SSH channel mechanism, making it testable without a real SSH session.
|
||||
|
||||
mod protocol;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::debug;
|
||||
|
||||
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
|
||||
|
||||
pub use protocol::Socks5Address;
|
||||
|
||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||
|
||||
pub trait ChannelOpener: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
fn open_channel(
|
||||
&self,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelOpenError {
|
||||
#[error("session closed")]
|
||||
SessionClosed,
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed,
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
}
|
||||
|
||||
pub struct Socks5Server<C: ChannelOpener> {
|
||||
listen_addr: SocketAddr,
|
||||
channel_opener: Arc<C>,
|
||||
}
|
||||
|
||||
impl<C: ChannelOpener> Socks5Server<C> {
|
||||
pub fn new(channel_opener: C) -> Self {
|
||||
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
|
||||
}
|
||||
|
||||
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
||||
let listen_addr: SocketAddr = addr.parse().expect("invalid SOCKS5 listen address");
|
||||
Self {
|
||||
listen_addr,
|
||||
channel_opener: Arc::new(channel_opener),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), std::io::Error> {
|
||||
let listener = TcpListener::bind(self.listen_addr).await?;
|
||||
debug!("socks5 server listening on {}", self.listen_addr);
|
||||
loop {
|
||||
let (socket, _peer) = listener.accept().await?;
|
||||
let opener = Arc::clone(&self.channel_opener);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_socks5_connection(socket, opener).await {
|
||||
debug!("socks5 connection error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_socks5_connection<S, C>(mut socket: S, opener: Arc<C>) -> Result<(), Socks5Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
C: ChannelOpener,
|
||||
{
|
||||
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
|
||||
if vm.version != 0x05 {
|
||||
return Err(Socks5Error::InvalidVersion(vm.version));
|
||||
}
|
||||
if !vm.methods.contains(&0x00) {
|
||||
let reply = [0x05, 0xFF];
|
||||
socket.write_all(&reply).await?;
|
||||
socket.shutdown().await?;
|
||||
return Err(Socks5Error::NoAcceptableAuth);
|
||||
}
|
||||
let reply = [0x05, 0x00];
|
||||
socket.write_all(&reply).await?;
|
||||
|
||||
let request = Socks5Request::read_from(&mut socket).await?;
|
||||
if request.version != 0x05 {
|
||||
return Err(Socks5Error::InvalidVersion(request.version));
|
||||
}
|
||||
if request.command != 0x01 {
|
||||
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
|
||||
return Err(Socks5Error::UnsupportedCommand(request.command));
|
||||
}
|
||||
|
||||
let (host, port) = match &request.address {
|
||||
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
|
||||
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
|
||||
Socks5Address::Domain(name) => (name.clone(), request.port),
|
||||
};
|
||||
|
||||
match opener.open_channel(host, port).await {
|
||||
Ok(mut ssh_stream) => {
|
||||
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
|
||||
let reply = Socks5Reply::success(bind_addr, 0);
|
||||
reply.write_to(&mut socket).await?;
|
||||
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => {
|
||||
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
|
||||
Err(Socks5Error::ChannelOpenFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
socket: &mut S,
|
||||
reply: Socks5Reply,
|
||||
) -> Result<(), Socks5Error> {
|
||||
reply.write_to(socket).await?;
|
||||
let _ = socket.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Socks5Error {
|
||||
#[error("invalid SOCKS version: {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error("no acceptable auth method")]
|
||||
NoAcceptableAuth,
|
||||
#[error("unsupported command: {0}")]
|
||||
UnsupportedCommand(u8),
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed,
|
||||
#[error("io error")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub struct HandleChannelOpener<H: russh::client::Handler> {
|
||||
handle: Arc<Mutex<russh::client::Handle<H>>>,
|
||||
}
|
||||
|
||||
impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
||||
pub fn new(handle: russh::client::Handle<H>) -> Self {
|
||||
Self {
|
||||
handle: Arc::new(Mutex::new(handle)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
|
||||
Self { handle }
|
||||
}
|
||||
}
|
||||
|
||||
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
||||
type Stream = russh::ChannelStream<russh::client::Msg>;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
let handle = self.handle.lock().await;
|
||||
if handle.is_closed() {
|
||||
return Err(ChannelOpenError::SessionClosed);
|
||||
}
|
||||
let channel = handle
|
||||
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
|
||||
.await
|
||||
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
|
||||
Ok(channel.into_stream())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
|
||||
struct MockChannelOpener {
|
||||
fail: bool,
|
||||
}
|
||||
|
||||
impl ChannelOpener for MockChannelOpener {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
if self.fail {
|
||||
Err(ChannelOpenError::ChannelOpenFailed)
|
||||
} else {
|
||||
let (client, _server) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, methods.len() as u8];
|
||||
buf.extend_from_slice(methods);
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
|
||||
buf.extend_from_slice(&addr);
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
|
||||
buf.push(domain.len() as u8);
|
||||
buf.extend_from_slice(domain.as_bytes());
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
|
||||
buf.extend_from_slice(&addr);
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
||||
client
|
||||
.write_all(&build_socks5_greeting(&[0x00]))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
let mut resp = [0u8; 2];
|
||||
client.read_exact(&mut resp).await.unwrap();
|
||||
resp
|
||||
}
|
||||
|
||||
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||
client
|
||||
.write_all(&build_socks5_connect_ipv4(addr, port))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
reply_buf.to_vec()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_no_auth_method() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
let resp = do_handshake(&mut client).await;
|
||||
assert_eq!(resp, [0x05, 0x00]);
|
||||
|
||||
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_rejects_no_acceptable_method() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
client
|
||||
.write_all(&build_socks5_greeting(&[0x02]))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
client.read_exact(&mut resp).await.unwrap();
|
||||
assert_eq!(resp, [0x05, 0xFF]);
|
||||
|
||||
drop(client);
|
||||
let result = server_handle.await.unwrap();
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), Socks5Error::NoAcceptableAuth));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_ipv4() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_domain() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
client
|
||||
.write_all(&build_socks5_connect_domain("example.com", 443))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_ipv6() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||
client
|
||||
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_open_failure_returns_socks5_error() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: true };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x05);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_command_returns_error() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
|
||||
bind_req.extend_from_slice(&[127, 0, 0, 1]);
|
||||
bind_req.extend_from_slice(&80u16.to_be_bytes());
|
||||
client.write_all(&bind_req).await.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[1], 0x07);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bidirectional_proxy_flow() {
|
||||
let (mut client_sock, server_sock) = duplex(4096);
|
||||
let (ssh_client, mut ssh_server) = duplex(4096);
|
||||
|
||||
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
|
||||
|
||||
struct ProxyOpener {
|
||||
stream: Arc<Mutex<Option<DuplexStream>>>,
|
||||
}
|
||||
|
||||
impl ChannelOpener for ProxyOpener {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
self.stream
|
||||
.lock()
|
||||
.await
|
||||
.take()
|
||||
.ok_or(ChannelOpenError::ChannelOpenFailed)
|
||||
}
|
||||
}
|
||||
|
||||
let opener = ProxyOpener {
|
||||
stream: Arc::clone(&ssh_stream),
|
||||
};
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(
|
||||
async move { handle_socks5_connection(server_sock, Arc::new(opener)).await },
|
||||
);
|
||||
|
||||
do_handshake(&mut client_sock).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
let test_data = b"hello through tunnel";
|
||||
client_sock.write_all(test_data).await.unwrap();
|
||||
client_sock.flush().await.unwrap();
|
||||
|
||||
let mut received = vec![0u8; test_data.len()];
|
||||
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(&received, test_data);
|
||||
|
||||
let echo_data = b"response from tunnel";
|
||||
ssh_server.write_all(echo_data).await.unwrap();
|
||||
ssh_server.flush().await.unwrap();
|
||||
|
||||
let mut received_back = vec![0u8; echo_data.len()];
|
||||
client_sock.read_exact(&mut received_back).await.unwrap();
|
||||
assert_eq!(&received_back, echo_data);
|
||||
|
||||
drop(client_sock);
|
||||
drop(ssh_server);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_listen_address() {
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
let server = Socks5Server::new(opener);
|
||||
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn custom_listen_address() {
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
|
||||
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
|
||||
}
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Socks5Address {
|
||||
Ipv4(Ipv4Addr),
|
||||
Ipv6(Ipv6Addr),
|
||||
Domain(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5VersionMethod {
|
||||
pub version: u8,
|
||||
pub methods: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Socks5VersionMethod {
|
||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||
let version = reader.read_u8().await?;
|
||||
let nmethods = reader.read_u8().await?;
|
||||
let mut methods = vec![0u8; nmethods as usize];
|
||||
reader.read_exact(&mut methods).await?;
|
||||
Ok(Self { version, methods })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5Request {
|
||||
pub version: u8,
|
||||
pub command: u8,
|
||||
pub address: Socks5Address,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl Socks5Request {
|
||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||
let version = reader.read_u8().await?;
|
||||
let command = reader.read_u8().await?;
|
||||
let _rsv = reader.read_u8().await?;
|
||||
let atyp = reader.read_u8().await?;
|
||||
|
||||
let address = match atyp {
|
||||
0x01 => {
|
||||
let mut octets = [0u8; 4];
|
||||
reader.read_exact(&mut octets).await?;
|
||||
Socks5Address::Ipv4(Ipv4Addr::from(octets))
|
||||
}
|
||||
0x04 => {
|
||||
let mut octets = [0u8; 16];
|
||||
reader.read_exact(&mut octets).await?;
|
||||
Socks5Address::Ipv6(Ipv6Addr::from(octets))
|
||||
}
|
||||
0x03 => {
|
||||
let len = reader.read_u8().await?;
|
||||
let mut domain = vec![0u8; len as usize];
|
||||
reader.read_exact(&mut domain).await?;
|
||||
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
|
||||
}
|
||||
_ => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("unsupported address type: {atyp}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let port = reader.read_u16().await?;
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
command,
|
||||
address,
|
||||
port,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5Reply {
|
||||
pub version: u8,
|
||||
pub reply: u8,
|
||||
pub address: Socks5Address,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl Socks5Reply {
|
||||
pub fn success(address: Socks5Address, port: u16) -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x00,
|
||||
address,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connection_refused() -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x05,
|
||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn command_not_supported() -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x07,
|
||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u8(self.version).await?;
|
||||
writer.write_u8(self.reply).await?;
|
||||
writer.write_u8(0x00).await?;
|
||||
match &self.address {
|
||||
Socks5Address::Ipv4(addr) => {
|
||||
writer.write_u8(0x01).await?;
|
||||
writer.write_all(&addr.octets()).await?;
|
||||
}
|
||||
Socks5Address::Ipv6(addr) => {
|
||||
writer.write_u8(0x04).await?;
|
||||
writer.write_all(&addr.octets()).await?;
|
||||
}
|
||||
Socks5Address::Domain(name) => {
|
||||
writer.write_u8(0x03).await?;
|
||||
writer.write_u8(name.len() as u8).await?;
|
||||
writer.write_all(name.as_bytes()).await?;
|
||||
}
|
||||
}
|
||||
writer.write_u16(self.port).await?;
|
||||
writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_version_method_no_auth() {
|
||||
let data = [0x05, 0x01, 0x00];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(vm.version, 0x05);
|
||||
assert_eq!(vm.methods, vec![0x00]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_version_method_multiple() {
|
||||
let data = [0x05, 0x02, 0x00, 0x02];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(vm.version, 0x05);
|
||||
assert_eq!(vm.methods, vec![0x00, 0x02]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_ipv4() {
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x01];
|
||||
data.extend_from_slice(&[10, 0, 0, 1]);
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert_eq!(req.address, Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1)));
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_ipv6() {
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x04];
|
||||
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||
data.extend_from_slice(&octets);
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_domain() {
|
||||
let domain = "example.com";
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x03];
|
||||
data.push(domain.len() as u8);
|
||||
data.extend_from_slice(domain.as_bytes());
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert_eq!(
|
||||
req.address,
|
||||
Socks5Address::Domain("example.com".to_string())
|
||||
);
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_unsupported_address_type() {
|
||||
let data = [0x05, 0x01, 0x00, 0x05];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let result = Socks5Request::read_from(&mut cursor).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_success_ipv4() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x00);
|
||||
assert_eq!(buf[2], 0x00);
|
||||
assert_eq!(buf[3], 0x01);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_connection_refused() {
|
||||
let reply = Socks5Reply::connection_refused();
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x05);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_command_not_supported() {
|
||||
let reply = Socks5Reply::command_not_supported();
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x07);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_ipv4_reply() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(version, 0x05);
|
||||
assert_eq!(atyp, 0x01);
|
||||
let mut octets = [0u8; 4];
|
||||
cursor.read_exact(&mut octets).await.unwrap();
|
||||
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 1080);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_ipv6_reply() {
|
||||
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let _version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(atyp, 0x04);
|
||||
let mut octets = [0u8; 16];
|
||||
cursor.read_exact(&mut octets).await.unwrap();
|
||||
assert_eq!(Ipv6Addr::from(octets), addr);
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_domain_reply() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let _version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(atyp, 0x03);
|
||||
let len = cursor.read_u8().await.unwrap();
|
||||
let mut domain = vec![0u8; len as usize];
|
||||
cursor.read_exact(&mut domain).await.unwrap();
|
||||
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 8080);
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
|
||||
|
||||
#[cfg(feature = "transport-traits")]
|
||||
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
mod local_traits {
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls { server_name: Option<String> },
|
||||
Iroh { endpoint_id: String },
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockStream {
|
||||
inner: DuplexStream,
|
||||
}
|
||||
|
||||
impl MockStream {
|
||||
pub fn new(inner: DuplexStream) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for MockStream {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MockStream {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for MockStream {}
|
||||
|
||||
pub struct MockTransport {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransport {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (client, _) = tokio::io::duplex(self.buf_size);
|
||||
Ok(MockStream::new(client))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockTransportAcceptor {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransportAcceptor {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportAcceptor for MockTransportAcceptor {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (_, server) = tokio::io::duplex(self.buf_size);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((MockStream::new(server), info))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
|
||||
let (client, server) = tokio::io::duplex(buf_size);
|
||||
(MockStream::new(client), MockStream::new(server))
|
||||
}
|
||||
@@ -1,352 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
use rustls::ServerConfig;
|
||||
use rustls_acme::caches::DirCache;
|
||||
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
||||
use tracing::{error, info};
|
||||
|
||||
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AcmeMode {
|
||||
Domain { domain: String },
|
||||
Ip,
|
||||
}
|
||||
|
||||
pub struct AcmeCertProvider {
|
||||
mode: AcmeMode,
|
||||
cache_dir: Option<PathBuf>,
|
||||
directory_url: String,
|
||||
contact: Vec<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AcmeCertProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AcmeCertProvider")
|
||||
.field("mode", &self.mode)
|
||||
.field("cache_dir", &self.cache_dir)
|
||||
.field("directory_url", &self.directory_url)
|
||||
.field("contact", &self.contact)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl AcmeCertProvider {
|
||||
pub fn new(mode: AcmeMode) -> Self {
|
||||
Self {
|
||||
mode,
|
||||
cache_dir: None,
|
||||
directory_url: rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY.to_string(),
|
||||
contact: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn domain(domain: impl Into<String>) -> Self {
|
||||
Self::new(AcmeMode::Domain {
|
||||
domain: domain.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ip() -> Self {
|
||||
Self::new(AcmeMode::Ip)
|
||||
}
|
||||
|
||||
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||
self.cache_dir = Some(dir.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_directory(mut self, url: impl Into<String>) -> Self {
|
||||
self.directory_url = url.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_production_directory(mut self) -> Self {
|
||||
self.directory_url = rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_contact(mut self, contact: impl Into<String>) -> Self {
|
||||
self.contact.push(contact.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> &AcmeMode {
|
||||
&self.mode
|
||||
}
|
||||
|
||||
fn build_acme_state(&self) -> (AcmeState<std::io::Error>, Arc<ResolvesServerCertAcme>) {
|
||||
let domains: Vec<String> = match &self.mode {
|
||||
AcmeMode::Domain { domain } => vec![domain.clone()],
|
||||
AcmeMode::Ip => vec![],
|
||||
};
|
||||
|
||||
let base_config = AcmeConfig::new(domains)
|
||||
.directory(&self.directory_url)
|
||||
.contact(self.contact.clone());
|
||||
|
||||
let state = match &self.cache_dir {
|
||||
Some(cache_dir) => base_config.cache(DirCache::new(cache_dir.clone())).state(),
|
||||
None => base_config
|
||||
.cache(rustls_acme::caches::NoCache::default())
|
||||
.state(),
|
||||
};
|
||||
|
||||
let resolver = state.resolver();
|
||||
(state, resolver)
|
||||
}
|
||||
|
||||
pub fn build_server_config_with_resolver(
|
||||
&self,
|
||||
resolver: Arc<ResolvesServerCertAcme>,
|
||||
) -> Result<Arc<ServerConfig>> {
|
||||
let provider = default_provider().into();
|
||||
let mut config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(resolver);
|
||||
config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcmeTlsAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
#[allow(dead_code)]
|
||||
server_config: Arc<ServerConfig>,
|
||||
tokio_acceptor: TokioTlsAcceptor,
|
||||
}
|
||||
|
||||
impl AcmeTlsAcceptor {
|
||||
pub async fn bind_acme(addr: SocketAddr, provider: Arc<AcmeCertProvider>) -> Result<Self> {
|
||||
let (state, resolver) = provider.build_acme_state();
|
||||
|
||||
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
||||
|
||||
Self::spawn_state_worker(state, resolver);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
|
||||
fn spawn_state_worker(state: AcmeState<std::io::Error>, resolver: Arc<ResolvesServerCertAcme>) {
|
||||
use futures::StreamExt;
|
||||
|
||||
let task = async move {
|
||||
let mut state = state;
|
||||
while let Some(event) = state.next().await {
|
||||
match event {
|
||||
Ok(ok) => {
|
||||
if let rustls_acme::EventOk::DeployedNewCert = ok {
|
||||
info!("ACME: new certificate deployed");
|
||||
} else {
|
||||
info!("ACME event: {:?}", ok);
|
||||
}
|
||||
}
|
||||
Err(err) => error!("ACME event error: {:?}", err),
|
||||
}
|
||||
if Arc::strong_count(&resolver) == 1 {
|
||||
info!("ACME resolver dropped, stopping background task");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
tokio::spawn(task);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportAcceptor for AcmeTlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||
|
||||
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tls { server_name },
|
||||
};
|
||||
|
||||
Ok((tls_stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_domain_mode() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||
if let AcmeMode::Domain { domain } = provider.mode() {
|
||||
assert_eq!(domain, "example.com");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_ip_mode() {
|
||||
let provider = AcmeCertProvider::ip();
|
||||
assert!(matches!(provider.mode(), AcmeMode::Ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_default_staging_directory() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_production_directory() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_production_directory();
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_custom_directory() {
|
||||
let provider =
|
||||
AcmeCertProvider::domain("example.com").with_directory("https://custom.acme.dir/");
|
||||
assert_eq!(provider.directory_url, "https://custom.acme.dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_with_cache_dir() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/acme_cache");
|
||||
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/acme_cache")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_with_contact() {
|
||||
let provider =
|
||||
AcmeCertProvider::domain("example.com").with_contact("mailto:admin@example.com");
|
||||
assert_eq!(
|
||||
provider.contact,
|
||||
vec!["mailto:admin@example.com".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_state_domain() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
let (_state, resolver) = provider.build_acme_state();
|
||||
assert!(Arc::strong_count(&resolver) >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_state_with_cache() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
||||
let (_state, resolver) = provider.build_acme_state();
|
||||
assert!(Arc::strong_count(&resolver) >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_server_config() {
|
||||
let _ = default_provider().install_default();
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
let (_, resolver) = provider.build_acme_state();
|
||||
let config = provider
|
||||
.build_server_config_with_resolver(resolver)
|
||||
.unwrap();
|
||||
assert!(!config.alpn_protocols.is_empty());
|
||||
assert!(config
|
||||
.alpn_protocols
|
||||
.iter()
|
||||
.any(|p| p == ACME_TLS_ALPN_NAME));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_mode_domain_debug() {
|
||||
let mode = AcmeMode::Domain {
|
||||
domain: "test.example.com".to_string(),
|
||||
};
|
||||
let debug_str = format!("{:?}", mode);
|
||||
assert!(debug_str.contains("test.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_mode_ip_debug() {
|
||||
let mode = AcmeMode::Ip;
|
||||
let debug_str = format!("{:?}", mode);
|
||||
assert!(debug_str.contains("Ip"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_builder_chain() {
|
||||
let provider = AcmeCertProvider::domain("test.example.com")
|
||||
.with_production_directory()
|
||||
.with_cache_dir("/tmp/cache")
|
||||
.with_contact("mailto:admin@test.example.com");
|
||||
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||
);
|
||||
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/cache")));
|
||||
assert_eq!(provider.contact.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acme_tls_acceptor_bind_acme() {
|
||||
let _ = default_provider().install_default();
|
||||
let provider = Arc::new(AcmeCertProvider::domain("example.com"));
|
||||
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
|
||||
let acceptor = AcmeTlsAcceptor::bind_acme(addr, provider).await.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn acme_staging_domain_cert_provisioning() {
|
||||
let _ = default_provider().install_default();
|
||||
|
||||
let cache_dir = tempfile::tempdir().unwrap();
|
||||
let provider = Arc::new(
|
||||
AcmeCertProvider::domain("acme-test.example.com")
|
||||
.with_cache_dir(cache_dir.path())
|
||||
.with_contact("mailto:admin@example.com"),
|
||||
);
|
||||
|
||||
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let result = AcmeTlsAcceptor::bind_acme(addr, provider).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"ACME TlsAcceptor should bind: {:?}",
|
||||
result.err()
|
||||
);
|
||||
|
||||
let acceptor = result.unwrap();
|
||||
assert_eq!(acceptor.listen_addr().port(), 443);
|
||||
}
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use iroh::{
|
||||
endpoint::RecvStream, node_info::NodeIdExt, Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
|
||||
};
|
||||
use tokio::io;
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
pub const ALPN: &[u8] = b"alknet-ssh";
|
||||
const DEFAULT_RELAY_URL: &str = "https://relay.iroh.network/";
|
||||
|
||||
/// A client-side iroh QUIC P2P transport that connects to a remote iroh endpoint.
|
||||
///
|
||||
/// Connects via `Endpoint::connect(node_id, alpn)`, opens a bidirectional
|
||||
/// QUIC stream with `conn.open_bi()`, and joins the halves with
|
||||
/// `tokio::io::join(recv, send)` to produce a duplex stream for russh.
|
||||
/// Per ADR-003, `tokio::io::join` is used instead of a custom wrapper.
|
||||
///
|
||||
/// Use [`IrohTransport::new`] to create a standalone endpoint, or
|
||||
/// [`IrohTransport::from_endpoint`] to share an existing iroh `Endpoint`
|
||||
/// with other protocol handlers (blobs, gossip, docs).
|
||||
pub struct IrohTransport {
|
||||
node_id: NodeId,
|
||||
endpoint: Endpoint,
|
||||
owned: bool,
|
||||
}
|
||||
|
||||
impl IrohTransport {
|
||||
/// Create a new iroh transport with its own dedicated endpoint.
|
||||
///
|
||||
/// The endpoint is created with the `alknet-ssh` ALPN and the provided
|
||||
/// relay URL. Use this when alknet is the only iroh service on this node.
|
||||
pub async fn new(
|
||||
node_id: NodeId,
|
||||
relay_url: Option<RelayUrl>,
|
||||
proxy_url: Option<url::Url>,
|
||||
) -> Result<Self> {
|
||||
let relay_url = relay_url.unwrap_or_else(|| {
|
||||
DEFAULT_RELAY_URL
|
||||
.parse()
|
||||
.expect("default relay URL is valid")
|
||||
});
|
||||
let relay_map = RelayMap::from_url(relay_url);
|
||||
let mut builder = Endpoint::builder()
|
||||
.relay_mode(RelayMode::Custom(relay_map))
|
||||
.alpns(vec![ALPN.to_vec()]);
|
||||
if let Some(ref proxy) = proxy_url {
|
||||
builder = builder.proxy_url(proxy.clone());
|
||||
}
|
||||
let endpoint = builder.bind().await?;
|
||||
Ok(Self {
|
||||
node_id,
|
||||
endpoint,
|
||||
owned: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an iroh transport using an existing shared endpoint.
|
||||
///
|
||||
/// The endpoint must already have the `alknet-ssh` ALPN registered
|
||||
/// (typically via [`iroh::protocol::Router::builder`]). This enables
|
||||
/// running alknet alongside iroh-blobs, iroh-gossip, iroh-docs, and
|
||||
/// other protocol handlers on the same QUIC endpoint — one connection
|
||||
/// per peer, multiplexed by ALPN.
|
||||
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
endpoint,
|
||||
owned: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_id(&self) -> String {
|
||||
self.endpoint.node_id().to_z32()
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &Endpoint {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
pub fn owned(&self) -> bool {
|
||||
self.owned
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for IrohTransport {
|
||||
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let conn = self.endpoint.connect(self.node_id, ALPN).await?;
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
Ok(io::join(recv, send))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("iroh://{}", self.node_id.to_z32())
|
||||
}
|
||||
}
|
||||
|
||||
/// A server-side iroh QUIC P2P transport acceptor that listens for incoming connections.
|
||||
///
|
||||
/// Binds an iroh `Endpoint` with the configured relay URL and optional proxy
|
||||
/// (ADR-010). Accepts incoming connections, accepts bidirectional QUIC streams,
|
||||
/// and joins the halves with `tokio::io::join(recv, send)`. Exposes
|
||||
/// `endpoint_id()` for CLI display of the server's z-base-32 node ID.
|
||||
///
|
||||
/// Use [`IrohAcceptor::bind`] to create a standalone endpoint, or
|
||||
/// [`IrohAcceptor::from_endpoint`] to share an existing iroh `Endpoint`
|
||||
/// with other protocol handlers (blobs, gossip, docs).
|
||||
///
|
||||
/// When using `from_endpoint`, the alknet-ssh ALPN must be registered
|
||||
/// via an iroh `Router` that calls `Handler::accept()` on incoming
|
||||
/// connections with the `alknet-ssh` ALPN, then passes the accepted
|
||||
/// bidirectional stream to `russh::server::run_stream()`.
|
||||
pub struct IrohAcceptor {
|
||||
endpoint: Endpoint,
|
||||
owned: bool,
|
||||
}
|
||||
|
||||
impl IrohAcceptor {
|
||||
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
|
||||
///
|
||||
/// Use this when alknet is the only iroh service on this node.
|
||||
pub async fn bind(relay_url: Option<RelayUrl>, proxy_url: Option<url::Url>) -> Result<Self> {
|
||||
let relay_url = relay_url.unwrap_or_else(|| {
|
||||
DEFAULT_RELAY_URL
|
||||
.parse()
|
||||
.expect("default relay URL is valid")
|
||||
});
|
||||
let relay_map = RelayMap::from_url(relay_url);
|
||||
let mut builder = Endpoint::builder()
|
||||
.relay_mode(RelayMode::Custom(relay_map))
|
||||
.alpns(vec![ALPN.to_vec()]);
|
||||
if let Some(ref proxy) = proxy_url {
|
||||
builder = builder.proxy_url(proxy.clone());
|
||||
}
|
||||
let endpoint = builder.bind().await?;
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
owned: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an iroh acceptor using an existing shared endpoint.
|
||||
///
|
||||
/// The endpoint must already have the `alknet-ssh` ALPN registered
|
||||
/// (typically via [`iroh::protocol::Router::builder`]). When using a
|
||||
/// shared endpoint, incoming connections with the `alknet-ssh` ALPN
|
||||
/// are routed by the Router to a `ProtocolHandler` that this acceptor
|
||||
/// does not manage — the caller is responsible for bridging the
|
||||
/// Router's `accept()` callback to this acceptor's stream handling.
|
||||
///
|
||||
/// For the standalone case where alknet owns the endpoint, use
|
||||
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
|
||||
/// internally.
|
||||
pub fn from_endpoint(endpoint: Endpoint) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
owned: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_id(&self) -> String {
|
||||
self.endpoint.node_id().to_z32()
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &Endpoint {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
pub fn owned(&self) -> bool {
|
||||
self.owned
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for IrohAcceptor {
|
||||
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let incoming = self
|
||||
.endpoint
|
||||
.accept()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("endpoint closed"))?;
|
||||
let conn = incoming.await?;
|
||||
let node_id = conn.remote_node_id()?;
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
let stream = io::join(recv, send);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Iroh {
|
||||
endpoint_id: node_id.to_z32(),
|
||||
},
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_bind_creates_endpoint() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint_id = acceptor.endpoint_id();
|
||||
assert!(!endpoint_id.is_empty());
|
||||
let parsed = NodeId::from_z32(&endpoint_id);
|
||||
assert!(parsed.is_ok());
|
||||
assert!(acceptor.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_bind_with_custom_relay() {
|
||||
let relay: RelayUrl = "https://relay.iroh.network/".parse().unwrap();
|
||||
let acceptor = IrohAcceptor::bind(Some(relay), None).await.unwrap();
|
||||
assert!(!acceptor.endpoint_id().is_empty());
|
||||
assert!(acceptor.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_from_endpoint() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint = acceptor.endpoint.clone();
|
||||
let shared = IrohAcceptor::from_endpoint(endpoint);
|
||||
assert_eq!(shared.endpoint_id(), acceptor.endpoint_id());
|
||||
assert!(!shared.owned());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iroh_transport_describe_format() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let desc = format!("iroh://{}", node_id.to_z32());
|
||||
assert!(desc.starts_with("iroh://"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_transport_connect_builds_endpoint() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
|
||||
assert!(transport.describe().starts_with("iroh://"));
|
||||
assert!(!transport.endpoint_id().is_empty());
|
||||
assert!(transport.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_transport_from_endpoint() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint = acceptor.endpoint.clone();
|
||||
let transport = IrohTransport::from_endpoint(node_id, endpoint);
|
||||
assert!(transport.describe().starts_with("iroh://"));
|
||||
assert_eq!(transport.endpoint_id(), acceptor.endpoint_id());
|
||||
assert!(!transport.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn iroh_client_connects_to_iroh_server() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let server_node_id = acceptor.endpoint().node_id();
|
||||
|
||||
let transport = IrohTransport::new(server_node_id, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
|
||||
addrs_watcher.initialized().await.unwrap();
|
||||
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
|
||||
for addr in addr_set {
|
||||
transport
|
||||
.endpoint
|
||||
.add_node_addr(iroh::NodeAddr::from_parts(
|
||||
server_node_id,
|
||||
None,
|
||||
vec![addr.addr],
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let accept_handle = tokio::spawn(async move {
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
|
||||
stream
|
||||
});
|
||||
|
||||
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
|
||||
transport.connect().await.unwrap();
|
||||
let _server_stream = accept_handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn iroh_shared_endpoint_client_connects_to_server() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let server_node_id = acceptor.endpoint().node_id();
|
||||
let shared_endpoint = acceptor.endpoint().clone();
|
||||
|
||||
let transport = IrohTransport::from_endpoint(server_node_id, shared_endpoint);
|
||||
|
||||
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
|
||||
addrs_watcher.initialized().await.unwrap();
|
||||
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
|
||||
for addr in addr_set {
|
||||
transport
|
||||
.endpoint
|
||||
.add_node_addr(iroh::NodeAddr::from_parts(
|
||||
server_node_id,
|
||||
None,
|
||||
vec![addr.addr],
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let accept_handle = tokio::spawn(async move {
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
|
||||
stream
|
||||
});
|
||||
|
||||
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
|
||||
transport.connect().await.unwrap();
|
||||
let _server_stream = accept_handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
//! Pluggable transport layer for Alknet.
|
||||
//!
|
||||
//! The transport layer produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
|
||||
//! that SSH consumes. This is the core architectural abstraction — SSH never opens its own
|
||||
//! network connections; it runs entirely over whatever stream the transport provides.
|
||||
//!
|
||||
//! Available transports (feature-gated):
|
||||
//! - `TcpTransport` / `TcpAcceptor` — always available, direct TCP
|
||||
//! - `TlsTransport` / `TlsAcceptor` — behind the `tls` feature, TCP + rustls
|
||||
//! - `IrohTransport` / `IrohAcceptor` — behind the `iroh` feature, QUIC P2P via iroh
|
||||
//! - `AcmeTlsAcceptor` — behind the `acme` feature, auto-provision TLS certs via Let's Encrypt
|
||||
//!
|
||||
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
|
||||
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
mod iroh_transport;
|
||||
mod tcp;
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
|
||||
pub use tcp::{TcpAcceptor, TcpTransport};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
mod acme;
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
pub use acme::{AcmeCertProvider, AcmeMode, AcmeTlsAcceptor};
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// Client-side transport trait. Produces a single duplex stream per connection.
|
||||
///
|
||||
/// Implementations connect to a remote endpoint and return a stream that SSH
|
||||
/// runs over via `russh::client::connect_stream()`. Each call to `connect()` creates
|
||||
/// a new stream — multiple sessions need multiple calls or multiple transports.
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
/// The duplex stream type produced by this transport.
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
/// Connect to the remote endpoint and return a duplex stream.
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
|
||||
/// Return a human-readable description of this transport for logging.
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
/// Server-side transport acceptor. Accepts incoming connections and returns streams.
|
||||
///
|
||||
/// Implementations bind to a local endpoint and produce streams that SSH
|
||||
/// runs over via `russh::server::run_stream()`.
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
/// The duplex stream type produced by this acceptor.
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
/// Accept an incoming connection and return a duplex stream with metadata.
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
/// Metadata about an incoming transport connection.
|
||||
///
|
||||
/// Carries the remote address (if available) and the kind of transport
|
||||
/// used. The server handler uses this for logging and auth decisions.
|
||||
/// See ADR-001 for the pluggable transport rationale and ADR-004
|
||||
/// for why SSH runs entirely over the transport stream.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
/// The kind of transport that produced a connection.
|
||||
///
|
||||
/// Each variant identifies the transport mechanism. Used by the
|
||||
/// server handler for logging and authorization decisions.
|
||||
/// See ADR-001 and ADR-004.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls { server_name: Option<String> },
|
||||
Iroh { endpoint_id: String },
|
||||
WebTransport { server_name: Option<String> },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportKind::Tcp => write!(f, "tcp"),
|
||||
TransportKind::Tls { .. } => write!(f, "tls"),
|
||||
TransportKind::Iroh { .. } => write!(f, "iroh"),
|
||||
|
||||
TransportKind::WebTransport { .. } => write!(f, "webtransport"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, DuplexStream};
|
||||
|
||||
struct MockTransport;
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (stream, _) = duplex(1024);
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct MockAcceptor;
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for MockAcceptor {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (stream, _) = duplex(1024);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_trait_object() {
|
||||
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_acceptor_trait_object() {
|
||||
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_connect_returns_stream() {
|
||||
let t = MockTransport;
|
||||
let _stream = t.connect().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_describe_returns_string() {
|
||||
let t = MockTransport;
|
||||
assert_eq!(t.describe(), "mock");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_accept_returns_stream_and_info() {
|
||||
let a = MockAcceptor;
|
||||
let (_, info) = a.accept().await.unwrap();
|
||||
assert!(info.remote_addr.is_none());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_variants() {
|
||||
let tcp = TransportKind::Tcp;
|
||||
let tls = TransportKind::Tls {
|
||||
server_name: Some("example.com".to_string()),
|
||||
};
|
||||
let iroh = TransportKind::Iroh {
|
||||
endpoint_id: "abc123".to_string(),
|
||||
};
|
||||
let wt = TransportKind::WebTransport {
|
||||
server_name: Some("example.com".to_string()),
|
||||
};
|
||||
|
||||
if let TransportKind::Tcp = tcp {}
|
||||
if let TransportKind::Tls {
|
||||
server_name: Some(name),
|
||||
} = tls
|
||||
{
|
||||
assert_eq!(name, "example.com");
|
||||
}
|
||||
if let TransportKind::Iroh { endpoint_id } = iroh {
|
||||
assert_eq!(endpoint_id, "abc123");
|
||||
}
|
||||
if let TransportKind::WebTransport { server_name } = wt {
|
||||
assert_eq!(server_name, Some("example.com".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
/// A TCP-based client transport that connects to a remote address.
|
||||
///
|
||||
/// Connects via `TcpStream::connect(addr)`. Uses tokio's default
|
||||
/// connect timeout behavior: the OS controls connection timeout
|
||||
/// (typically ~2 minutes on Linux via `net.ipv4.tcp_syn_retries`).
|
||||
/// For custom timeouts, wrap `TcpTransport` with
|
||||
/// `tokio::time::timeout(duration, transport.connect())`.
|
||||
pub struct TcpTransport {
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl TcpTransport {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self { addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for TcpTransport {
|
||||
type Stream = TcpStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let stream = TcpStream::connect(self.addr).await?;
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("tcp://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
/// A TCP-based server transport acceptor that listens for incoming connections.
|
||||
///
|
||||
/// Binds via `TcpListener::bind(addr)`. Accepts connections and returns
|
||||
/// the stream together with `TransportInfo` containing the remote address
|
||||
/// and `TransportKind::Tcp`.
|
||||
pub struct TcpAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl TcpAcceptor {
|
||||
/// Bind a TCP listener on the given address.
|
||||
///
|
||||
/// Returns the acceptor ready to receive connections.
|
||||
/// The actual bound address may differ from the requested one
|
||||
/// (e.g., when binding to port 0 the OS assigns an ephemeral port).
|
||||
pub async fn bind(addr: SocketAddr) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for TcpAcceptor {
|
||||
type Stream = TcpStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (stream, remote_addr) = self.listener.accept().await?;
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_transport_connect_creates_stream() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
let transport = TcpTransport::new(addr);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let stream = transport.connect().await.unwrap();
|
||||
assert_eq!(stream.local_addr().unwrap().ip(), addr.ip());
|
||||
|
||||
let (_server_stream, info) = accept_handle.await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_acceptor_accept_receives_connection() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
tokio::spawn(async move {
|
||||
TcpStream::connect(addr).await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
assert_eq!(
|
||||
info.remote_addr.unwrap().ip(),
|
||||
stream.peer_addr().unwrap().ip()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp_transport_describe_format() {
|
||||
let addr: SocketAddr = "1.2.3.4:22".parse().unwrap();
|
||||
let transport = TcpTransport::new(addr);
|
||||
assert_eq!(transport.describe(), "tcp://1.2.3.4:22");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_stream_is_duplex() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
let (mut server, _) = acceptor.accept().await.unwrap();
|
||||
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
|
||||
server.write_all(b"world").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_acceptor_bind_port_zero_assigns_ephemeral() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
}
|
||||
@@ -1,429 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
||||
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_rustls::{
|
||||
client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector,
|
||||
};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
#[cfg(feature = "acme")]
|
||||
use rustls_acme::ResolvesServerCertAcme;
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||
|
||||
/// A TLS-based client transport that connects to a remote address over TLS.
|
||||
///
|
||||
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
||||
/// Supports insecure mode (accepts any certificate, for development) and
|
||||
/// custom root CA certificates for verification. The `tls_server_name` field
|
||||
/// overrides the SNI hostname sent during the TLS handshake (ADR-010).
|
||||
pub struct TlsTransport {
|
||||
addr: SocketAddr,
|
||||
tls_server_name: Option<String>,
|
||||
insecure: bool,
|
||||
root_cert: Option<CertificateDer<'static>>,
|
||||
}
|
||||
|
||||
impl TlsTransport {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
root_cert: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self {
|
||||
self.root_cert = Some(cert);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_client_config(&self) -> Result<ClientConfig> {
|
||||
if self.insecure {
|
||||
let config = ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
return Ok(config);
|
||||
}
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
if let Some(ref cert) = self.root_cert {
|
||||
root_store.add(cert.clone())?;
|
||||
}
|
||||
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn resolve_server_name(&self) -> Result<ServerName<'static>> {
|
||||
let name = match &self.tls_server_name {
|
||||
Some(n) => n.clone(),
|
||||
None => self.addr.ip().to_string(),
|
||||
};
|
||||
ServerName::try_from(name.clone())
|
||||
.map_err(move |e| anyhow!("invalid server name '{}': {}", name, e))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for TlsTransport {
|
||||
type Stream = ClientTlsStream<TcpStream>;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let tcp_stream = TcpStream::connect(self.addr).await?;
|
||||
let config = self.build_client_config()?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
let server_name = self.resolve_server_name()?;
|
||||
let tls_stream = connector.connect(server_name, tcp_stream).await?;
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("tls://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stub configuration for ACME certificate provisioning (ADR-008).
|
||||
/// Feature-gated behind the `acme` feature. When implemented, this will
|
||||
/// hold the ACME domain and challenge responder configuration.
|
||||
#[derive(Debug)]
|
||||
pub struct AcmeConfig {
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// A TLS-based server transport acceptor that accepts TCP connections
|
||||
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
||||
///
|
||||
/// Supports three certificate modes (ADR-008):
|
||||
/// - Manual certs via `bind()` with explicit cert/key
|
||||
/// - ACME certs via `bind_acme()` with an `AcmeCertProvider`
|
||||
/// - The stub `AcmeConfig` parameter in `bind()` is kept for backward compat
|
||||
pub struct TlsAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
#[allow(dead_code)]
|
||||
server_config: Arc<ServerConfig>,
|
||||
tokio_acceptor: TokioTlsAcceptor,
|
||||
}
|
||||
|
||||
impl TlsAcceptor {
|
||||
pub async fn bind(
|
||||
addr: SocketAddr,
|
||||
tls_certs: Vec<CertificateDer<'static>>,
|
||||
tls_key: PrivateKeyDer<'static>,
|
||||
_acme_config: Option<AcmeConfig>,
|
||||
) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certs, tls_key)?;
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
pub async fn bind_acme(
|
||||
addr: SocketAddr,
|
||||
acme_resolver: Arc<ResolvesServerCertAcme>,
|
||||
) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let provider = default_provider().into();
|
||||
let mut server_config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(acme_resolver);
|
||||
server_config
|
||||
.alpn_protocols
|
||||
.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<TcpStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||
|
||||
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tls { server_name },
|
||||
};
|
||||
|
||||
Ok((tls_stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> std::result::Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
fn ensure_crypto_provider() {
|
||||
let _ = default_provider().install_default();
|
||||
}
|
||||
|
||||
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
|
||||
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
let cert_der: CertificateDer<'static> = cert.into();
|
||||
let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into());
|
||||
(cert_der, key_der)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_format() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr).with_server_name("example.com");
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_with_ip() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr);
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_builder_methods() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("alknet.test")
|
||||
.with_insecure(true);
|
||||
assert_eq!(transport.tls_server_name, Some("alknet.test".to_string()));
|
||||
assert!(transport.insecure);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_connect_insecure_self_signed() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
|
||||
let (mut server, info) = accept_handle.await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tls { .. }));
|
||||
|
||||
client.write_all(b"hello tls").await.unwrap();
|
||||
let mut buf = [0u8; 9];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello tls");
|
||||
|
||||
server.write_all(b"reply").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"reply");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_returns_server_name() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let _client = transport.connect().await.unwrap();
|
||||
|
||||
let (_, info) = accept_handle.await.unwrap();
|
||||
if let TransportKind::Tls { server_name } = info.transport_kind {
|
||||
assert_eq!(server_name, Some("localhost".to_string()));
|
||||
} else {
|
||||
panic!("expected TransportKind::Tls");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_full_client_to_server_connection() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
let (mut server, _info) = accept_handle.await.unwrap();
|
||||
|
||||
let msg = b"alknet integration test";
|
||||
client.write_all(msg).await.unwrap();
|
||||
let mut buf = vec![0u8; msg.len()];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..], msg);
|
||||
|
||||
let reply = b"ok";
|
||||
server.write_all(reply).await.unwrap();
|
||||
let mut buf = [0u8; 2];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, reply);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_bind_port_zero_assigns_ephemeral() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_verifier_accepts_any_cert() {
|
||||
let verifier = NoVerifier;
|
||||
assert!(verifier.supported_verify_schemes().len() > 0);
|
||||
}
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn auth_placeholder() {}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn client_placeholder() {}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn server_placeholder() {}
|
||||
@@ -1,28 +0,0 @@
|
||||
use alknet_core::testutil::{
|
||||
mock_pair, MockTransport, MockTransportAcceptor, Transport, TransportAcceptor,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_connect() {
|
||||
let transport = MockTransport::new(1024);
|
||||
let stream = transport.connect().await.unwrap();
|
||||
drop(stream);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_acceptor_accept() {
|
||||
let acceptor = MockTransportAcceptor::new(1024);
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
drop(stream);
|
||||
drop(info);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_pair_communicates() {
|
||||
let (mut client, mut server) = mock_pair(1024);
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
Reference in New Issue
Block a user