//! Endpoint: `AlknetEndpoint`, `HandlerRegistry`, `EndpointError`. //! //! See `docs/architecture/crates/core/endpoint.md` for the full specification. use std::collections::HashMap; use std::io; #[cfg(any(feature = "quinn", feature = "iroh"))] use std::net::SocketAddr; #[cfg(feature = "quinn")] use std::path::Path; use std::sync::Arc; #[cfg(any(feature = "quinn", feature = "iroh"))] use std::time::Duration; #[cfg(any(feature = "quinn", feature = "iroh"))] use arc_swap::ArcSwap; #[cfg(any(feature = "quinn", feature = "iroh"))] use tokio::sync::watch; #[cfg(any(feature = "quinn", feature = "iroh"))] use tracing::{debug, error, warn}; #[cfg(any(feature = "quinn", feature = "iroh"))] use crate::auth::{AuthContext, IdentityProvider}; #[cfg(any(feature = "quinn", feature = "iroh"))] use crate::config::{DynamicConfig, StaticConfig, TlsIdentity}; #[cfg(not(any(feature = "quinn", feature = "iroh")))] use crate::types::ProtocolHandler; #[cfg(any(feature = "quinn", feature = "iroh"))] use crate::types::{Connection, ProtocolHandler}; #[derive(Debug, thiserror::Error)] pub enum EndpointError { #[error("bind failed: {0}")] BindFailed(io::Error), #[error("tls config error: {0}")] TlsConfig(io::Error), #[error("handler not found for ALPN {0:?}")] HandlerNotFound(Vec), } pub struct HandlerRegistry { handlers: HashMap<&'static [u8], Arc>, } impl HandlerRegistry { pub fn new() -> Self { Self { handlers: HashMap::new(), } } pub fn register(&mut self, handler: Arc) { let alpn = handler.alpn(); if self.handlers.contains_key(alpn) { panic!( "HandlerRegistry: ALPN already registered: {:?}", String::from_utf8_lossy(alpn) ); } self.handlers.insert(alpn, handler); } pub fn get(&self, alpn: &[u8]) -> Option<&Arc> { self.handlers.get(alpn) } pub fn alpn_strings(&self) -> Vec> { self.handlers.keys().map(|k| k.to_vec()).collect() } } impl Default for HandlerRegistry { fn default() -> Self { Self::new() } } impl std::fmt::Debug for HandlerRegistry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("HandlerRegistry") .field( "alpns", &self .handlers .keys() .map(|k| String::from_utf8_lossy(k).to_string()) .collect::>(), ) .finish() } } #[cfg(any(feature = "quinn", feature = "iroh"))] pub struct AlknetEndpoint { #[cfg(feature = "quinn")] quinn: Option, #[cfg(feature = "iroh")] iroh: Option, handlers: Arc, #[allow(dead_code)] dynamic: Arc>, identity_provider: Arc, shutdown_tx: watch::Sender, shutdown_rx: watch::Receiver, drain_timeout: Duration, } #[cfg(any(feature = "quinn", feature = "iroh"))] impl std::fmt::Debug for AlknetEndpoint { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AlknetEndpoint") .field("handlers", &self.handlers) .field("drain_timeout", &self.drain_timeout) .finish() } } #[cfg(any(feature = "quinn", feature = "iroh"))] impl AlknetEndpoint { pub async fn new( static_config: &StaticConfig, handlers: HandlerRegistry, dynamic: Arc>, identity_provider: Arc, ) -> Result { let handlers = Arc::new(handlers); let alpns = handlers.alpn_strings(); let (shutdown_tx, shutdown_rx) = watch::channel(false); let drain_timeout = static_config.drain_timeout; #[cfg(feature = "quinn")] let quinn = if let Some(listen_addr) = static_config.listen_addr { let tls_identity = static_config.tls_identity.as_ref().ok_or_else(|| { EndpointError::TlsConfig(io::Error::new( io::ErrorKind::InvalidInput, "quinn endpoint requires tls_identity in static config", )) })?; let server_config = build_quinn_server_config(tls_identity, &alpns)?; let endpoint = quinn::Endpoint::server(server_config, listen_addr) .map_err(EndpointError::BindFailed)?; Some(endpoint) } else { None }; #[cfg(not(feature = "quinn"))] if static_config.listen_addr.is_some() { return Err(EndpointError::TlsConfig(io::Error::new( io::ErrorKind::Unsupported, "quinn feature is not enabled but listen_addr was set", ))); } #[cfg(feature = "iroh")] let iroh = if static_config.iroh_relay.is_some() || has_iroh_identity(static_config) { Some(build_iroh_endpoint(static_config, &alpns).await?) } else { None }; Ok(Self { #[cfg(feature = "quinn")] quinn, #[cfg(feature = "iroh")] iroh, handlers, dynamic, identity_provider, shutdown_tx, shutdown_rx, drain_timeout, }) } pub fn shutdown_sender(&self) -> watch::Sender { self.shutdown_tx.clone() } pub async fn shutdown(&self) -> Result<(), EndpointError> { let _ = self.shutdown_tx.send(true); #[cfg(feature = "quinn")] if let Some(quinn) = &self.quinn { quinn.close(0u32.into(), b"shutdown"); } #[cfg(feature = "iroh")] if let Some(iroh) = &self.iroh { iroh.close().await; } tokio::time::sleep(self.drain_timeout).await; #[cfg(feature = "quinn")] if let Some(quinn) = &self.quinn { quinn.wait_idle().await; } Ok(()) } pub async fn run(self: Arc) { let mut tasks: Vec> = Vec::new(); #[cfg(feature = "quinn")] if let Some(quinn) = &self.quinn { let quinn = quinn.clone(); let handlers = self.handlers.clone(); let identity_provider = self.identity_provider.clone(); let mut shutdown_rx = self.shutdown_rx.clone(); let task = tokio::spawn(async move { run_quinn_accept_loop(quinn, handlers, identity_provider, &mut shutdown_rx).await; }); tasks.push(task); } #[cfg(feature = "iroh")] if let Some(iroh) = &self.iroh { let iroh = iroh.clone(); let handlers = self.handlers.clone(); let identity_provider = self.identity_provider.clone(); let mut shutdown_rx = self.shutdown_rx.clone(); let task = tokio::spawn(async move { run_iroh_accept_loop(iroh, handlers, identity_provider, &mut shutdown_rx).await; }); tasks.push(task); } for task in tasks { let _ = task.await; } } } #[cfg(feature = "iroh")] fn has_iroh_identity(static_config: &StaticConfig) -> bool { matches!( static_config.tls_identity.as_ref(), Some(TlsIdentity::RawKey(_)) ) } #[cfg(feature = "quinn")] async fn run_quinn_accept_loop( quinn: quinn::Endpoint, handlers: Arc, identity_provider: Arc, shutdown_rx: &mut watch::Receiver, ) { loop { tokio::select! { _ = shutdown_rx.changed() => { debug!("quinn accept loop: shutdown signaled"); break; } incoming = quinn.accept() => { let Some(incoming) = incoming else { debug!("quinn accept loop: endpoint closed"); break; }; let connecting = match incoming.accept() { Ok(c) => c, Err(e) => { warn!("quinn accept failed: {e}"); continue; } }; let handlers = handlers.clone(); let identity_provider = identity_provider.clone(); tokio::spawn(async move { let connection = match connecting.await { Ok(conn) => conn, Err(e) => { warn!("quinn TLS handshake failure: {e}"); return; } }; dispatch_quinn(connection, &handlers, &identity_provider); }); } } } } #[cfg(feature = "quinn")] fn dispatch_quinn( connection: quinn::Connection, handlers: &HandlerRegistry, identity_provider: &Arc, ) { let alpn = extract_quinn_alpn(&connection); let handler = match handlers.get(&alpn) { Some(h) => h.clone(), None => { connection.close(0u32.into(), b"no handler"); warn!( "quinn dispatch: no handler for ALPN {:?}", String::from_utf8_lossy(&alpn) ); return; } }; let remote_addr = Some(connection.remote_address()); let auth = build_auth_context(&alpn, remote_addr, None, identity_provider); let conn = Connection::from_quinn_with_alpn(connection, alpn.clone()); tokio::spawn(async move { if let Err(e) = handler.handle(conn, &auth).await { error!("handler returned error: {e}"); } }); } #[cfg(feature = "quinn")] fn extract_quinn_alpn(connection: &quinn::Connection) -> Vec { use quinn::crypto::rustls::HandshakeData; if let Some(data) = connection.handshake_data() { if let Ok(hs) = data.downcast::() { if let Some(protocol) = hs.protocol { return protocol; } } } Vec::new() } #[cfg(feature = "iroh")] async fn run_iroh_accept_loop( iroh: iroh::Endpoint, handlers: Arc, identity_provider: Arc, shutdown_rx: &mut watch::Receiver, ) { loop { tokio::select! { _ = shutdown_rx.changed() => { debug!("iroh accept loop: shutdown signaled"); break; } incoming = iroh.accept() => { let Some(incoming) = incoming else { debug!("iroh accept loop: endpoint closed"); break; }; let handlers = handlers.clone(); let identity_provider = identity_provider.clone(); tokio::spawn(async move { let mut connecting = match incoming.accept() { Ok(c) => c, Err(e) => { warn!("iroh accept failed: {e}"); return; } }; let alpn = match connecting.alpn().await { Ok(alpn) => alpn, Err(e) => { warn!("iroh ALPN negotiation failed: {e}"); return; } }; let connection = match connecting.await { Ok(conn) => conn, Err(e) => { warn!("iroh handshake completion failed: {e}"); return; } }; dispatch_iroh(connection, alpn, &handlers, &identity_provider); }); } } } } #[cfg(feature = "iroh")] fn dispatch_iroh( connection: iroh::endpoint::Connection, alpn: Vec, handlers: &HandlerRegistry, identity_provider: &Arc, ) { let handler = match handlers.get(&alpn) { Some(h) => h.clone(), None => { connection.close(0u32.into(), b"no handler"); warn!( "iroh dispatch: no handler for ALPN {:?}", String::from_utf8_lossy(&alpn) ); return; } }; let auth = build_auth_context(&alpn, None, None, identity_provider); let conn = Connection::from_iroh(connection); tokio::spawn(async move { if let Err(e) = handler.handle(conn, &auth).await { error!("handler returned error: {e}"); } }); } #[cfg(any(feature = "quinn", feature = "iroh"))] fn build_auth_context( alpn: &[u8], remote_addr: Option, tls_client_fingerprint: Option, identity_provider: &Arc, ) -> AuthContext { let identity = tls_client_fingerprint .as_ref() .and_then(|fp| identity_provider.resolve_from_fingerprint(fp)); AuthContext { identity, alpn: alpn.to_vec(), remote_addr, tls_client_fingerprint, } } #[cfg(feature = "quinn")] fn build_quinn_server_config( tls_identity: &TlsIdentity, alpns: &[Vec], ) -> Result { use quinn::crypto::rustls::QuicServerConfig; let rustls_config = build_rustls_server_config(tls_identity, alpns)?; let quic_server_config = QuicServerConfig::try_from(rustls_config) .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; Ok(quinn::ServerConfig::with_crypto(Arc::new( quic_server_config, ))) } #[cfg(feature = "quinn")] fn build_rustls_server_config( tls_identity: &TlsIdentity, alpns: &[Vec], ) -> Result { let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider()); match tls_identity { TlsIdentity::X509 { cert, key } => { let cert_chain = load_cert_chain(cert)?; let private_key = load_private_key(key)?; let mut config = rustls::ServerConfig::builder_with_provider(provider) .with_safe_default_protocol_versions() .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? .with_no_client_auth() .with_single_cert(cert_chain, private_key) .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; config.alpn_protocols = alpns.to_vec(); config.max_early_data_size = u32::MAX; Ok(config) } #[cfg(feature = "iroh")] TlsIdentity::RawKey(secret_key) => { let resolver = Arc::new(RawKeyCertResolver::new(secret_key)); let mut config = rustls::ServerConfig::builder_with_provider(provider) .with_safe_default_protocol_versions() .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? .with_no_client_auth() .with_cert_resolver(resolver); config.alpn_protocols = alpns.to_vec(); config.max_early_data_size = u32::MAX; Ok(config) } TlsIdentity::SelfSigned => { let cert = generate_self_signed_cert()?; let mut config = rustls::ServerConfig::builder_with_provider(provider) .with_safe_default_protocol_versions() .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))? .with_no_client_auth() .with_single_cert(cert.cert_chain, cert.private_key) .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; config.alpn_protocols = alpns.to_vec(); config.max_early_data_size = u32::MAX; Ok(config) } } } #[cfg(feature = "iroh")] async fn build_iroh_endpoint( static_config: &StaticConfig, alpns: &[Vec], ) -> Result { let mut builder = iroh::Endpoint::builder(); if let Some(TlsIdentity::RawKey(secret_key)) = static_config.tls_identity.as_ref() { builder = builder.secret_key(secret_key.clone()); } else { let mut csprng = rand::rngs::OsRng; builder = builder.secret_key(iroh::SecretKey::generate(&mut csprng)); } builder = builder.alpns(alpns.to_vec()); if let Some(relay_url) = static_config.iroh_relay.as_ref() { let relay_map = iroh::RelayMap::from(relay_url.clone()); builder = builder.relay_mode(iroh::RelayMode::Custom(relay_map)); } else { builder = builder.relay_mode(iroh::RelayMode::Disabled); } builder .bind() .await .map_err(|e| EndpointError::BindFailed(io::Error::other(e))) } #[cfg(feature = "quinn")] fn load_cert_chain( path: &Path, ) -> Result>, EndpointError> { let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; let mut reader = std::io::BufReader::new(bytes.as_slice()); rustls_pemfile::certs(&mut reader) .collect::, _>>() .map_err(|e| EndpointError::TlsConfig(io::Error::other(e))) } #[cfg(feature = "quinn")] fn load_private_key( path: &Path, ) -> Result, EndpointError> { let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; let mut reader = std::io::BufReader::new(bytes.as_slice()); match rustls_pemfile::private_key(&mut reader) { Ok(Some(key)) => Ok(key), Ok(None) => Err(EndpointError::TlsConfig(io::Error::new( io::ErrorKind::InvalidData, "no private key found in file", ))), Err(e) => Err(EndpointError::TlsConfig(io::Error::other(e))), } } #[cfg(feature = "quinn")] struct SelfSignedCert { cert_chain: Vec>, private_key: rustls::pki_types::PrivateKeyDer<'static>, } #[cfg(feature = "quinn")] fn generate_self_signed_cert() -> Result { use rcgen::{CertificateParams, KeyPair}; let key_pair = KeyPair::generate().map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; let params = CertificateParams::default(); let cert = params .self_signed(&key_pair) .map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?; let cert_der = cert.der().clone(); let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8( rustls::pki_types::PrivatePkcs8KeyDer::from(key_pair.serialize_der()), ); Ok(SelfSignedCert { cert_chain: vec![cert_der], private_key: key_der, }) } #[cfg(all(feature = "quinn", feature = "iroh"))] struct RawKeyCertResolver { key: Arc, } #[cfg(all(feature = "quinn", feature = "iroh"))] impl RawKeyCertResolver { fn new(secret_key: &iroh::SecretKey) -> Self { let signing_key = Arc::new(Ed25519SigningKey::new(secret_key.clone())); let public_key = signing_key.spki_public_key(); let cert = rustls::pki_types::CertificateDer::from(public_key.to_vec()); let certified_key = rustls::sign::CertifiedKey::new(vec![cert], signing_key); Self { key: Arc::new(certified_key), } } } #[cfg(all(feature = "quinn", feature = "iroh"))] impl rustls::server::ResolvesServerCert for RawKeyCertResolver { fn resolve( &self, _client_hello: rustls::server::ClientHello<'_>, ) -> Option> { Some(Arc::clone(&self.key)) } fn only_raw_public_keys(&self) -> bool { true } } #[cfg(all(feature = "quinn", feature = "iroh"))] impl std::fmt::Debug for RawKeyCertResolver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawKeyCertResolver").finish() } } #[cfg(all(feature = "quinn", feature = "iroh"))] #[derive(Clone)] struct Ed25519SigningKey { key: iroh::SecretKey, } #[cfg(all(feature = "quinn", feature = "iroh"))] impl std::fmt::Debug for Ed25519SigningKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Ed25519SigningKey").finish() } } #[cfg(all(feature = "quinn", feature = "iroh"))] impl Ed25519SigningKey { fn new(key: iroh::SecretKey) -> Self { Self { key } } fn spki_public_key(&self) -> rustls::pki_types::SubjectPublicKeyInfoDer<'static> { rustls::sign::public_key_to_spki( &rustls::pki_types::alg_id::ED25519, self.key.public().as_bytes(), ) } } #[cfg(all(feature = "quinn", feature = "iroh"))] impl rustls::sign::SigningKey for Ed25519SigningKey { fn choose_scheme( &self, offered: &[rustls::SignatureScheme], ) -> Option> { if offered.contains(&rustls::SignatureScheme::ED25519) { Some(Box::new(self.clone())) } else { None } } fn algorithm(&self) -> rustls::SignatureAlgorithm { rustls::SignatureAlgorithm::ED25519 } fn public_key(&self) -> Option> { Some(self.spki_public_key()) } } #[cfg(all(feature = "quinn", feature = "iroh"))] impl rustls::sign::Signer for Ed25519SigningKey { fn sign(&self, message: &[u8]) -> Result, rustls::Error> { Ok(self.key.sign(message).to_bytes().to_vec()) } fn scheme(&self) -> rustls::SignatureScheme { rustls::SignatureScheme::ED25519 } } #[cfg(test)] mod tests { use super::*; use crate::auth::AuthContext; #[cfg(any(feature = "quinn", feature = "iroh"))] use crate::auth::{AuthToken, Identity, IdentityProvider}; use crate::types::{Connection, HandlerError}; use async_trait::async_trait; struct DummyHandler { alpn: &'static [u8], } #[async_trait] impl ProtocolHandler for DummyHandler { fn alpn(&self) -> &'static [u8] { self.alpn } async fn handle( &self, _connection: Connection, _auth: &AuthContext, ) -> Result<(), HandlerError> { Ok(()) } } fn make_handler(alpn: &'static [u8]) -> Arc { Arc::new(DummyHandler { alpn }) } #[test] fn handler_registry_new_is_empty() { let reg = HandlerRegistry::new(); assert!(reg.alpn_strings().is_empty()); assert!(reg.get(b"alknet/test").is_none()); } #[test] fn handler_registry_register_then_get() { let mut reg = HandlerRegistry::new(); reg.register(make_handler(b"alknet/test")); assert_eq!(reg.alpn_strings(), vec![b"alknet/test".to_vec()]); assert!(reg.get(b"alknet/test").is_some()); assert!(reg.get(b"alknet/other").is_none()); } #[test] fn handler_registry_multiple_alpns() { let mut reg = HandlerRegistry::new(); reg.register(make_handler(b"alknet/ssh")); reg.register(make_handler(b"alknet/call")); let mut alpns = reg .alpn_strings() .into_iter() .map(|a| String::from_utf8(a).unwrap()) .collect::>(); alpns.sort(); assert_eq!(alpns, vec!["alknet/call", "alknet/ssh"]); assert!(reg.get(b"alknet/ssh").is_some()); assert!(reg.get(b"alknet/call").is_some()); } #[test] #[should_panic(expected = "ALPN already registered")] fn handler_registry_register_panics_on_duplicate() { let mut reg = HandlerRegistry::new(); reg.register(make_handler(b"alknet/test")); reg.register(make_handler(b"alknet/test")); } #[test] fn handler_registry_debug_lists_alpns() { let mut reg = HandlerRegistry::new(); reg.register(make_handler(b"alknet/test")); let s = format!("{:?}", reg); assert!(s.contains("alknet/test")); } #[test] fn endpoint_error_display() { let e = EndpointError::BindFailed(io::Error::new(io::ErrorKind::AddrInUse, "busy")); assert!(format!("{e}").contains("bind failed")); let e = EndpointError::TlsConfig(io::Error::new(io::ErrorKind::InvalidData, "bad")); assert!(format!("{e}").contains("tls config error")); let e = EndpointError::HandlerNotFound(b"alknet/test".to_vec()); assert!(format!("{e}").contains("handler not found")); } #[cfg(any(feature = "quinn", feature = "iroh"))] #[test] fn build_auth_context_resolves_identity_from_fingerprint() { struct StaticProvider; impl IdentityProvider for StaticProvider { fn resolve_from_fingerprint(&self, fp: &str) -> Option { if fp == "SHA256:known" { Some(Identity { id: "SHA256:known".to_string(), scopes: vec![], resources: HashMap::new(), }) } else { None } } fn resolve_from_token(&self, _token: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(StaticProvider); let auth = build_auth_context( b"alknet/test", None, Some("SHA256:known".to_string()), &provider, ); assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:known"); assert_eq!(auth.alpn, b"alknet/test"); assert_eq!(auth.tls_client_fingerprint.as_deref(), Some("SHA256:known")); } #[cfg(any(feature = "quinn", feature = "iroh"))] #[test] fn build_auth_context_no_fingerprint_no_identity() { struct NoProvider; impl IdentityProvider for NoProvider { fn resolve_from_fingerprint(&self, _fp: &str) -> Option { None } fn resolve_from_token(&self, _token: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(NoProvider); let auth = build_auth_context(b"alknet/test", None, None, &provider); assert!(auth.identity.is_none()); assert!(auth.tls_client_fingerprint.is_none()); } #[cfg(any(feature = "quinn", feature = "iroh"))] #[test] fn build_auth_context_fingerprint_unknown_identity_none() { struct StaticProvider; impl IdentityProvider for StaticProvider { fn resolve_from_fingerprint(&self, _fp: &str) -> Option { None } fn resolve_from_token(&self, _token: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(StaticProvider); let auth = build_auth_context( b"alknet/test", None, Some("SHA256:unknown".to_string()), &provider, ); assert!(auth.identity.is_none()); assert!(auth.tls_client_fingerprint.is_some()); } #[cfg(all(feature = "quinn", feature = "iroh"))] #[test] fn raw_key_cert_resolver_only_raw_public_keys() { use rustls::server::ResolvesServerCert; let mut csprng = rand::rngs::OsRng; let sk = iroh::SecretKey::generate(&mut csprng); let resolver = RawKeyCertResolver::new(&sk); assert!(resolver.only_raw_public_keys()); } #[cfg(feature = "iroh")] #[tokio::test] async fn endpoint_constructs_with_iroh_raw_key_identity() { let mut csprng = rand::rngs::OsRng; let static_config = StaticConfig { listen_addr: None, tls_identity: Some(TlsIdentity::RawKey(iroh::SecretKey::generate(&mut csprng))), iroh_relay: None, drain_timeout: Duration::from_millis(10), }; let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default())); struct NoProvider; impl IdentityProvider for NoProvider { fn resolve_from_fingerprint(&self, _: &str) -> Option { None } fn resolve_from_token(&self, _: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(NoProvider); let mut registry = HandlerRegistry::new(); registry.register(make_handler(b"alknet/test")); let endpoint = AlknetEndpoint::new(&static_config, registry, dynamic, provider) .await .expect("endpoint constructs"); assert!(endpoint.shutdown_sender().send(true).is_ok()); endpoint.shutdown().await.expect("shutdown ok"); } #[cfg(feature = "iroh")] #[tokio::test] async fn iroh_endpoint_runs_accept_loop_and_shutdown() { use std::sync::Mutex; let server_sk = iroh::SecretKey::generate(&mut rand::rngs::OsRng); let static_config = StaticConfig { listen_addr: None, tls_identity: Some(TlsIdentity::RawKey(server_sk)), iroh_relay: None, drain_timeout: Duration::from_millis(20), }; let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default())); struct NoProvider; impl IdentityProvider for NoProvider { fn resolve_from_fingerprint(&self, _: &str) -> Option { None } fn resolve_from_token(&self, _: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(NoProvider); let connected = Arc::new(Mutex::new(false)); let connected_clone = connected.clone(); struct CountingHandler { alpn: &'static [u8], connected: Arc>, } #[async_trait] impl ProtocolHandler for CountingHandler { fn alpn(&self) -> &'static [u8] { self.alpn } async fn handle( &self, _conn: Connection, _auth: &AuthContext, ) -> Result<(), HandlerError> { *self.connected.lock().unwrap() = true; Ok(()) } } let mut registry = HandlerRegistry::new(); registry.register(Arc::new(CountingHandler { alpn: b"alknet/test", connected: connected_clone, })); let endpoint = Arc::new( AlknetEndpoint::new(&static_config, registry, dynamic, provider) .await .expect("endpoint constructs"), ); let run_endpoint = endpoint.clone(); let run_task = tokio::spawn(async move { run_endpoint.run().await; }); let _ = endpoint.shutdown_sender().send(true); endpoint.shutdown().await.ok(); let _ = run_task.await; assert!(!*connected.lock().unwrap()); } #[cfg(feature = "quinn")] #[test] fn self_signed_cert_generation_produces_cert_and_key() { let cert = generate_self_signed_cert().expect("self-signed cert generates"); assert!(!cert.cert_chain.is_empty()); assert!(!cert.private_key.secret_der().is_empty()); } #[cfg(any(feature = "quinn", feature = "iroh"))] #[test] fn dispatch_decision_logic_lookup_and_auth() { let mut registry = HandlerRegistry::new(); registry.register(make_handler(b"alknet/ssh")); registry.register(make_handler(b"alknet/call")); struct StaticProvider; impl IdentityProvider for StaticProvider { fn resolve_from_fingerprint(&self, fp: &str) -> Option { if fp == "SHA256:caller" { Some(Identity { id: "SHA256:caller".to_string(), scopes: vec!["relay:connect".to_string()], resources: HashMap::new(), }) } else { None } } fn resolve_from_token(&self, _: &AuthToken) -> Option { None } } let provider: Arc = Arc::new(StaticProvider); let ssh_handler = registry.get(b"alknet/ssh").expect("ssh handler registered"); assert_eq!(ssh_handler.alpn(), b"alknet/ssh"); let auth = build_auth_context( b"alknet/ssh", Some(std::net::SocketAddr::new( std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 1234, )), Some("SHA256:caller".to_string()), &provider, ); assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:caller"); assert_eq!(auth.alpn, b"alknet/ssh"); let unknown = registry.get(b"alknet/unknown"); assert!(unknown.is_none(), "unknown ALPN has no handler"); } }