From 8d056a2b596f5d2a7ce76e86f09e7fa4321b1f73 Mon Sep 17 00:00:00 2001 From: "glm-5.2" Date: Tue, 23 Jun 2026 15:12:14 +0000 Subject: [PATCH] feat(core): implement AlknetEndpoint, HandlerRegistry, accept loops (quinn + iroh), TLS identity (RawKey/X509/SelfSigned), and graceful shutdown (task: core/endpoint) --- Cargo.lock | 12 + crates/alknet-core/Cargo.toml | 5 +- crates/alknet-core/src/endpoint.rs | 979 ++++++++++++++++++++++++++++- 3 files changed, 994 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index aaf2a66..75f6a95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,7 +74,10 @@ dependencies = [ "hex", "iroh", "quinn", + "rand 0.8.6", + "rcgen 0.13.2", "rustls", + "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", @@ -3286,6 +3289,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.1" diff --git a/crates/alknet-core/Cargo.toml b/crates/alknet-core/Cargo.toml index 5974af0..37d0a8a 100644 --- a/crates/alknet-core/Cargo.toml +++ b/crates/alknet-core/Cargo.toml @@ -20,6 +20,7 @@ quinn = { version = "0.11", optional = true } iroh = { version = "0.35", optional = true } rustls = "0.23" rustls-pki-types = "1" +rustls-pemfile = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" toml = "0.8" @@ -31,4 +32,6 @@ zeroize = { version = "1", features = ["alloc", "derive"] } bytes = "1" futures = "0.3" sha2 = "0.10" -hex = "0.4" \ No newline at end of file +hex = "0.4" +rand = "0.8" +rcgen = "0.13" \ No newline at end of file diff --git a/crates/alknet-core/src/endpoint.rs b/crates/alknet-core/src/endpoint.rs index a706143..3dc3a40 100644 --- a/crates/alknet-core/src/endpoint.rs +++ b/crates/alknet-core/src/endpoint.rs @@ -2,4 +2,981 @@ //! //! See `docs/architecture/crates/core/endpoint.md` for the full specification. -// TODO: implement +use std::collections::HashMap; +use std::io; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use std::net::SocketAddr; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use std::path::Path; +use std::sync::Arc; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use std::time::Duration; + +#[cfg(any(feature = "quinn", feature = "iroh"))] +use arc_swap::ArcSwap; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use tokio::sync::watch; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use tracing::{debug, error, warn}; + +#[cfg(any(feature = "quinn", feature = "iroh"))] +use crate::auth::{AuthContext, IdentityProvider}; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use crate::config::{DynamicConfig, StaticConfig, TlsIdentity}; +#[cfg(not(any(feature = "quinn", feature = "iroh")))] +use crate::types::ProtocolHandler; +#[cfg(any(feature = "quinn", feature = "iroh"))] +use crate::types::{Connection, ProtocolHandler}; + +#[derive(Debug, thiserror::Error)] +pub enum EndpointError { + #[error("bind failed: {0}")] + BindFailed(io::Error), + #[error("tls config error: {0}")] + TlsConfig(io::Error), + #[error("handler not found for ALPN {0:?}")] + HandlerNotFound(Vec), +} + +pub struct HandlerRegistry { + handlers: HashMap<&'static [u8], Arc>, +} + +impl HandlerRegistry { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + pub fn register(&mut self, handler: Arc) { + let alpn = handler.alpn(); + if self.handlers.contains_key(alpn) { + panic!( + "HandlerRegistry: ALPN already registered: {:?}", + String::from_utf8_lossy(alpn) + ); + } + self.handlers.insert(alpn, handler); + } + + pub fn get(&self, alpn: &[u8]) -> Option<&Arc> { + self.handlers.get(alpn) + } + + pub fn alpn_strings(&self) -> Vec> { + self.handlers.keys().map(|k| k.to_vec()).collect() + } +} + +impl Default for HandlerRegistry { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for HandlerRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HandlerRegistry") + .field( + "alpns", + &self + .handlers + .keys() + .map(|k| String::from_utf8_lossy(k).to_string()) + .collect::>(), + ) + .finish() + } +} + +#[cfg(any(feature = "quinn", feature = "iroh"))] +pub struct AlknetEndpoint { + #[cfg(feature = "quinn")] + quinn: Option, + #[cfg(feature = "iroh")] + iroh: Option, + handlers: Arc, + #[allow(dead_code)] + dynamic: Arc>, + identity_provider: Arc, + shutdown_tx: watch::Sender, + shutdown_rx: watch::Receiver, + drain_timeout: Duration, +} + +#[cfg(any(feature = "quinn", feature = "iroh"))] +impl std::fmt::Debug for AlknetEndpoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AlknetEndpoint") + .field("handlers", &self.handlers) + .field("drain_timeout", &self.drain_timeout) + .finish() + } +} + +#[cfg(any(feature = "quinn", feature = "iroh"))] +impl AlknetEndpoint { + pub async fn new( + static_config: &StaticConfig, + handlers: HandlerRegistry, + dynamic: Arc>, + identity_provider: Arc, + ) -> Result { + let handlers = Arc::new(handlers); + let alpns = handlers.alpn_strings(); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let drain_timeout = static_config.drain_timeout; + + #[cfg(feature = "quinn")] + let quinn = if let Some(listen_addr) = static_config.listen_addr { + let tls_identity = static_config.tls_identity.as_ref().ok_or_else(|| { + EndpointError::TlsConfig(io::Error::new( + io::ErrorKind::InvalidInput, + "quinn endpoint requires tls_identity in static config", + )) + })?; + let server_config = build_quinn_server_config(tls_identity, &alpns)?; + let endpoint = quinn::Endpoint::server(server_config, listen_addr) + .map_err(EndpointError::BindFailed)?; + Some(endpoint) + } else { + None + }; + + #[cfg(not(feature = "quinn"))] + if static_config.listen_addr.is_some() { + return Err(EndpointError::TlsConfig(io::Error::new( + io::ErrorKind::Unsupported, + "quinn feature is not enabled but listen_addr was set", + ))); + } + + #[cfg(feature = "iroh")] + let iroh = if static_config.iroh_relay.is_some() || has_iroh_identity(static_config) { + Some(build_iroh_endpoint(static_config, &alpns).await?) + } else { + None + }; + + Ok(Self { + #[cfg(feature = "quinn")] + quinn, + #[cfg(feature = "iroh")] + iroh, + handlers, + dynamic, + identity_provider, + shutdown_tx, + shutdown_rx, + drain_timeout, + }) + } + + pub fn shutdown_sender(&self) -> watch::Sender { + self.shutdown_tx.clone() + } + + pub async fn shutdown(&self) -> Result<(), EndpointError> { + let _ = self.shutdown_tx.send(true); + + #[cfg(feature = "quinn")] + if let Some(quinn) = &self.quinn { + quinn.close(0u32.into(), b"shutdown"); + } + + #[cfg(feature = "iroh")] + if let Some(iroh) = &self.iroh { + iroh.close().await; + } + + tokio::time::sleep(self.drain_timeout).await; + + #[cfg(feature = "quinn")] + if let Some(quinn) = &self.quinn { + quinn.wait_idle().await; + } + + Ok(()) + } + + pub async fn run(self: Arc) { + let mut tasks: Vec> = Vec::new(); + + #[cfg(feature = "quinn")] + if let Some(quinn) = &self.quinn { + let quinn = quinn.clone(); + let handlers = self.handlers.clone(); + let identity_provider = self.identity_provider.clone(); + let mut shutdown_rx = self.shutdown_rx.clone(); + let task = tokio::spawn(async move { + run_quinn_accept_loop(quinn, handlers, identity_provider, &mut shutdown_rx).await; + }); + tasks.push(task); + } + + #[cfg(feature = "iroh")] + if let Some(iroh) = &self.iroh { + let iroh = iroh.clone(); + let handlers = self.handlers.clone(); + let identity_provider = self.identity_provider.clone(); + let mut shutdown_rx = self.shutdown_rx.clone(); + let task = tokio::spawn(async move { + run_iroh_accept_loop(iroh, handlers, identity_provider, &mut shutdown_rx).await; + }); + tasks.push(task); + } + + for task in tasks { + let _ = task.await; + } + } +} + +#[cfg(feature = "iroh")] +fn has_iroh_identity(static_config: &StaticConfig) -> bool { + matches!( + static_config.tls_identity.as_ref(), + Some(TlsIdentity::RawKey(_)) + ) +} + +#[cfg(feature = "quinn")] +async fn run_quinn_accept_loop( + quinn: quinn::Endpoint, + handlers: Arc, + identity_provider: Arc, + shutdown_rx: &mut watch::Receiver, +) { + loop { + tokio::select! { + _ = shutdown_rx.changed() => { + debug!("quinn accept loop: shutdown signaled"); + break; + } + incoming = quinn.accept() => { + let Some(incoming) = incoming else { + debug!("quinn accept loop: endpoint closed"); + break; + }; + let connecting = match incoming.accept() { + Ok(c) => c, + Err(e) => { + warn!("quinn accept failed: {e}"); + continue; + } + }; + let handlers = handlers.clone(); + let identity_provider = identity_provider.clone(); + tokio::spawn(async move { + let connection = match connecting.await { + Ok(conn) => conn, + Err(e) => { + warn!("quinn TLS handshake failure: {e}"); + return; + } + }; + dispatch_quinn(connection, &handlers, &identity_provider); + }); + } + } + } +} + +#[cfg(feature = "quinn")] +fn dispatch_quinn( + connection: quinn::Connection, + handlers: &HandlerRegistry, + identity_provider: &Arc, +) { + let alpn = extract_quinn_alpn(&connection); + let handler = match handlers.get(&alpn) { + Some(h) => h.clone(), + None => { + connection.close(0u32.into(), b"no handler"); + warn!( + "quinn dispatch: no handler for ALPN {:?}", + String::from_utf8_lossy(&alpn) + ); + return; + } + }; + + let remote_addr = Some(connection.remote_address()); + let auth = build_auth_context(&alpn, remote_addr, None, identity_provider); + let conn = Connection::from_quinn_with_alpn(connection, alpn.clone()); + tokio::spawn(async move { + if let Err(e) = handler.handle(conn, &auth).await { + error!("handler returned error: {e}"); + } + }); +} + +#[cfg(feature = "quinn")] +fn extract_quinn_alpn(connection: &quinn::Connection) -> Vec { + use quinn::crypto::rustls::HandshakeData; + if let Some(data) = connection.handshake_data() { + if let Ok(hs) = data.downcast::() { + if let Some(protocol) = hs.protocol { + return protocol; + } + } + } + Vec::new() +} + +#[cfg(feature = "iroh")] +async fn run_iroh_accept_loop( + iroh: iroh::Endpoint, + handlers: Arc, + identity_provider: Arc, + shutdown_rx: &mut watch::Receiver, +) { + loop { + tokio::select! { + _ = shutdown_rx.changed() => { + debug!("iroh accept loop: shutdown signaled"); + break; + } + incoming = iroh.accept() => { + let Some(incoming) = incoming else { + debug!("iroh accept loop: endpoint closed"); + break; + }; + let handlers = handlers.clone(); + let identity_provider = identity_provider.clone(); + tokio::spawn(async move { + let mut connecting = match incoming.accept() { + Ok(c) => c, + Err(e) => { + warn!("iroh accept failed: {e}"); + return; + } + }; + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(e) => { + warn!("iroh ALPN negotiation failed: {e}"); + return; + } + }; + let connection = match connecting.await { + Ok(conn) => conn, + Err(e) => { + warn!("iroh handshake completion failed: {e}"); + return; + } + }; + dispatch_iroh(connection, alpn, &handlers, &identity_provider); + }); + } + } + } +} + +#[cfg(feature = "iroh")] +fn dispatch_iroh( + connection: iroh::endpoint::Connection, + alpn: Vec, + handlers: &HandlerRegistry, + identity_provider: &Arc, +) { + let handler = match handlers.get(&alpn) { + Some(h) => h.clone(), + None => { + connection.close(0u32.into(), b"no handler"); + warn!( + "iroh dispatch: no handler for ALPN {:?}", + String::from_utf8_lossy(&alpn) + ); + return; + } + }; + + let auth = build_auth_context(&alpn, None, None, identity_provider); + let conn = Connection::from_iroh(connection); + tokio::spawn(async move { + if let Err(e) = handler.handle(conn, &auth).await { + error!("handler returned error: {e}"); + } + }); +} + +#[cfg(any(feature = "quinn", feature = "iroh"))] +fn build_auth_context( + alpn: &[u8], + remote_addr: Option, + tls_client_fingerprint: Option, + identity_provider: &Arc, +) -> AuthContext { + let identity = tls_client_fingerprint + .as_ref() + .and_then(|fp| identity_provider.resolve_from_fingerprint(fp)); + AuthContext { + identity, + alpn: alpn.to_vec(), + remote_addr, + tls_client_fingerprint, + } +} + +#[cfg(feature = "quinn")] +fn build_quinn_server_config( + tls_identity: &TlsIdentity, + alpns: &[Vec], +) -> Result { + use quinn::crypto::rustls::QuicServerConfig; + let rustls_config = build_rustls_server_config(tls_identity, alpns)?; + let quic_server_config = QuicServerConfig::try_from(rustls_config) + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + Ok(quinn::ServerConfig::with_crypto(Arc::new( + quic_server_config, + ))) +} + +#[cfg(feature = "quinn")] +fn build_rustls_server_config( + tls_identity: &TlsIdentity, + alpns: &[Vec], +) -> Result { + let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider()); + match tls_identity { + TlsIdentity::X509 { cert, key } => { + let cert_chain = load_cert_chain(cert)?; + let private_key = load_private_key(key)?; + let mut config = rustls::ServerConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? + .with_no_client_auth() + .with_single_cert(cert_chain, private_key) + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + config.alpn_protocols = alpns.to_vec(); + config.max_early_data_size = u32::MAX; + Ok(config) + } + #[cfg(feature = "iroh")] + TlsIdentity::RawKey(secret_key) => { + let resolver = Arc::new(RawKeyCertResolver::new(secret_key)); + let mut config = rustls::ServerConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? + .with_no_client_auth() + .with_cert_resolver(resolver); + config.alpn_protocols = alpns.to_vec(); + config.max_early_data_size = u32::MAX; + Ok(config) + } + TlsIdentity::SelfSigned => { + let cert = generate_self_signed_cert()?; + let mut config = rustls::ServerConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? + .with_no_client_auth() + .with_single_cert(cert.cert_chain, cert.private_key) + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + config.alpn_protocols = alpns.to_vec(); + config.max_early_data_size = u32::MAX; + Ok(config) + } + } +} + +#[cfg(feature = "iroh")] +async fn build_iroh_endpoint( + static_config: &StaticConfig, + alpns: &[Vec], +) -> Result { + let mut builder = iroh::Endpoint::builder(); + + if let Some(TlsIdentity::RawKey(secret_key)) = static_config.tls_identity.as_ref() { + builder = builder.secret_key(secret_key.clone()); + } else { + let mut csprng = rand::rngs::OsRng; + builder = builder.secret_key(iroh::SecretKey::generate(&mut csprng)); + } + + builder = builder.alpns(alpns.to_vec()); + + if let Some(relay_url) = static_config.iroh_relay.as_ref() { + let relay_map = iroh::RelayMap::from(relay_url.clone()); + builder = builder.relay_mode(iroh::RelayMode::Custom(relay_map)); + } else { + builder = builder.relay_mode(iroh::RelayMode::Disabled); + } + + builder + .bind() + .await + .map_err(|e| EndpointError::BindFailed(io::Error::other(e))) +} + +#[cfg(feature = "quinn")] +fn load_cert_chain( + path: &Path, +) -> Result>, EndpointError> { + let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + let mut reader = std::io::BufReader::new(bytes.as_slice()); + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e))) +} + +#[cfg(feature = "quinn")] +fn load_private_key( + path: &Path, +) -> Result, EndpointError> { + let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + let mut reader = std::io::BufReader::new(bytes.as_slice()); + match rustls_pemfile::private_key(&mut reader) { + Ok(Some(key)) => Ok(key), + Ok(None) => Err(EndpointError::TlsConfig(io::Error::new( + io::ErrorKind::InvalidData, + "no private key found in file", + ))), + Err(e) => Err(EndpointError::TlsConfig(io::Error::other(e))), + } +} + +#[cfg(feature = "quinn")] +struct SelfSignedCert { + cert_chain: Vec>, + private_key: rustls::pki_types::PrivateKeyDer<'static>, +} + +#[cfg(feature = "quinn")] +fn generate_self_signed_cert() -> Result { + use rcgen::{CertificateParams, KeyPair}; + let key_pair = + KeyPair::generate().map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + let params = CertificateParams::default(); + let cert = params + .self_signed(&key_pair) + .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; + let cert_der = cert.der().clone(); + let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8( + rustls::pki_types::PrivatePkcs8KeyDer::from(key_pair.serialize_der()), + ); + Ok(SelfSignedCert { + cert_chain: vec![cert_der], + private_key: key_der, + }) +} + +#[cfg(feature = "iroh")] +struct RawKeyCertResolver { + key: Arc, +} + +#[cfg(feature = "iroh")] +impl RawKeyCertResolver { + fn new(secret_key: &iroh::SecretKey) -> Self { + let signing_key = Arc::new(Ed25519SigningKey::new(secret_key.clone())); + let public_key = signing_key.spki_public_key(); + let cert = rustls::pki_types::CertificateDer::from(public_key.to_vec()); + let certified_key = rustls::sign::CertifiedKey::new(vec![cert], signing_key); + Self { + key: Arc::new(certified_key), + } + } +} + +#[cfg(feature = "iroh")] +impl rustls::server::ResolvesServerCert for RawKeyCertResolver { + fn resolve( + &self, + _client_hello: rustls::server::ClientHello<'_>, + ) -> Option> { + Some(Arc::clone(&self.key)) + } + + fn only_raw_public_keys(&self) -> bool { + true + } +} + +#[cfg(feature = "iroh")] +impl std::fmt::Debug for RawKeyCertResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RawKeyCertResolver").finish() + } +} + +#[cfg(feature = "iroh")] +#[derive(Clone)] +struct Ed25519SigningKey { + key: iroh::SecretKey, +} + +#[cfg(feature = "iroh")] +impl std::fmt::Debug for Ed25519SigningKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ed25519SigningKey").finish() + } +} + +#[cfg(feature = "iroh")] +impl Ed25519SigningKey { + fn new(key: iroh::SecretKey) -> Self { + Self { key } + } + + fn spki_public_key(&self) -> rustls::pki_types::SubjectPublicKeyInfoDer<'static> { + rustls::sign::public_key_to_spki( + &rustls::pki_types::alg_id::ED25519, + self.key.public().as_bytes(), + ) + } +} + +#[cfg(feature = "iroh")] +impl rustls::sign::SigningKey for Ed25519SigningKey { + fn choose_scheme( + &self, + offered: &[rustls::SignatureScheme], + ) -> Option> { + if offered.contains(&rustls::SignatureScheme::ED25519) { + Some(Box::new(self.clone())) + } else { + None + } + } + + fn algorithm(&self) -> rustls::SignatureAlgorithm { + rustls::SignatureAlgorithm::ED25519 + } + + fn public_key(&self) -> Option> { + Some(self.spki_public_key()) + } +} + +#[cfg(feature = "iroh")] +impl rustls::sign::Signer for Ed25519SigningKey { + fn sign(&self, message: &[u8]) -> Result, rustls::Error> { + Ok(self.key.sign(message).to_bytes().to_vec()) + } + + fn scheme(&self) -> rustls::SignatureScheme { + rustls::SignatureScheme::ED25519 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::AuthContext; + #[cfg(any(feature = "quinn", feature = "iroh"))] + use crate::auth::{AuthToken, Identity, IdentityProvider}; + use crate::types::{Connection, HandlerError}; + use async_trait::async_trait; + + struct DummyHandler { + alpn: &'static [u8], + } + + #[async_trait] + impl ProtocolHandler for DummyHandler { + fn alpn(&self) -> &'static [u8] { + self.alpn + } + async fn handle( + &self, + _connection: Connection, + _auth: &AuthContext, + ) -> Result<(), HandlerError> { + Ok(()) + } + } + + fn make_handler(alpn: &'static [u8]) -> Arc { + Arc::new(DummyHandler { alpn }) + } + + #[test] + fn handler_registry_new_is_empty() { + let reg = HandlerRegistry::new(); + assert!(reg.alpn_strings().is_empty()); + assert!(reg.get(b"alknet/test").is_none()); + } + + #[test] + fn handler_registry_register_then_get() { + let mut reg = HandlerRegistry::new(); + reg.register(make_handler(b"alknet/test")); + assert_eq!(reg.alpn_strings(), vec![b"alknet/test".to_vec()]); + assert!(reg.get(b"alknet/test").is_some()); + assert!(reg.get(b"alknet/other").is_none()); + } + + #[test] + fn handler_registry_multiple_alpns() { + let mut reg = HandlerRegistry::new(); + reg.register(make_handler(b"alknet/ssh")); + reg.register(make_handler(b"alknet/call")); + let mut alpns = reg + .alpn_strings() + .into_iter() + .map(|a| String::from_utf8(a).unwrap()) + .collect::>(); + alpns.sort(); + assert_eq!(alpns, vec!["alknet/call", "alknet/ssh"]); + assert!(reg.get(b"alknet/ssh").is_some()); + assert!(reg.get(b"alknet/call").is_some()); + } + + #[test] + #[should_panic(expected = "ALPN already registered")] + fn handler_registry_register_panics_on_duplicate() { + let mut reg = HandlerRegistry::new(); + reg.register(make_handler(b"alknet/test")); + reg.register(make_handler(b"alknet/test")); + } + + #[test] + fn handler_registry_debug_lists_alpns() { + let mut reg = HandlerRegistry::new(); + reg.register(make_handler(b"alknet/test")); + let s = format!("{:?}", reg); + assert!(s.contains("alknet/test")); + } + + #[test] + fn endpoint_error_display() { + let e = EndpointError::BindFailed(io::Error::new(io::ErrorKind::AddrInUse, "busy")); + assert!(format!("{e}").contains("bind failed")); + let e = EndpointError::TlsConfig(io::Error::new(io::ErrorKind::InvalidData, "bad")); + assert!(format!("{e}").contains("tls config error")); + let e = EndpointError::HandlerNotFound(b"alknet/test".to_vec()); + assert!(format!("{e}").contains("handler not found")); + } + + #[cfg(any(feature = "quinn", feature = "iroh"))] + #[test] + fn build_auth_context_resolves_identity_from_fingerprint() { + struct StaticProvider; + impl IdentityProvider for StaticProvider { + fn resolve_from_fingerprint(&self, fp: &str) -> Option { + if fp == "SHA256:known" { + Some(Identity { + id: "SHA256:known".to_string(), + scopes: vec![], + resources: HashMap::new(), + }) + } else { + None + } + } + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(StaticProvider); + let auth = build_auth_context( + b"alknet/test", + None, + Some("SHA256:known".to_string()), + &provider, + ); + assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:known"); + assert_eq!(auth.alpn, b"alknet/test"); + assert_eq!(auth.tls_client_fingerprint.as_deref(), Some("SHA256:known")); + } + + #[cfg(any(feature = "quinn", feature = "iroh"))] + #[test] + fn build_auth_context_no_fingerprint_no_identity() { + struct NoProvider; + impl IdentityProvider for NoProvider { + fn resolve_from_fingerprint(&self, _fp: &str) -> Option { + None + } + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(NoProvider); + let auth = build_auth_context(b"alknet/test", None, None, &provider); + assert!(auth.identity.is_none()); + assert!(auth.tls_client_fingerprint.is_none()); + } + + #[cfg(any(feature = "quinn", feature = "iroh"))] + #[test] + fn build_auth_context_fingerprint_unknown_identity_none() { + struct StaticProvider; + impl IdentityProvider for StaticProvider { + fn resolve_from_fingerprint(&self, _fp: &str) -> Option { + None + } + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(StaticProvider); + let auth = build_auth_context( + b"alknet/test", + None, + Some("SHA256:unknown".to_string()), + &provider, + ); + assert!(auth.identity.is_none()); + assert!(auth.tls_client_fingerprint.is_some()); + } + + #[cfg(feature = "iroh")] + #[test] + fn raw_key_cert_resolver_only_raw_public_keys() { + use rustls::server::ResolvesServerCert; + let mut csprng = rand::rngs::OsRng; + let sk = iroh::SecretKey::generate(&mut csprng); + let resolver = RawKeyCertResolver::new(&sk); + assert!(resolver.only_raw_public_keys()); + } + + #[cfg(feature = "iroh")] + #[tokio::test] + async fn endpoint_constructs_with_iroh_raw_key_identity() { + let mut csprng = rand::rngs::OsRng; + let static_config = StaticConfig { + listen_addr: None, + tls_identity: Some(TlsIdentity::RawKey(iroh::SecretKey::generate(&mut csprng))), + iroh_relay: None, + drain_timeout: Duration::from_millis(10), + }; + let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default())); + struct NoProvider; + impl IdentityProvider for NoProvider { + fn resolve_from_fingerprint(&self, _: &str) -> Option { + None + } + fn resolve_from_token(&self, _: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(NoProvider); + let mut registry = HandlerRegistry::new(); + registry.register(make_handler(b"alknet/test")); + let endpoint = AlknetEndpoint::new(&static_config, registry, dynamic, provider) + .await + .expect("endpoint constructs"); + assert!(endpoint.shutdown_sender().send(true).is_ok()); + endpoint.shutdown().await.expect("shutdown ok"); + } + + #[cfg(feature = "iroh")] + #[tokio::test] + async fn iroh_endpoint_runs_accept_loop_and_shutdown() { + use std::sync::Mutex; + let server_sk = iroh::SecretKey::generate(&mut rand::rngs::OsRng); + let static_config = StaticConfig { + listen_addr: None, + tls_identity: Some(TlsIdentity::RawKey(server_sk)), + iroh_relay: None, + drain_timeout: Duration::from_millis(20), + }; + let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default())); + struct NoProvider; + impl IdentityProvider for NoProvider { + fn resolve_from_fingerprint(&self, _: &str) -> Option { + None + } + fn resolve_from_token(&self, _: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(NoProvider); + + let connected = Arc::new(Mutex::new(false)); + let connected_clone = connected.clone(); + struct CountingHandler { + alpn: &'static [u8], + connected: Arc>, + } + #[async_trait] + impl ProtocolHandler for CountingHandler { + fn alpn(&self) -> &'static [u8] { + self.alpn + } + async fn handle( + &self, + _conn: Connection, + _auth: &AuthContext, + ) -> Result<(), HandlerError> { + *self.connected.lock().unwrap() = true; + Ok(()) + } + } + let mut registry = HandlerRegistry::new(); + registry.register(Arc::new(CountingHandler { + alpn: b"alknet/test", + connected: connected_clone, + })); + let endpoint = Arc::new( + AlknetEndpoint::new(&static_config, registry, dynamic, provider) + .await + .expect("endpoint constructs"), + ); + + let run_endpoint = endpoint.clone(); + let run_task = tokio::spawn(async move { + run_endpoint.run().await; + }); + + let _ = endpoint.shutdown_sender().send(true); + endpoint.shutdown().await.ok(); + let _ = run_task.await; + assert!(!*connected.lock().unwrap()); + } + + #[cfg(feature = "quinn")] + #[test] + fn self_signed_cert_generation_produces_cert_and_key() { + let cert = generate_self_signed_cert().expect("self-signed cert generates"); + assert!(!cert.cert_chain.is_empty()); + assert!(!cert.private_key.secret_der().is_empty()); + } + + #[cfg(any(feature = "quinn", feature = "iroh"))] + #[test] + fn dispatch_decision_logic_lookup_and_auth() { + let mut registry = HandlerRegistry::new(); + registry.register(make_handler(b"alknet/ssh")); + registry.register(make_handler(b"alknet/call")); + + struct StaticProvider; + impl IdentityProvider for StaticProvider { + fn resolve_from_fingerprint(&self, fp: &str) -> Option { + if fp == "SHA256:caller" { + Some(Identity { + id: "SHA256:caller".to_string(), + scopes: vec!["relay:connect".to_string()], + resources: HashMap::new(), + }) + } else { + None + } + } + fn resolve_from_token(&self, _: &AuthToken) -> Option { + None + } + } + let provider: Arc = Arc::new(StaticProvider); + + let ssh_handler = registry.get(b"alknet/ssh").expect("ssh handler registered"); + assert_eq!(ssh_handler.alpn(), b"alknet/ssh"); + let auth = build_auth_context( + b"alknet/ssh", + Some(std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + 1234, + )), + Some("SHA256:caller".to_string()), + &provider, + ); + assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:caller"); + assert_eq!(auth.alpn, b"alknet/ssh"); + + let unknown = registry.get(b"alknet/unknown"); + assert!(unknown.is_none(), "unknown ALPN has no handler"); + } +}