feat(core): implement AlknetEndpoint, HandlerRegistry, accept loops (quinn + iroh), TLS identity (RawKey/X509/SelfSigned), and graceful shutdown (task: core/endpoint)
This commit is contained in:
12
Cargo.lock
generated
12
Cargo.lock
generated
@@ -74,7 +74,10 @@ dependencies = [
|
||||
"hex",
|
||||
"iroh",
|
||||
"quinn",
|
||||
"rand 0.8.6",
|
||||
"rcgen 0.13.2",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -3286,6 +3289,15 @@ dependencies = [
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.14.1"
|
||||
|
||||
@@ -20,6 +20,7 @@ quinn = { version = "0.11", optional = true }
|
||||
iroh = { version = "0.35", optional = true }
|
||||
rustls = "0.23"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
@@ -32,3 +33,5 @@ bytes = "1"
|
||||
futures = "0.3"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
rand = "0.8"
|
||||
rcgen = "0.13"
|
||||
@@ -2,4 +2,981 @@
|
||||
//!
|
||||
//! See `docs/architecture/crates/core/endpoint.md` for the full specification.
|
||||
|
||||
// TODO: implement
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use std::net::SocketAddr;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use arc_swap::ArcSwap;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use tokio::sync::watch;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use crate::auth::{AuthContext, IdentityProvider};
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use crate::config::{DynamicConfig, StaticConfig, TlsIdentity};
|
||||
#[cfg(not(any(feature = "quinn", feature = "iroh")))]
|
||||
use crate::types::ProtocolHandler;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use crate::types::{Connection, ProtocolHandler};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EndpointError {
|
||||
#[error("bind failed: {0}")]
|
||||
BindFailed(io::Error),
|
||||
#[error("tls config error: {0}")]
|
||||
TlsConfig(io::Error),
|
||||
#[error("handler not found for ALPN {0:?}")]
|
||||
HandlerNotFound(Vec<u8>),
|
||||
}
|
||||
|
||||
pub struct HandlerRegistry {
|
||||
handlers: HashMap<&'static [u8], Arc<dyn ProtocolHandler>>,
|
||||
}
|
||||
|
||||
impl HandlerRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, handler: Arc<dyn ProtocolHandler>) {
|
||||
let alpn = handler.alpn();
|
||||
if self.handlers.contains_key(alpn) {
|
||||
panic!(
|
||||
"HandlerRegistry: ALPN already registered: {:?}",
|
||||
String::from_utf8_lossy(alpn)
|
||||
);
|
||||
}
|
||||
self.handlers.insert(alpn, handler);
|
||||
}
|
||||
|
||||
pub fn get(&self, alpn: &[u8]) -> Option<&Arc<dyn ProtocolHandler>> {
|
||||
self.handlers.get(alpn)
|
||||
}
|
||||
|
||||
pub fn alpn_strings(&self) -> Vec<Vec<u8>> {
|
||||
self.handlers.keys().map(|k| k.to_vec()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HandlerRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for HandlerRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("HandlerRegistry")
|
||||
.field(
|
||||
"alpns",
|
||||
&self
|
||||
.handlers
|
||||
.keys()
|
||||
.map(|k| String::from_utf8_lossy(k).to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
pub struct AlknetEndpoint {
|
||||
#[cfg(feature = "quinn")]
|
||||
quinn: Option<quinn::Endpoint>,
|
||||
#[cfg(feature = "iroh")]
|
||||
iroh: Option<iroh::Endpoint>,
|
||||
handlers: Arc<HandlerRegistry>,
|
||||
#[allow(dead_code)]
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
shutdown_tx: watch::Sender<bool>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
drain_timeout: Duration,
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
impl std::fmt::Debug for AlknetEndpoint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AlknetEndpoint")
|
||||
.field("handlers", &self.handlers)
|
||||
.field("drain_timeout", &self.drain_timeout)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
impl AlknetEndpoint {
|
||||
pub async fn new(
|
||||
static_config: &StaticConfig,
|
||||
handlers: HandlerRegistry,
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) -> Result<Self, EndpointError> {
|
||||
let handlers = Arc::new(handlers);
|
||||
let alpns = handlers.alpn_strings();
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let drain_timeout = static_config.drain_timeout;
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
let quinn = if let Some(listen_addr) = static_config.listen_addr {
|
||||
let tls_identity = static_config.tls_identity.as_ref().ok_or_else(|| {
|
||||
EndpointError::TlsConfig(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"quinn endpoint requires tls_identity in static config",
|
||||
))
|
||||
})?;
|
||||
let server_config = build_quinn_server_config(tls_identity, &alpns)?;
|
||||
let endpoint = quinn::Endpoint::server(server_config, listen_addr)
|
||||
.map_err(EndpointError::BindFailed)?;
|
||||
Some(endpoint)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "quinn"))]
|
||||
if static_config.listen_addr.is_some() {
|
||||
return Err(EndpointError::TlsConfig(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"quinn feature is not enabled but listen_addr was set",
|
||||
)));
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
let iroh = if static_config.iroh_relay.is_some() || has_iroh_identity(static_config) {
|
||||
Some(build_iroh_endpoint(static_config, &alpns).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
#[cfg(feature = "quinn")]
|
||||
quinn,
|
||||
#[cfg(feature = "iroh")]
|
||||
iroh,
|
||||
handlers,
|
||||
dynamic,
|
||||
identity_provider,
|
||||
shutdown_tx,
|
||||
shutdown_rx,
|
||||
drain_timeout,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn shutdown_sender(&self) -> watch::Sender<bool> {
|
||||
self.shutdown_tx.clone()
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> Result<(), EndpointError> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
if let Some(quinn) = &self.quinn {
|
||||
quinn.close(0u32.into(), b"shutdown");
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
if let Some(iroh) = &self.iroh {
|
||||
iroh.close().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.drain_timeout).await;
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
if let Some(quinn) = &self.quinn {
|
||||
quinn.wait_idle().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
if let Some(quinn) = &self.quinn {
|
||||
let quinn = quinn.clone();
|
||||
let handlers = self.handlers.clone();
|
||||
let identity_provider = self.identity_provider.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
run_quinn_accept_loop(quinn, handlers, identity_provider, &mut shutdown_rx).await;
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
if let Some(iroh) = &self.iroh {
|
||||
let iroh = iroh.clone();
|
||||
let handlers = self.handlers.clone();
|
||||
let identity_provider = self.identity_provider.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
run_iroh_accept_loop(iroh, handlers, identity_provider, &mut shutdown_rx).await;
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let _ = task.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
fn has_iroh_identity(static_config: &StaticConfig) -> bool {
|
||||
matches!(
|
||||
static_config.tls_identity.as_ref(),
|
||||
Some(TlsIdentity::RawKey(_))
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
async fn run_quinn_accept_loop(
|
||||
quinn: quinn::Endpoint,
|
||||
handlers: Arc<HandlerRegistry>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
shutdown_rx: &mut watch::Receiver<bool>,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
debug!("quinn accept loop: shutdown signaled");
|
||||
break;
|
||||
}
|
||||
incoming = quinn.accept() => {
|
||||
let Some(incoming) = incoming else {
|
||||
debug!("quinn accept loop: endpoint closed");
|
||||
break;
|
||||
};
|
||||
let connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
warn!("quinn accept failed: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let handlers = handlers.clone();
|
||||
let identity_provider = identity_provider.clone();
|
||||
tokio::spawn(async move {
|
||||
let connection = match connecting.await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
warn!("quinn TLS handshake failure: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
dispatch_quinn(connection, &handlers, &identity_provider);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn dispatch_quinn(
|
||||
connection: quinn::Connection,
|
||||
handlers: &HandlerRegistry,
|
||||
identity_provider: &Arc<dyn IdentityProvider>,
|
||||
) {
|
||||
let alpn = extract_quinn_alpn(&connection);
|
||||
let handler = match handlers.get(&alpn) {
|
||||
Some(h) => h.clone(),
|
||||
None => {
|
||||
connection.close(0u32.into(), b"no handler");
|
||||
warn!(
|
||||
"quinn dispatch: no handler for ALPN {:?}",
|
||||
String::from_utf8_lossy(&alpn)
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let remote_addr = Some(connection.remote_address());
|
||||
let auth = build_auth_context(&alpn, remote_addr, None, identity_provider);
|
||||
let conn = Connection::from_quinn_with_alpn(connection, alpn.clone());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handler.handle(conn, &auth).await {
|
||||
error!("handler returned error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn extract_quinn_alpn(connection: &quinn::Connection) -> Vec<u8> {
|
||||
use quinn::crypto::rustls::HandshakeData;
|
||||
if let Some(data) = connection.handshake_data() {
|
||||
if let Ok(hs) = data.downcast::<HandshakeData>() {
|
||||
if let Some(protocol) = hs.protocol {
|
||||
return protocol;
|
||||
}
|
||||
}
|
||||
}
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
async fn run_iroh_accept_loop(
|
||||
iroh: iroh::Endpoint,
|
||||
handlers: Arc<HandlerRegistry>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
shutdown_rx: &mut watch::Receiver<bool>,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
debug!("iroh accept loop: shutdown signaled");
|
||||
break;
|
||||
}
|
||||
incoming = iroh.accept() => {
|
||||
let Some(incoming) = incoming else {
|
||||
debug!("iroh accept loop: endpoint closed");
|
||||
break;
|
||||
};
|
||||
let handlers = handlers.clone();
|
||||
let identity_provider = identity_provider.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
warn!("iroh accept failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let alpn = match connecting.alpn().await {
|
||||
Ok(alpn) => alpn,
|
||||
Err(e) => {
|
||||
warn!("iroh ALPN negotiation failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let connection = match connecting.await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
warn!("iroh handshake completion failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
dispatch_iroh(connection, alpn, &handlers, &identity_provider);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
fn dispatch_iroh(
|
||||
connection: iroh::endpoint::Connection,
|
||||
alpn: Vec<u8>,
|
||||
handlers: &HandlerRegistry,
|
||||
identity_provider: &Arc<dyn IdentityProvider>,
|
||||
) {
|
||||
let handler = match handlers.get(&alpn) {
|
||||
Some(h) => h.clone(),
|
||||
None => {
|
||||
connection.close(0u32.into(), b"no handler");
|
||||
warn!(
|
||||
"iroh dispatch: no handler for ALPN {:?}",
|
||||
String::from_utf8_lossy(&alpn)
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let auth = build_auth_context(&alpn, None, None, identity_provider);
|
||||
let conn = Connection::from_iroh(connection);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handler.handle(conn, &auth).await {
|
||||
error!("handler returned error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
fn build_auth_context(
|
||||
alpn: &[u8],
|
||||
remote_addr: Option<SocketAddr>,
|
||||
tls_client_fingerprint: Option<String>,
|
||||
identity_provider: &Arc<dyn IdentityProvider>,
|
||||
) -> AuthContext {
|
||||
let identity = tls_client_fingerprint
|
||||
.as_ref()
|
||||
.and_then(|fp| identity_provider.resolve_from_fingerprint(fp));
|
||||
AuthContext {
|
||||
identity,
|
||||
alpn: alpn.to_vec(),
|
||||
remote_addr,
|
||||
tls_client_fingerprint,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn build_quinn_server_config(
|
||||
tls_identity: &TlsIdentity,
|
||||
alpns: &[Vec<u8>],
|
||||
) -> Result<quinn::ServerConfig, EndpointError> {
|
||||
use quinn::crypto::rustls::QuicServerConfig;
|
||||
let rustls_config = build_rustls_server_config(tls_identity, alpns)?;
|
||||
let quic_server_config = QuicServerConfig::try_from(rustls_config)
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
Ok(quinn::ServerConfig::with_crypto(Arc::new(
|
||||
quic_server_config,
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn build_rustls_server_config(
|
||||
tls_identity: &TlsIdentity,
|
||||
alpns: &[Vec<u8>],
|
||||
) -> Result<rustls::ServerConfig, EndpointError> {
|
||||
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
|
||||
match tls_identity {
|
||||
TlsIdentity::X509 { cert, key } => {
|
||||
let cert_chain = load_cert_chain(cert)?;
|
||||
let private_key = load_private_key(key)?;
|
||||
let mut config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, private_key)
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
config.alpn_protocols = alpns.to_vec();
|
||||
config.max_early_data_size = u32::MAX;
|
||||
Ok(config)
|
||||
}
|
||||
#[cfg(feature = "iroh")]
|
||||
TlsIdentity::RawKey(secret_key) => {
|
||||
let resolver = Arc::new(RawKeyCertResolver::new(secret_key));
|
||||
let mut config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(resolver);
|
||||
config.alpn_protocols = alpns.to_vec();
|
||||
config.max_early_data_size = u32::MAX;
|
||||
Ok(config)
|
||||
}
|
||||
TlsIdentity::SelfSigned => {
|
||||
let cert = generate_self_signed_cert()?;
|
||||
let mut config = rustls::ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert.cert_chain, cert.private_key)
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
config.alpn_protocols = alpns.to_vec();
|
||||
config.max_early_data_size = u32::MAX;
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
async fn build_iroh_endpoint(
|
||||
static_config: &StaticConfig,
|
||||
alpns: &[Vec<u8>],
|
||||
) -> Result<iroh::Endpoint, EndpointError> {
|
||||
let mut builder = iroh::Endpoint::builder();
|
||||
|
||||
if let Some(TlsIdentity::RawKey(secret_key)) = static_config.tls_identity.as_ref() {
|
||||
builder = builder.secret_key(secret_key.clone());
|
||||
} else {
|
||||
let mut csprng = rand::rngs::OsRng;
|
||||
builder = builder.secret_key(iroh::SecretKey::generate(&mut csprng));
|
||||
}
|
||||
|
||||
builder = builder.alpns(alpns.to_vec());
|
||||
|
||||
if let Some(relay_url) = static_config.iroh_relay.as_ref() {
|
||||
let relay_map = iroh::RelayMap::from(relay_url.clone());
|
||||
builder = builder.relay_mode(iroh::RelayMode::Custom(relay_map));
|
||||
} else {
|
||||
builder = builder.relay_mode(iroh::RelayMode::Disabled);
|
||||
}
|
||||
|
||||
builder
|
||||
.bind()
|
||||
.await
|
||||
.map_err(|e| EndpointError::BindFailed(io::Error::other(e)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn load_cert_chain(
|
||||
path: &Path,
|
||||
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, EndpointError> {
|
||||
let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
let mut reader = std::io::BufReader::new(bytes.as_slice());
|
||||
rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn load_private_key(
|
||||
path: &Path,
|
||||
) -> Result<rustls::pki_types::PrivateKeyDer<'static>, EndpointError> {
|
||||
let bytes = std::fs::read(path).map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
let mut reader = std::io::BufReader::new(bytes.as_slice());
|
||||
match rustls_pemfile::private_key(&mut reader) {
|
||||
Ok(Some(key)) => Ok(key),
|
||||
Ok(None) => Err(EndpointError::TlsConfig(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"no private key found in file",
|
||||
))),
|
||||
Err(e) => Err(EndpointError::TlsConfig(io::Error::other(e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
struct SelfSignedCert {
|
||||
cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
|
||||
private_key: rustls::pki_types::PrivateKeyDer<'static>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn generate_self_signed_cert() -> Result<SelfSignedCert, EndpointError> {
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
let key_pair =
|
||||
KeyPair::generate().map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
let params = CertificateParams::default();
|
||||
let cert = params
|
||||
.self_signed(&key_pair)
|
||||
.map_err(|e| EndpointError::TlsConfig(io::Error::other(e)))?;
|
||||
let cert_der = cert.der().clone();
|
||||
let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(
|
||||
rustls::pki_types::PrivatePkcs8KeyDer::from(key_pair.serialize_der()),
|
||||
);
|
||||
Ok(SelfSignedCert {
|
||||
cert_chain: vec![cert_der],
|
||||
private_key: key_der,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
struct RawKeyCertResolver {
|
||||
key: Arc<rustls::sign::CertifiedKey>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl RawKeyCertResolver {
|
||||
fn new(secret_key: &iroh::SecretKey) -> Self {
|
||||
let signing_key = Arc::new(Ed25519SigningKey::new(secret_key.clone()));
|
||||
let public_key = signing_key.spki_public_key();
|
||||
let cert = rustls::pki_types::CertificateDer::from(public_key.to_vec());
|
||||
let certified_key = rustls::sign::CertifiedKey::new(vec![cert], signing_key);
|
||||
Self {
|
||||
key: Arc::new(certified_key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl rustls::server::ResolvesServerCert for RawKeyCertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
_client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
Some(Arc::clone(&self.key))
|
||||
}
|
||||
|
||||
fn only_raw_public_keys(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl std::fmt::Debug for RawKeyCertResolver {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RawKeyCertResolver").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[derive(Clone)]
|
||||
struct Ed25519SigningKey {
|
||||
key: iroh::SecretKey,
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl std::fmt::Debug for Ed25519SigningKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Ed25519SigningKey").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl Ed25519SigningKey {
|
||||
fn new(key: iroh::SecretKey) -> Self {
|
||||
Self { key }
|
||||
}
|
||||
|
||||
fn spki_public_key(&self) -> rustls::pki_types::SubjectPublicKeyInfoDer<'static> {
|
||||
rustls::sign::public_key_to_spki(
|
||||
&rustls::pki_types::alg_id::ED25519,
|
||||
self.key.public().as_bytes(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl rustls::sign::SigningKey for Ed25519SigningKey {
|
||||
fn choose_scheme(
|
||||
&self,
|
||||
offered: &[rustls::SignatureScheme],
|
||||
) -> Option<Box<dyn rustls::sign::Signer>> {
|
||||
if offered.contains(&rustls::SignatureScheme::ED25519) {
|
||||
Some(Box::new(self.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn algorithm(&self) -> rustls::SignatureAlgorithm {
|
||||
rustls::SignatureAlgorithm::ED25519
|
||||
}
|
||||
|
||||
fn public_key(&self) -> Option<rustls::pki_types::SubjectPublicKeyInfoDer<'_>> {
|
||||
Some(self.spki_public_key())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
impl rustls::sign::Signer for Ed25519SigningKey {
|
||||
fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rustls::Error> {
|
||||
Ok(self.key.sign(message).to_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn scheme(&self) -> rustls::SignatureScheme {
|
||||
rustls::SignatureScheme::ED25519
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::AuthContext;
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
use crate::auth::{AuthToken, Identity, IdentityProvider};
|
||||
use crate::types::{Connection, HandlerError};
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct DummyHandler {
|
||||
alpn: &'static [u8],
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolHandler for DummyHandler {
|
||||
fn alpn(&self) -> &'static [u8] {
|
||||
self.alpn
|
||||
}
|
||||
async fn handle(
|
||||
&self,
|
||||
_connection: Connection,
|
||||
_auth: &AuthContext,
|
||||
) -> Result<(), HandlerError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_handler(alpn: &'static [u8]) -> Arc<dyn ProtocolHandler> {
|
||||
Arc::new(DummyHandler { alpn })
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_registry_new_is_empty() {
|
||||
let reg = HandlerRegistry::new();
|
||||
assert!(reg.alpn_strings().is_empty());
|
||||
assert!(reg.get(b"alknet/test").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_registry_register_then_get() {
|
||||
let mut reg = HandlerRegistry::new();
|
||||
reg.register(make_handler(b"alknet/test"));
|
||||
assert_eq!(reg.alpn_strings(), vec![b"alknet/test".to_vec()]);
|
||||
assert!(reg.get(b"alknet/test").is_some());
|
||||
assert!(reg.get(b"alknet/other").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_registry_multiple_alpns() {
|
||||
let mut reg = HandlerRegistry::new();
|
||||
reg.register(make_handler(b"alknet/ssh"));
|
||||
reg.register(make_handler(b"alknet/call"));
|
||||
let mut alpns = reg
|
||||
.alpn_strings()
|
||||
.into_iter()
|
||||
.map(|a| String::from_utf8(a).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
alpns.sort();
|
||||
assert_eq!(alpns, vec!["alknet/call", "alknet/ssh"]);
|
||||
assert!(reg.get(b"alknet/ssh").is_some());
|
||||
assert!(reg.get(b"alknet/call").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ALPN already registered")]
|
||||
fn handler_registry_register_panics_on_duplicate() {
|
||||
let mut reg = HandlerRegistry::new();
|
||||
reg.register(make_handler(b"alknet/test"));
|
||||
reg.register(make_handler(b"alknet/test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_registry_debug_lists_alpns() {
|
||||
let mut reg = HandlerRegistry::new();
|
||||
reg.register(make_handler(b"alknet/test"));
|
||||
let s = format!("{:?}", reg);
|
||||
assert!(s.contains("alknet/test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_error_display() {
|
||||
let e = EndpointError::BindFailed(io::Error::new(io::ErrorKind::AddrInUse, "busy"));
|
||||
assert!(format!("{e}").contains("bind failed"));
|
||||
let e = EndpointError::TlsConfig(io::Error::new(io::ErrorKind::InvalidData, "bad"));
|
||||
assert!(format!("{e}").contains("tls config error"));
|
||||
let e = EndpointError::HandlerNotFound(b"alknet/test".to_vec());
|
||||
assert!(format!("{e}").contains("handler not found"));
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
#[test]
|
||||
fn build_auth_context_resolves_identity_from_fingerprint() {
|
||||
struct StaticProvider;
|
||||
impl IdentityProvider for StaticProvider {
|
||||
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
|
||||
if fp == "SHA256:known" {
|
||||
Some(Identity {
|
||||
id: "SHA256:known".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
|
||||
let auth = build_auth_context(
|
||||
b"alknet/test",
|
||||
None,
|
||||
Some("SHA256:known".to_string()),
|
||||
&provider,
|
||||
);
|
||||
assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:known");
|
||||
assert_eq!(auth.alpn, b"alknet/test");
|
||||
assert_eq!(auth.tls_client_fingerprint.as_deref(), Some("SHA256:known"));
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
#[test]
|
||||
fn build_auth_context_no_fingerprint_no_identity() {
|
||||
struct NoProvider;
|
||||
impl IdentityProvider for NoProvider {
|
||||
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
|
||||
let auth = build_auth_context(b"alknet/test", None, None, &provider);
|
||||
assert!(auth.identity.is_none());
|
||||
assert!(auth.tls_client_fingerprint.is_none());
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
#[test]
|
||||
fn build_auth_context_fingerprint_unknown_identity_none() {
|
||||
struct StaticProvider;
|
||||
impl IdentityProvider for StaticProvider {
|
||||
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
|
||||
let auth = build_auth_context(
|
||||
b"alknet/test",
|
||||
None,
|
||||
Some("SHA256:unknown".to_string()),
|
||||
&provider,
|
||||
);
|
||||
assert!(auth.identity.is_none());
|
||||
assert!(auth.tls_client_fingerprint.is_some());
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[test]
|
||||
fn raw_key_cert_resolver_only_raw_public_keys() {
|
||||
use rustls::server::ResolvesServerCert;
|
||||
let mut csprng = rand::rngs::OsRng;
|
||||
let sk = iroh::SecretKey::generate(&mut csprng);
|
||||
let resolver = RawKeyCertResolver::new(&sk);
|
||||
assert!(resolver.only_raw_public_keys());
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[tokio::test]
|
||||
async fn endpoint_constructs_with_iroh_raw_key_identity() {
|
||||
let mut csprng = rand::rngs::OsRng;
|
||||
let static_config = StaticConfig {
|
||||
listen_addr: None,
|
||||
tls_identity: Some(TlsIdentity::RawKey(iroh::SecretKey::generate(&mut csprng))),
|
||||
iroh_relay: None,
|
||||
drain_timeout: Duration::from_millis(10),
|
||||
};
|
||||
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
|
||||
struct NoProvider;
|
||||
impl IdentityProvider for NoProvider {
|
||||
fn resolve_from_fingerprint(&self, _: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
|
||||
let mut registry = HandlerRegistry::new();
|
||||
registry.register(make_handler(b"alknet/test"));
|
||||
let endpoint = AlknetEndpoint::new(&static_config, registry, dynamic, provider)
|
||||
.await
|
||||
.expect("endpoint constructs");
|
||||
assert!(endpoint.shutdown_sender().send(true).is_ok());
|
||||
endpoint.shutdown().await.expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[tokio::test]
|
||||
async fn iroh_endpoint_runs_accept_loop_and_shutdown() {
|
||||
use std::sync::Mutex;
|
||||
let server_sk = iroh::SecretKey::generate(&mut rand::rngs::OsRng);
|
||||
let static_config = StaticConfig {
|
||||
listen_addr: None,
|
||||
tls_identity: Some(TlsIdentity::RawKey(server_sk)),
|
||||
iroh_relay: None,
|
||||
drain_timeout: Duration::from_millis(20),
|
||||
};
|
||||
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
|
||||
struct NoProvider;
|
||||
impl IdentityProvider for NoProvider {
|
||||
fn resolve_from_fingerprint(&self, _: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NoProvider);
|
||||
|
||||
let connected = Arc::new(Mutex::new(false));
|
||||
let connected_clone = connected.clone();
|
||||
struct CountingHandler {
|
||||
alpn: &'static [u8],
|
||||
connected: Arc<Mutex<bool>>,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ProtocolHandler for CountingHandler {
|
||||
fn alpn(&self) -> &'static [u8] {
|
||||
self.alpn
|
||||
}
|
||||
async fn handle(
|
||||
&self,
|
||||
_conn: Connection,
|
||||
_auth: &AuthContext,
|
||||
) -> Result<(), HandlerError> {
|
||||
*self.connected.lock().unwrap() = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
let mut registry = HandlerRegistry::new();
|
||||
registry.register(Arc::new(CountingHandler {
|
||||
alpn: b"alknet/test",
|
||||
connected: connected_clone,
|
||||
}));
|
||||
let endpoint = Arc::new(
|
||||
AlknetEndpoint::new(&static_config, registry, dynamic, provider)
|
||||
.await
|
||||
.expect("endpoint constructs"),
|
||||
);
|
||||
|
||||
let run_endpoint = endpoint.clone();
|
||||
let run_task = tokio::spawn(async move {
|
||||
run_endpoint.run().await;
|
||||
});
|
||||
|
||||
let _ = endpoint.shutdown_sender().send(true);
|
||||
endpoint.shutdown().await.ok();
|
||||
let _ = run_task.await;
|
||||
assert!(!*connected.lock().unwrap());
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
#[test]
|
||||
fn self_signed_cert_generation_produces_cert_and_key() {
|
||||
let cert = generate_self_signed_cert().expect("self-signed cert generates");
|
||||
assert!(!cert.cert_chain.is_empty());
|
||||
assert!(!cert.private_key.secret_der().is_empty());
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "quinn", feature = "iroh"))]
|
||||
#[test]
|
||||
fn dispatch_decision_logic_lookup_and_auth() {
|
||||
let mut registry = HandlerRegistry::new();
|
||||
registry.register(make_handler(b"alknet/ssh"));
|
||||
registry.register(make_handler(b"alknet/call"));
|
||||
|
||||
struct StaticProvider;
|
||||
impl IdentityProvider for StaticProvider {
|
||||
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
|
||||
if fp == "SHA256:caller" {
|
||||
Some(Identity {
|
||||
id: "SHA256:caller".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn resolve_from_token(&self, _: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticProvider);
|
||||
|
||||
let ssh_handler = registry.get(b"alknet/ssh").expect("ssh handler registered");
|
||||
assert_eq!(ssh_handler.alpn(), b"alknet/ssh");
|
||||
let auth = build_auth_context(
|
||||
b"alknet/ssh",
|
||||
Some(std::net::SocketAddr::new(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
|
||||
1234,
|
||||
)),
|
||||
Some("SHA256:caller".to_string()),
|
||||
&provider,
|
||||
);
|
||||
assert_eq!(auth.identity.as_ref().unwrap().id, "SHA256:caller");
|
||||
assert_eq!(auth.alpn, b"alknet/ssh");
|
||||
|
||||
let unknown = registry.get(b"alknet/unknown");
|
||||
assert!(unknown.is_none(), "unknown ALPN has no handler");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user