diff --git a/crates/alknet-call/src/protocol/dispatch.rs b/crates/alknet-call/src/protocol/dispatch.rs index ab47c0f..4bd8597 100644 --- a/crates/alknet-call/src/protocol/dispatch.rs +++ b/crates/alknet-call/src/protocol/dispatch.rs @@ -180,7 +180,6 @@ impl Dispatcher { let connection_identity = connection.connection().identity().cloned(); let identity = self.resolve_identity(connection_identity, &payload); - let forwarded_for = payload .get("forwarded_for") .and_then(|v| serde_json::from_value::(v.clone()).ok()); @@ -316,3 +315,390 @@ impl Clone for Dispatcher { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::registration::{make_handler, HandlerRegistration, OperationProvenance}; + use crate::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility}; + use alknet_core::auth::{AuthToken, Identity, IdentityProvider}; + use alknet_core::types::{Capabilities, MockConnection}; + use std::collections::HashMap; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + use std::sync::Mutex as StdMutex; + + struct StubConnection { + alpn: &'static [u8], + addr: Option, + closed: StdMutex>, + } + + impl MockConnection for StubConnection { + fn remote_alpn(&self) -> &[u8] { + self.alpn + } + fn remote_addr(&self) -> Option { + self.addr + } + fn close(&self, code: u32, reason: &str) { + *self.closed.lock().unwrap() = Some((code, reason.to_string())); + } + } + + fn stub_connection() -> alknet_core::types::Connection { + alknet_core::types::Connection::from_mock(Arc::new(StubConnection { + alpn: b"alknet/call", + addr: Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4321)), + closed: StdMutex::new(None), + })) + } + + struct StaticIdentityProvider { + tokens: StdMutex>, + } + + impl StaticIdentityProvider { + fn new() -> Self { + Self { + tokens: StdMutex::new(HashMap::new()), + } + } + + fn with_token(self, token: &str, identity: Identity) -> Self { + self.tokens + .lock() + .unwrap() + .insert(token.to_string(), identity); + self + } + } + + impl IdentityProvider for StaticIdentityProvider { + fn resolve_from_fingerprint(&self, _fp: &str) -> Option { + None + } + fn resolve_from_token(&self, token: &AuthToken) -> Option { + let token_str = String::from_utf8_lossy(&token.raw); + self.tokens.lock().unwrap().get(token_str.as_ref()).cloned() + } + } + + fn identity_with_scopes(id: &str, scopes: &[&str]) -> Identity { + Identity { + id: id.to_string(), + scopes: scopes.iter().map(|s| s.to_string()).collect(), + resources: HashMap::new(), + } + } + + fn external_spec(name: &str, acl: AccessControl) -> OperationSpec { + OperationSpec::new( + name, + OperationType::Query, + Visibility::External, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ) + } + + fn internal_spec(name: &str, acl: AccessControl) -> OperationSpec { + OperationSpec::new( + name, + OperationType::Query, + Visibility::Internal, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ) + } + + fn registry_with(name: &str, visibility: Visibility, acl: AccessControl) -> OperationRegistry { + let mut registry = OperationRegistry::new(); + registry.register(HandlerRegistration::new( + OperationSpec::new( + name, + OperationType::Query, + visibility, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ), + make_handler(|input, context| async move { + ResponseEnvelope::ok(context.request_id, input) + }), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )); + registry + } + + fn dispatcher() -> Dispatcher { + Dispatcher::new( + Arc::new(OperationRegistry::new()), + Arc::new(StaticIdentityProvider::new()), + ) + } + + #[tokio::test] + async fn dispatch_authorized_peer_dispatches_and_populates_capabilities() { + let caps = Capabilities::new().with_api_key("google", "k".to_string()); + let mut registry = OperationRegistry::new(); + let handler = make_handler(|_input, context| async move { + let has_google = context.capabilities.get("google").is_some(); + ResponseEnvelope::ok( + context.request_id, + serde_json::json!({ "has_google": has_google }), + ) + }); + registry.register(HandlerRegistration::new( + external_spec("admin/run", AccessControl::default()), + handler, + OperationProvenance::Local, + None, + None, + caps, + )); + let registry = Arc::new(registry); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/admin/run", + "input": {}, + }); + let response = dp + .dispatch_requested(&conn, "req-1".to_string(), payload) + .await; + let out = response.result.expect("dispatch ok"); + assert_eq!(out["has_google"], Value::Bool(true)); + } + + #[tokio::test] + async fn dispatch_unauthorized_peer_returns_forbidden_capabilities_never_populated() { + let caps = Capabilities::new().with_api_key("google", "k".to_string()); + let mut registry = OperationRegistry::new(); + let handler = make_handler(|_input, context| async move { + let has_google = context.capabilities.get("google").is_some(); + ResponseEnvelope::ok( + context.request_id, + serde_json::json!({ "has_google": has_google }), + ) + }); + registry.register(HandlerRegistration::new( + external_spec( + "admin/run", + AccessControl { + required_scopes: vec!["admin".to_string()], + ..Default::default() + }, + ), + handler, + OperationProvenance::Local, + None, + None, + caps, + )); + let registry = Arc::new(registry); + let provider: Arc = Arc::new( + StaticIdentityProvider::new() + .with_token("alk_user", identity_with_scopes("regular-user", &["user"])), + ); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/admin/run", + "input": {}, + "auth_token": "alk_user", + }); + let response = dp + .dispatch_requested(&conn, "req-2".to_string(), payload) + .await; + match response.result { + Err(e) => { + assert_eq!(e.code, "FORBIDDEN"); + assert!(e.message.contains("admin")); + } + other => panic!("expected FORBIDDEN, got {other:?}"), + } + } + + #[tokio::test] + async fn dispatch_internal_op_from_wire_returns_not_found_before_acl() { + let registry = Arc::new(registry_with( + "secret/op", + Visibility::Internal, + AccessControl::default(), + )); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/secret/op", + "input": {}, + }); + let response = dp + .dispatch_requested(&conn, "req-3".to_string(), payload) + .await; + match response.result { + Err(e) => { + assert_eq!(e.code, "NOT_FOUND"); + assert!(e.message.contains("secret/op")); + } + other => panic!("expected NOT_FOUND, got {other:?}"), + } + } + + #[tokio::test] + async fn dispatch_connection_with_no_identity_produces_no_peer_id_in_env() { + let registry = Arc::new(registry_with( + "fs/readFile", + Visibility::External, + AccessControl::default(), + )); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = CallConnection::new(stub_connection()); + + let context = dp.build_root_context("req-4".to_string(), "fs/readFile", None, None, &conn); + + assert!( + context.identity.is_none(), + "no connection identity → context.identity is None" + ); + assert!( + context.env.peer_ids().is_empty(), + "no peer overlay attached when connection has no identity" + ); + } + + #[tokio::test] + async fn dispatch_connection_with_identity_attaches_peer_overlay_keyed_by_identity_id() { + let registry = Arc::new(registry_with( + "fs/readFile", + Visibility::External, + AccessControl::default(), + )); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = CallConnection::new(stub_connection()); + conn.connection() + .set_identity(identity_with_scopes("worker-a", &[])) + .expect("identity not yet set"); + + let context = dp.build_root_context("req-5".to_string(), "fs/readFile", None, None, &conn); + + assert_eq!( + context.env.peer_ids(), + vec!["worker-a".to_string()], + "PeerId for connection comes from connection.identity().id" + ); + } + + #[tokio::test] + async fn dispatch_extract_forwarded_for_from_payload_into_context() { + let mut registry = OperationRegistry::new(); + let handler = make_handler(|_input, context| async move { + let forwarded_id = context.forwarded_for.as_ref().map(|i| i.id.clone()); + ResponseEnvelope::ok( + context.request_id, + serde_json::json!({ "forwarded_for_id": forwarded_id }), + ) + }); + registry.register(HandlerRegistration::new( + external_spec("fs/readFile", AccessControl::default()), + handler, + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )); + let registry = Arc::new(registry); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/fs/readFile", + "input": {}, + "forwarded_for": { + "id": "alice", + "scopes": ["fs:read"], + "resources": {} + }, + }); + let response = dp + .dispatch_requested(&conn, "req-6".to_string(), payload) + .await; + let out = response.result.expect("ok"); + assert_eq!(out["forwarded_for_id"], Value::String("alice".into())); + } + + #[tokio::test] + async fn dispatch_without_forwarded_for_field_is_none() { + let mut registry = OperationRegistry::new(); + let handler = make_handler(|_input, context| async move { + let present = context.forwarded_for.is_some(); + ResponseEnvelope::ok( + context.request_id, + serde_json::json!({ "present": present }), + ) + }); + registry.register(HandlerRegistration::new( + external_spec("fs/readFile", AccessControl::default()), + handler, + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )); + let registry = Arc::new(registry); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/fs/readFile", + "input": {}, + }); + let response = dp + .dispatch_requested(&conn, "req-7".to_string(), payload) + .await; + let out = response.result.expect("ok"); + assert_eq!(out["present"], Value::Bool(false)); + } + + #[tokio::test] + async fn dispatch_default_access_control_dispatches_to_any_peer() { + let registry = Arc::new(registry_with( + "echo/run", + Visibility::External, + AccessControl::default(), + )); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/echo/run", + "input": { "msg": "hi" }, + }); + let response = dp + .dispatch_requested(&conn, "req-8".to_string(), payload) + .await; + assert_eq!(response.result, Ok(serde_json::json!({ "msg": "hi" }))); + } + + #[test] + fn dispatcher_helper_compiles_with_full_signature() { + let _dp = dispatcher(); + } +} diff --git a/crates/alknet-call/src/registry/env.rs b/crates/alknet-call/src/registry/env.rs index 46e29ee..93e0258 100644 --- a/crates/alknet-call/src/registry/env.rs +++ b/crates/alknet-call/src/registry/env.rs @@ -287,6 +287,10 @@ impl OperationEnv for PeerCompositeEnv { .get(peer) .is_some_and(|c| c.contains(name)) } + + fn peer_ids(&self) -> Vec { + self.connection_order.clone() + } } #[cfg(test)]