feat(core): implement AlknetEndpoint, HandlerRegistry, accept loops (quinn + iroh), TLS identity (RawKey/X509/SelfSigned), and graceful shutdown (task: core/endpoint)

This commit is contained in:
2026-06-23 15:12:14 +00:00
parent dabb0d8b68
commit 8d056a2b59
3 changed files with 994 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ quinn = { version = "0.11", optional = true }
iroh = { version = "0.35", optional = true }
rustls = "0.23"
rustls-pki-types = "1"
rustls-pemfile = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
@@ -31,4 +32,6 @@ zeroize = { version = "1", features = ["alloc", "derive"] }
bytes = "1"
futures = "0.3"
sha2 = "0.10"
hex = "0.4"
hex = "0.4"
rand = "0.8"
rcgen = "0.13"

View File

@@ -2,4 +2,981 @@
//!
//! See `docs/architecture/crates/core/endpoint.md` for the full specification.
// TODO: implement
use std::collections::HashMap;
use std::io;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use std::net::SocketAddr;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use std::path::Path;
use std::sync::Arc;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use std::time::Duration;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use arc_swap::ArcSwap;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use tokio::sync::watch;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use tracing::{debug, error, warn};
#[cfg(any(feature = "quinn", feature = "iroh"))]
use crate::auth::{AuthContext, IdentityProvider};
#[cfg(any(feature = "quinn", feature = "iroh"))]
use crate::config::{DynamicConfig, StaticConfig, TlsIdentity};
#[cfg(not(any(feature = "quinn", feature = "iroh")))]
use crate::types::ProtocolHandler;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use crate::types::{Connection, ProtocolHandler};
#[derive(Debug, thiserror::Error)]
pub enum EndpointError {
#[error("bind failed: {0}")]
BindFailed(io::Error),
#[error("tls config error: {0}")]
TlsConfig(io::Error),
#[error("handler not found for ALPN {0:?}")]
HandlerNotFound(Vec<u8>),
}
pub struct HandlerRegistry {
handlers: HashMap<&'static [u8], Arc<dyn ProtocolHandler>>,
}
impl HandlerRegistry {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
pub fn register(&mut self, handler: Arc<dyn ProtocolHandler>) {
let alpn = handler.alpn();
if self.handlers.contains_key(alpn) {
panic!(
"HandlerRegistry: ALPN already registered: {:?}",
String::from_utf8_lossy(alpn)
);
}
self.handlers.insert(alpn, handler);
}
pub fn get(&self, alpn: &[u8]) -> Option<&Arc<dyn ProtocolHandler>> {
self.handlers.get(alpn)
}
pub fn alpn_strings(&self) -> Vec<Vec<u8>> {
self.handlers.keys().map(|k| k.to_vec()).collect()
}
}
impl Default for HandlerRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for HandlerRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HandlerRegistry")
.field(
"alpns",
&self
.handlers
.keys()
.map(|k| String::from_utf8_lossy(k).to_string())
.collect::<Vec<_>>(),
)
.finish()
}
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
pub struct AlknetEndpoint {
#[cfg(feature = "quinn")]
quinn: Option<quinn::Endpoint>,
#[cfg(feature = "iroh")]
iroh: Option<iroh::Endpoint>,
handlers: Arc<HandlerRegistry>,
#[allow(dead_code)]
dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Arc<dyn IdentityProvider>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
drain_timeout: Duration,
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
impl std::fmt::Debug for AlknetEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlknetEndpoint")
.field("handlers", &self.handlers)
.field("drain_timeout", &self.drain_timeout)
.finish()
}
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
impl AlknetEndpoint {
pub async fn new(
static_config: &StaticConfig,
handlers: HandlerRegistry,
dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Arc<dyn IdentityProvider>,
) -> Result<Self, EndpointError> {
let handlers = Arc::new(handlers);
let alpns = handlers.alpn_strings();
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let drain_timeout = static_config.drain_timeout;
#[cfg(feature = "quinn")]
let quinn = if let Some(listen_addr) = static_config.listen_addr {
let tls_identity = static_config.tls_identity.as_ref().ok_or_else(|| {
EndpointError::TlsConfig(io::Error::new(
io::ErrorKind::InvalidInput,
"quinn endpoint requires tls_identity in static config",
))
})?;
let server_config = build_quinn_server_config(tls_identity, &alpns)?;
let endpoint = quinn::Endpoint::server(server_config, listen_addr)
.map_err(EndpointError::BindFailed)?;
Some(endpoint)
} else {
None
};
#[cfg(not(feature = "quinn"))]
if static_config.listen_addr.is_some() {
return Err(EndpointError::TlsConfig(io::Error::new(
io::ErrorKind::Unsupported,
"quinn feature is not enabled but listen_addr was set",
)));
}
#[cfg(feature = "iroh")]
let iroh = if static_config.iroh_relay.is_some() || has_iroh_identity(static_config) {
Some(build_iroh_endpoint(static_config, &alpns).await?)
} else {
None
};
Ok(Self {
#[cfg(feature = "quinn")]
quinn,
#[cfg(feature = "iroh")]
iroh,
handlers,
dynamic,
identity_provider,
shutdown_tx,
shutdown_rx,
drain_timeout,
})
}
pub fn shutdown_sender(&self) -> watch::Sender<bool> {
self.shutdown_tx.clone()
}
pub async fn shutdown(&self) -> Result<(), EndpointError> {
let _ = self.shutdown_tx.send(true);
#[cfg(feature = "quinn")]
if let Some(quinn) = &self.quinn {
quinn.close(0u32.into(), b"shutdown");
}
#[cfg(feature = "iroh")]
if let Some(iroh) = &self.iroh {
iroh.close().await;
}
tokio::time::sleep(self.drain_timeout).await;
#[cfg(feature = "quinn")]
if let Some(quinn) = &self.quinn {
quinn.wait_idle().await;
}
Ok(())
}
pub async fn run(self: Arc<Self>) {
let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
#[cfg(feature = "quinn")]
if let Some(quinn) = &self.quinn {
let quinn = quinn.clone();
let handlers = self.handlers.clone();
let identity_provider = self.identity_provider.clone();
let mut shutdown_rx = self.shutdown_rx.clone();
let task = tokio::spawn(async move {
run_quinn_accept_loop(quinn, handlers, identity_provider, &mut shutdown_rx).await;
});
tasks.push(task);
}
#[cfg(feature = "iroh")]
if let Some(iroh) = &self.iroh {
let iroh = iroh.clone();
let handlers = self.handlers.clone();
let identity_provider = self.identity_provider.clone();
let mut shutdown_rx = self.shutdown_rx.clone();
let task = tokio::spawn(async move {
run_iroh_accept_loop(iroh, handlers, identity_provider, &mut shutdown_rx).await;
});
tasks.push(task);
}
for task in tasks {
let _ = task.await;
}
}
}
#[cfg(feature = "iroh")]
fn has_iroh_identity(static_config: &StaticConfig) -> bool {
matches!(
static_config.tls_identity.as_ref(),
Some(TlsIdentity::RawKey(_))
)
}
#[cfg(feature = "quinn")]
async fn run_quinn_accept_loop(
quinn: quinn::Endpoint,
handlers: Arc<HandlerRegistry>,
identity_provider: Arc<dyn IdentityProvider>,
shutdown_rx: &mut watch::Receiver<bool>,
) {
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
debug!("quinn accept loop: shutdown signaled");
break;
}
incoming = quinn.accept() => {
let Some(incoming) = incoming else {
debug!("quinn accept loop: endpoint closed");
break;
};
let connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
warn!("quinn accept failed: {e}");
continue;
}
};
let handlers = handlers.clone();
let identity_provider = identity_provider.clone();
tokio::spawn(async move {
let connection = match connecting.await {
Ok(conn) => conn,
Err(e) => {
warn!("quinn TLS handshake failure: {e}");
return;
}
};
dispatch_quinn(connection, &handlers, &identity_provider);
});
}
}
}
}
#[cfg(feature = "quinn")]
fn dispatch_quinn(
connection: quinn::Connection,
handlers: &HandlerRegistry,
identity_provider: &Arc<dyn IdentityProvider>,
) {
let alpn = extract_quinn_alpn(&connection);
let handler = match handlers.get(&alpn) {
Some(h) => h.clone(),
None => {
connection.close(0u32.into(), b"no handler");
warn!(
"quinn dispatch: no handler for ALPN {:?}",
String::from_utf8_lossy(&alpn)
);
return;
}
};
let remote_addr = Some(connection.remote_address());
let auth = build_auth_context(&alpn, remote_addr, None, identity_provider);
let conn = Connection::from_quinn_with_alpn(connection, alpn.clone());
tokio::spawn(async move {
if let Err(e) = handler.handle(conn, &auth).await {
error!("handler returned error: {e}");
}
});
}
#[cfg(feature = "quinn")]
fn extract_quinn_alpn(connection: &quinn::Connection) -> Vec<u8> {
use quinn::crypto::rustls::HandshakeData;
if let Some(data) = connection.handshake_data() {
if let Ok(hs) = data.downcast::<HandshakeData>() {
if let Some(protocol) = hs.protocol {
return protocol;
}
}
}
Vec::new()
}
#[cfg(feature = "iroh")]
async fn run_iroh_accept_loop(
iroh: iroh::Endpoint,
handlers: Arc<HandlerRegistry>,
identity_provider: Arc<dyn IdentityProvider>,
shutdown_rx: &mut watch::Receiver<bool>,
) {
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
debug!("iroh accept loop: shutdown signaled");
break;
}
incoming = iroh.accept() => {
let Some(incoming) = incoming else {
debug!("iroh accept loop: endpoint closed");
break;
};
let handlers = handlers.clone();
let identity_provider = identity_provider.clone();
tokio::spawn(async move {
let mut connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
warn!("iroh accept failed: {e}");
return;
}
};
let alpn = match connecting.alpn().await {
Ok(alpn) => alpn,
Err(e) => {
warn!("iroh ALPN negotiation failed: {e}");
return;
}
};
let connection = match connecting.await {
Ok(conn) => conn,
Err(e) => {
warn!("iroh handshake completion failed: {e}");
return;
}
};
dispatch_iroh(connection, alpn, &handlers, &identity_provider);
});
}
}
}
}
#[cfg(feature = "iroh")]
fn dispatch_iroh(
connection: iroh::endpoint::Connection,
alpn: Vec<u8>,
handlers: &HandlerRegistry,
identity_provider: &Arc<dyn IdentityProvider>,
) {
let handler = match handlers.get(&alpn) {
Some(h) => h.clone(),
None => {
connection.close(0u32.into(), b"no handler");
warn!(
"iroh dispatch: no handler for ALPN {:?}",
String::from_utf8_lossy(&alpn)
);
return;
}
};
let auth = build_auth_context(&alpn, None, None, identity_provider);
let conn = Connection::from_iroh(connection);
tokio::spawn(async move {
if let Err(e) = handler.handle(conn, &auth).await {
error!("handler returned error: {e}");
}
});
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
fn build_auth_context(
alpn: &[u8],
remote_addr: Option<SocketAddr>,
tls_client_fingerprint: Option<String>,
identity_provider: &Arc<dyn IdentityProvider>,
) -> AuthContext {
let identity = tls_client_fingerprint
.as_ref()
.and_then(|fp| identity_provider.resolve_from_fingerprint(fp));
AuthContext {
identity,
alpn: alpn.to_vec(),
remote_addr,
tls_client_fingerprint,
}
}
#[cfg(feature = "quinn")]
fn build_quinn_server_config(
tls_identity: &TlsIdentity,
alpns: &[Vec<u8>],
) -> Result<quinn::ServerConfig, EndpointError> {
use quinn::crypto::rustls::QuicServerConfig;
let rustls_config = build_rustls_server_config(tls_identity, alpns)?;
let quic_server_config = QuicServerConfig::try_from(rustls_config)
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
Ok(quinn::ServerConfig::with_crypto(Arc::new(
quic_server_config,
)))
}
#[cfg(feature = "quinn")]
fn build_rustls_server_config(
tls_identity: &TlsIdentity,
alpns: &[Vec<u8>],
) -> Result<rustls::ServerConfig, EndpointError> {
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
match tls_identity {
TlsIdentity::X509 { cert, key } => {
let cert_chain = load_cert_chain(cert)?;
let private_key = load_private_key(key)?;
let mut config = rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
config.alpn_protocols = alpns.to_vec();
config.max_early_data_size = u32::MAX;
Ok(config)
}
#[cfg(feature = "iroh")]
TlsIdentity::RawKey(secret_key) => {
let resolver = Arc::new(RawKeyCertResolver::new(secret_key));
let mut config = rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
.with_no_client_auth()
.with_cert_resolver(resolver);
config.alpn_protocols = alpns.to_vec();
config.max_early_data_size = u32::MAX;
Ok(config)
}
TlsIdentity::SelfSigned => {
let cert = generate_self_signed_cert()?;
let mut config = rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
.with_no_client_auth()
.with_single_cert(cert.cert_chain, cert.private_key)
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
config.alpn_protocols = alpns.to_vec();
config.max_early_data_size = u32::MAX;
Ok(config)
}
}
}
#[cfg(feature = "iroh")]
async fn build_iroh_endpoint(
static_config: &StaticConfig,
alpns: &[Vec<u8>],
) -> Result<iroh::Endpoint, EndpointError> {
let mut builder = iroh::Endpoint::builder();
if let Some(TlsIdentity::RawKey(secret_key)) = static_config.tls_identity.as_ref() {
builder = builder.secret_key(secret_key.clone());
} else {
let mut csprng = rand::rngs::OsRng;
builder = builder.secret_key(iroh::SecretKey::generate(&mut csprng));
}
builder = builder.alpns(alpns.to_vec());
if let Some(relay_url) = static_config.iroh_relay.as_ref() {
let relay_map = iroh::RelayMap::from(relay_url.clone());
builder = builder.relay_mode(iroh::RelayMode::Custom(relay_map));
} else {
builder = builder.relay_mode(iroh::RelayMode::Disabled);
}
builder
.bind()
.await
.map_err(|e| EndpointError::BindFailed(io::Error::other(e)))
}
#[cfg(feature = "quinn")]
fn load_cert_chain(
path: &Path,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, EndpointError> {
let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
let mut reader = std::io::BufReader::new(bytes.as_slice());
rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))
}
#[cfg(feature = "quinn")]
fn load_private_key(
path: &Path,
) -> Result<rustls::pki_types::PrivateKeyDer<'static>, EndpointError> {
let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
let mut reader = std::io::BufReader::new(bytes.as_slice());
match rustls_pemfile::private_key(&mut reader) {
Ok(Some(key)) => Ok(key),
Ok(None) => Err(EndpointError::TlsConfig(io::Error::new(
io::ErrorKind::InvalidData,
"no private key found in file",
))),
Err(e) => Err(EndpointError::TlsConfig(io::Error::other(e))),
}
}
#[cfg(feature = "quinn")]
struct SelfSignedCert {
cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
private_key: rustls::pki_types::PrivateKeyDer<'static>,
}
#[cfg(feature = "quinn")]
fn generate_self_signed_cert() -> Result<SelfSignedCert, EndpointError> {
use rcgen::{CertificateParams, KeyPair};
let key_pair =
KeyPair::generate().map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
let params = CertificateParams::default();
let cert = params
.self_signed(&key_pair)
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
let cert_der = cert.der().clone();
let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(
rustls::pki_types::PrivatePkcs8KeyDer::from(key_pair.serialize_der()),
);
Ok(SelfSignedCert {
cert_chain: vec![cert_der],
private_key: key_der,
})
}
#[cfg(feature = "iroh")]
struct RawKeyCertResolver {
key: Arc<rustls::sign::CertifiedKey>,
}
#[cfg(feature = "iroh")]
impl RawKeyCertResolver {
fn new(secret_key: &iroh::SecretKey) -> Self {
let signing_key = Arc::new(Ed25519SigningKey::new(secret_key.clone()));
let public_key = signing_key.spki_public_key();
let cert = rustls::pki_types::CertificateDer::from(public_key.to_vec());
let certified_key = rustls::sign::CertifiedKey::new(vec![cert], signing_key);
Self {
key: Arc::new(certified_key),
}
}
}
#[cfg(feature = "iroh")]
impl rustls::server::ResolvesServerCert for RawKeyCertResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(Arc::clone(&self.key))
}
fn only_raw_public_keys(&self) -> bool {
true
}
}
#[cfg(feature = "iroh")]
impl std::fmt::Debug for RawKeyCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RawKeyCertResolver").finish()
}
}
#[cfg(feature = "iroh")]
#[derive(Clone)]
struct Ed25519SigningKey {
key: iroh::SecretKey,
}
#[cfg(feature = "iroh")]
impl std::fmt::Debug for Ed25519SigningKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Ed25519SigningKey").finish()
}
}
#[cfg(feature = "iroh")]
impl Ed25519SigningKey {
fn new(key: iroh::SecretKey) -> Self {
Self { key }
}
fn spki_public_key(&self) -> rustls::pki_types::SubjectPublicKeyInfoDer<'static> {
rustls::sign::public_key_to_spki(
&rustls::pki_types::alg_id::ED25519,
self.key.public().as_bytes(),
)
}
}
#[cfg(feature = "iroh")]
impl rustls::sign::SigningKey for Ed25519SigningKey {
fn choose_scheme(
&self,
offered: &[rustls::SignatureScheme],
) -> Option<Box<dyn rustls::sign::Signer>> {
if offered.contains(&rustls::SignatureScheme::ED25519) {
Some(Box::new(self.clone()))
} else {
None
}
}
fn algorithm(&self) -> rustls::SignatureAlgorithm {
rustls::SignatureAlgorithm::ED25519
}
fn public_key(&self) -> Option<rustls::pki_types::SubjectPublicKeyInfoDer<'_>> {
Some(self.spki_public_key())
}
}
#[cfg(feature = "iroh")]
impl rustls::sign::Signer for Ed25519SigningKey {
fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rustls::Error> {
Ok(self.key.sign(message).to_bytes().to_vec())
}
fn scheme(&self) -> rustls::SignatureScheme {
rustls::SignatureScheme::ED25519
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthContext;
#[cfg(any(feature = "quinn", feature = "iroh"))]
use crate::auth::{AuthToken, Identity, IdentityProvider};
use crate::types::{Connection, HandlerError};
use async_trait::async_trait;
struct DummyHandler {
alpn: &'static [u8],
}
#[async_trait]
impl ProtocolHandler for DummyHandler {
fn alpn(&self) -> &'static [u8] {
self.alpn
}
async fn handle(
&self,
_connection: Connection,
_auth: &AuthContext,
) -> Result<(), HandlerError> {
Ok(())
}
}
fn make_handler(alpn: &'static [u8]) -> Arc<dyn ProtocolHandler> {
Arc::new(DummyHandler { alpn })
}
#[test]
fn handler_registry_new_is_empty() {
let reg = HandlerRegistry::new();
assert!(reg.alpn_strings().is_empty());
assert!(reg.get(b"alknet/test").is_none());
}
#[test]
fn handler_registry_register_then_get() {
let mut reg = HandlerRegistry::new();
reg.register(make_handler(b"alknet/test"));
assert_eq!(reg.alpn_strings(), vec![b"alknet/test".to_vec()]);
assert!(reg.get(b"alknet/test").is_some());
assert!(reg.get(b"alknet/other").is_none());
}
#[test]
fn handler_registry_multiple_alpns() {
let mut reg = HandlerRegistry::new();
reg.register(make_handler(b"alknet/ssh"));
reg.register(make_handler(b"alknet/call"));
let mut alpns = reg
.alpn_strings()
.into_iter()
.map(|a| String::from_utf8(a).unwrap())
.collect::<Vec<_>>();
alpns.sort();
assert_eq!(alpns, vec!["alknet/call", "alknet/ssh"]);
assert!(reg.get(b"alknet/ssh").is_some());
assert!(reg.get(b"alknet/call").is_some());
}
#[test]
#[should_panic(expected = "ALPN already registered")]
fn handler_registry_register_panics_on_duplicate() {
let mut reg = HandlerRegistry::new();
reg.register(make_handler(b"alknet/test"));
reg.register(make_handler(b"alknet/test"));
}
#[test]
fn handler_registry_debug_lists_alpns() {
let mut reg = HandlerRegistry::new();
reg.register(make_handler(b"alknet/test"));
let s = format!("{:?}", reg);
assert!(s.contains("alknet/test"));
}
#[test]
fn endpoint_error_display() {
let e = EndpointError::BindFailed(io::Error::new(io::ErrorKind::AddrInUse, "busy"));
assert!(format!("{e}").contains("bind failed"));
let e = EndpointError::TlsConfig(io::Error::new(io::ErrorKind::InvalidData, "bad"));
assert!(format!("{e}").contains("tls config error"));
let e = EndpointError::HandlerNotFound(b"alknet/test".to_vec());
assert!(format!("{e}").contains("handler not found"));
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
#[test]
fn build_auth_context_resolves_identity_from_fingerprint() {
struct StaticProvider;
impl IdentityProvider for StaticProvider {
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
if fp == "SHA256:known" {
Some(Identity {
id: "SHA256:known".to_string(),
scopes: vec![],
resources: HashMap::new(),
})
} else {
None
}
}
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
let auth = build_auth_context(
b"alknet/test",
None,
Some("SHA256:known".to_string()),
&provider,
);
assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:known");
assert_eq!(auth.alpn, b"alknet/test");
assert_eq!(auth.tls_client_fingerprint.as_deref(), Some("SHA256:known"));
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
#[test]
fn build_auth_context_no_fingerprint_no_identity() {
struct NoProvider;
impl IdentityProvider for NoProvider {
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
None
}
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
let auth = build_auth_context(b"alknet/test", None, None, &provider);
assert!(auth.identity.is_none());
assert!(auth.tls_client_fingerprint.is_none());
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
#[test]
fn build_auth_context_fingerprint_unknown_identity_none() {
struct StaticProvider;
impl IdentityProvider for StaticProvider {
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
None
}
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
let auth = build_auth_context(
b"alknet/test",
None,
Some("SHA256:unknown".to_string()),
&provider,
);
assert!(auth.identity.is_none());
assert!(auth.tls_client_fingerprint.is_some());
}
#[cfg(feature = "iroh")]
#[test]
fn raw_key_cert_resolver_only_raw_public_keys() {
use rustls::server::ResolvesServerCert;
let mut csprng = rand::rngs::OsRng;
let sk = iroh::SecretKey::generate(&mut csprng);
let resolver = RawKeyCertResolver::new(&sk);
assert!(resolver.only_raw_public_keys());
}
#[cfg(feature = "iroh")]
#[tokio::test]
async fn endpoint_constructs_with_iroh_raw_key_identity() {
let mut csprng = rand::rngs::OsRng;
let static_config = StaticConfig {
listen_addr: None,
tls_identity: Some(TlsIdentity::RawKey(iroh::SecretKey::generate(&mut csprng))),
iroh_relay: None,
drain_timeout: Duration::from_millis(10),
};
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
struct NoProvider;
impl IdentityProvider for NoProvider {
fn resolve_from_fingerprint(&self, _: &str) -> Option<Identity> {
None
}
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
let mut registry = HandlerRegistry::new();
registry.register(make_handler(b"alknet/test"));
let endpoint = AlknetEndpoint::new(&static_config, registry, dynamic, provider)
.await
.expect("endpoint constructs");
assert!(endpoint.shutdown_sender().send(true).is_ok());
endpoint.shutdown().await.expect("shutdown ok");
}
#[cfg(feature = "iroh")]
#[tokio::test]
async fn iroh_endpoint_runs_accept_loop_and_shutdown() {
use std::sync::Mutex;
let server_sk = iroh::SecretKey::generate(&mut rand::rngs::OsRng);
let static_config = StaticConfig {
listen_addr: None,
tls_identity: Some(TlsIdentity::RawKey(server_sk)),
iroh_relay: None,
drain_timeout: Duration::from_millis(20),
};
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
struct NoProvider;
impl IdentityProvider for NoProvider {
fn resolve_from_fingerprint(&self, _: &str) -> Option<Identity> {
None
}
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
let connected = Arc::new(Mutex::new(false));
let connected_clone = connected.clone();
struct CountingHandler {
alpn: &'static [u8],
connected: Arc<Mutex<bool>>,
}
#[async_trait]
impl ProtocolHandler for CountingHandler {
fn alpn(&self) -> &'static [u8] {
self.alpn
}
async fn handle(
&self,
_conn: Connection,
_auth: &AuthContext,
) -> Result<(), HandlerError> {
*self.connected.lock().unwrap() = true;
Ok(())
}
}
let mut registry = HandlerRegistry::new();
registry.register(Arc::new(CountingHandler {
alpn: b"alknet/test",
connected: connected_clone,
}));
let endpoint = Arc::new(
AlknetEndpoint::new(&static_config, registry, dynamic, provider)
.await
.expect("endpoint constructs"),
);
let run_endpoint = endpoint.clone();
let run_task = tokio::spawn(async move {
run_endpoint.run().await;
});
let _ = endpoint.shutdown_sender().send(true);
endpoint.shutdown().await.ok();
let _ = run_task.await;
assert!(!*connected.lock().unwrap());
}
#[cfg(feature = "quinn")]
#[test]
fn self_signed_cert_generation_produces_cert_and_key() {
let cert = generate_self_signed_cert().expect("self-signed cert generates");
assert!(!cert.cert_chain.is_empty());
assert!(!cert.private_key.secret_der().is_empty());
}
#[cfg(any(feature = "quinn", feature = "iroh"))]
#[test]
fn dispatch_decision_logic_lookup_and_auth() {
let mut registry = HandlerRegistry::new();
registry.register(make_handler(b"alknet/ssh"));
registry.register(make_handler(b"alknet/call"));
struct StaticProvider;
impl IdentityProvider for StaticProvider {
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
if fp == "SHA256:caller" {
Some(Identity {
id: "SHA256:caller".to_string(),
scopes: vec!["relay:connect".to_string()],
resources: HashMap::new(),
})
} else {
None
}
}
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
None
}
}
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
let ssh_handler = registry.get(b"alknet/ssh").expect("ssh handler registered");
assert_eq!(ssh_handler.alpn(), b"alknet/ssh");
let auth = build_auth_context(
b"alknet/ssh",
Some(std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
1234,
)),
Some("SHA256:caller".to_string()),
&provider,
);
assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:caller");
assert_eq!(auth.alpn, b"alknet/ssh");
let unknown = registry.get(b"alknet/unknown");
assert!(unknown.is_none(), "unknown ALPN has no handler");
}
}