diff --git a/crates/alknet-call/src/registry/registration.rs b/crates/alknet-call/src/registry/registration.rs index f2700ed..70a30c5 100644 --- a/crates/alknet-call/src/registry/registration.rs +++ b/crates/alknet-call/src/registry/registration.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::sync::Arc; use alknet_core::types::Capabilities; -use futures::stream::Stream; +use futures::stream::{self, Stream}; use serde_json::Value; use super::context::{CompositionAuthority, OperationContext, ScopedPeerEnv}; @@ -156,6 +156,63 @@ impl OperationRegistry { ), } } + + pub fn invoke_streaming( + &self, + name: &str, + input: Value, + context: OperationContext, + ) -> ResponseStream { + let request_id = context.request_id.clone(); + let name_owned = name.to_string(); + + let registration = match self.operations.get(name) { + Some(r) => r, + None => { + return Box::pin(stream::once(async move { + ResponseEnvelope::not_found(request_id, &name_owned) + })); + } + }; + + if registration.spec.visibility == Visibility::Internal && !context.internal { + return Box::pin(stream::once(async move { + ResponseEnvelope::not_found(request_id, &name_owned) + })); + } + + let acl = ®istration.spec.access_control; + let identity = if context.internal { + context + .handler_identity + .as_ref() + .and_then(|ca| ca.as_identity()) + } else { + context.identity.clone() + }; + + if let AccessResult::Forbidden(message) = acl.check(identity.as_ref()) { + return Box::pin(stream::once(async move { + ResponseEnvelope::forbidden(request_id, message) + })); + } + + let streaming_handler = match ®istration.handler { + HandlerKind::Stream(h) => Arc::clone(h), + HandlerKind::Once(_) => { + return Box::pin(stream::once(async move { + ResponseEnvelope::error( + request_id, + CallError::invalid_operation_type( + "invoke_streaming() called on a Query/Mutation op; use invoke()", + ), + ) + })); + } + }; + + streaming_handler(input, context) + } } impl Default for OperationRegistry { @@ -1006,4 +1063,189 @@ mod tests { assert!(!err.retryable); assert!(err.details.is_none()); } + + async fn collect_stream(mut s: ResponseStream) -> Vec { + use futures::stream::StreamExt; + let mut out = Vec::new(); + while let Some(env) = s.next().await { + out.push(env); + } + out + } + + #[tokio::test] + async fn invoke_streaming_on_subscription_dispatches_handler_stream() { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + subscription_spec("events/stream"), + HandlerKind::Stream(echo_streaming_handler()), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let ctx = root_context("req-is-1", None, None, false, ScopedPeerEnv::empty()); + let stream = registry.invoke_streaming("events/stream", serde_json::json!({"v": 7}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + assert_eq!(items[0].request_id, "req-is-1"); + assert_eq!(items[0].result, Ok(serde_json::json!({"v": 7}))); + } + + #[tokio::test] + async fn invoke_streaming_on_unknown_op_yields_single_not_found() { + let registry = OperationRegistry::new(); + let ctx = root_context("req-is-2", None, None, false, ScopedPeerEnv::empty()); + let stream = registry.invoke_streaming("missing", serde_json::json!({}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + match &items[0].result { + Err(e) => { + assert_eq!(e.code, "NOT_FOUND"); + assert!(e.message.contains("missing")); + } + other => panic!("expected NOT_FOUND, got {other:?}"), + } + } + + #[tokio::test] + async fn invoke_streaming_on_query_op_yields_invalid_operation_type() { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + external_spec("echo", AccessControl::default()), + HandlerKind::Once(echo_handler()), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let ctx = root_context("req-is-3", None, None, false, ScopedPeerEnv::empty()); + let stream = registry.invoke_streaming("echo", serde_json::json!({}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + match &items[0].result { + Err(e) => assert_eq!(e.code, "INVALID_OPERATION_TYPE"), + other => panic!("expected INVALID_OPERATION_TYPE, got {other:?}"), + } + } + + #[tokio::test] + async fn invoke_streaming_internal_op_from_external_yields_not_found() { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + internal_subscription_spec(AccessControl::default()), + HandlerKind::Stream(echo_streaming_handler()), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let ctx = root_context("req-is-4", None, None, false, ScopedPeerEnv::empty()); + let stream = registry.invoke_streaming("events/stream", serde_json::json!({}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + match &items[0].result { + Err(e) => { + assert_eq!(e.code, "NOT_FOUND"); + assert!(e.message.contains("events/stream")); + } + other => panic!("expected NOT_FOUND, got {other:?}"), + } + } + + #[tokio::test] + async fn invoke_streaming_acl_denied_yields_forbidden() { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + subscription_spec_with_acl(AccessControl { + required_scopes: vec!["admin".to_string()], + ..Default::default() + }), + HandlerKind::Stream(echo_streaming_handler()), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let ctx = root_context( + "req-is-5", + Some(identity_with_scopes("caller", &["user"])), + None, + false, + ScopedPeerEnv::empty(), + ); + let stream = registry.invoke_streaming("events/stream", serde_json::json!({}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + match &items[0].result { + Err(e) => { + assert_eq!(e.code, "FORBIDDEN"); + assert!(e.message.contains("admin")); + } + other => panic!("expected FORBIDDEN, got {other:?}"), + } + } + + #[tokio::test] + async fn invoke_streaming_internal_call_uses_handler_identity_for_acl() { + let mut registry = OperationRegistry::new(); + let composing_authority = CompositionAuthority::new("agent-chat", ["admin".to_string()]); + registry + .register(HandlerRegistration::new( + internal_subscription_spec(AccessControl { + required_scopes: vec!["admin".to_string()], + ..Default::default() + }), + HandlerKind::Stream(echo_streaming_handler()), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let ctx = root_context( + "req-is-6", + Some(identity_with_scopes("user", &["user"])), + Some(composing_authority), + true, + ScopedPeerEnv::empty(), + ); + let stream = registry.invoke_streaming("events/stream", serde_json::json!({"ok": 1}), ctx); + let items = collect_stream(stream).await; + assert_eq!(items.len(), 1); + assert_eq!(items[0].request_id, "req-is-6"); + assert_eq!(items[0].result, Ok(serde_json::json!({"ok": 1}))); + } + + fn subscription_spec_with_acl(acl: AccessControl) -> OperationSpec { + OperationSpec::new( + "events/stream", + OperationType::Subscription, + Visibility::External, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ) + } + + fn internal_subscription_spec(acl: AccessControl) -> OperationSpec { + OperationSpec::new( + "events/stream", + OperationType::Subscription, + Visibility::Internal, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ) + } }