diff --git a/crates/alknet-call/src/protocol/dispatch.rs b/crates/alknet-call/src/protocol/dispatch.rs index 46c6a63..32cd979 100644 --- a/crates/alknet-call/src/protocol/dispatch.rs +++ b/crates/alknet-call/src/protocol/dispatch.rs @@ -17,6 +17,7 @@ use std::time::{Duration, Instant}; use alknet_core::auth::{AuthToken, Identity, IdentityProvider}; use alknet_core::types::StreamError; +use futures::stream::StreamExt; use serde_json::Value; use tokio::task::JoinHandle; use tracing::{debug, warn}; @@ -30,11 +31,37 @@ use super::wire::{ use crate::protocol::adapter::SessionOverlaySource; use crate::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv}; use crate::registry::env::{LocalOperationEnv, OperationEnv, PeerCompositeEnv}; -use crate::registry::registration::OperationRegistry; +use crate::registry::registration::{OperationRegistry, ResponseStream}; +use crate::registry::spec::OperationType; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); const SWEEPER_INTERVAL: Duration = Duration::from_secs(10); +/// Outcome of dispatching a `call.requested` event. The dispatcher branches on +/// the registered operation's `op_type` (ADR-049 §6): `Query`/`Mutation` produce +/// a single [`ResponseEnvelope`] (`Once`), `Subscription` produces a +/// [`ResponseStream`] (`Stream`) that `handle_stream` pumps to the wire. +/// +/// This enum is the branch point the spec describes ("branches on `op_type` in +/// `handle_stream`"): `dispatch` returns it and `handle_stream` matches on it, +/// keeping the Once path (one frame, no `call.completed`) and the Stream path +/// (each envelope → frame, `call.completed` on natural end) visibly distinct. +pub enum DispatchResult { + Once(ResponseEnvelope), + Stream(ResponseStream), +} + +impl std::fmt::Debug for DispatchResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DispatchResult::Once(env) => f.debug_tuple("Once").field(env).finish(), + DispatchResult::Stream(_) => { + f.debug_tuple("Stream").field(&"").finish() + } + } + } +} + /// Shared dispatcher for an established `CallConnection`. Constructed by /// both `CallAdapter` (accept path) and `CallClient` (connect path) and used /// to run the dispatch loop. Holds no per-connection state; the @@ -166,6 +193,36 @@ impl Dispatcher { request_id: String, payload: Value, ) -> ResponseEnvelope { + match self.dispatch(connection, request_id, payload).await { + DispatchResult::Once(envelope) => envelope, + DispatchResult::Stream(mut stream) => stream.next().await.unwrap_or_else(|| { + ResponseEnvelope::error( + String::new(), + CallError::internal( + "dispatch_requested called on a Subscription op; use the streaming path", + ), + ) + }), + } + } + + /// Dispatch a `call.requested` event, branching on the registered + /// operation's `op_type` (ADR-049 §6). `Query`/`Mutation` → `invoke()` → + /// [`DispatchResult::Once`]; `Subscription` → `invoke_streaming()` → + /// [`DispatchResult::Stream`]. Unknown ops and ACL failures resolve via + /// the registry's own envelope/error paths (Once for `invoke`, a single + /// error envelope for `invoke_streaming`). + /// + /// For the streaming branch the root context's deadline is cleared + /// (`deadline: None`): subscriptions are long-running and unbounded — the + /// 30s request/response deadline does not apply (ADR-049 §6, call-protocol + /// Timeouts). The Once branch keeps the deadline from `build_root_context`. + pub async fn dispatch( + &self, + connection: &Arc, + request_id: String, + payload: Value, + ) -> DispatchResult { let operation_id = payload .get("operationId") .and_then(|v| v.as_str()) @@ -180,7 +237,13 @@ impl Dispatcher { let input = payload.get("input").cloned().unwrap_or(Value::Null); - let context = self.build_root_context( + let is_subscription = self + .registry + .registration(&operation_name) + .map(|r| r.spec.op_type == OperationType::Subscription) + .unwrap_or(false); + + let mut context = self.build_root_context( request_id.clone(), &operation_name, identity, @@ -188,7 +251,16 @@ impl Dispatcher { connection, ); - self.registry.invoke(&operation_name, input, context).await + if is_subscription { + context.deadline = None; + let stream = self + .registry + .invoke_streaming(&operation_name, input, context); + DispatchResult::Stream(stream) + } else { + let envelope = self.registry.invoke(&operation_name, input, context).await; + DispatchResult::Once(envelope) + } } pub async fn handle_abort(&self, connection: &Arc, request_id: &str) { @@ -225,14 +297,20 @@ impl Dispatcher { let request_id = envelope.id.clone(); let payload = envelope.payload.clone(); - let response = self - .dispatch_requested(&connection, request_id.clone(), payload) - .await; - - let event: EventEnvelope = response.into(); - if let Err(err) = writer.write_frame(&event).await { - warn!(error = %err, "failed to write response frame; closing stream"); - break; + match self + .dispatch(&connection, request_id.clone(), payload) + .await + { + DispatchResult::Once(response) => { + let event: EventEnvelope = response.into(); + if let Err(err) = writer.write_frame(&event).await { + warn!(error = %err, "failed to write response frame; closing stream"); + break; + } + } + DispatchResult::Stream(stream) => { + self.pump_stream(&mut writer, &request_id, stream).await; + } } } EVENT_ABORTED => { @@ -246,6 +324,43 @@ impl Dispatcher { } } + /// Pump a subscription's [`ResponseStream`] to the wire: each + /// [`ResponseEnvelope`] becomes an [`EventEnvelope`] frame (`call.responded` + /// for `Ok`, `call.error` for `Err`). On natural stream end (the stream + /// returned `None` without the last item being an `Err`), write a + /// `call.completed` frame. An `Err` envelope is terminal — the stream + /// ends after it and we do NOT write `call.completed` (ADR-049 §6). + /// + /// If a frame write fails the pump stops early; the stream is dropped on + /// return, releasing the handler's resources via `Drop` (ADR-016). The + /// pump is cancellable: it runs inside the `handle_stream` task, so a + /// `call.aborted` for this request ID (handled by `handle_abort` on + /// another stream) or connection close cancels the task and drops the + /// stream. + pub(crate) async fn pump_stream( + &self, + writer: &mut super::wire::FrameFramedWriter, + request_id: &str, + mut stream: ResponseStream, + ) { + let mut last_was_error = false; + while let Some(envelope) = stream.next().await { + last_was_error = envelope.result.is_err(); + let event: EventEnvelope = envelope.into(); + if let Err(err) = writer.write_frame(&event).await { + warn!(error = %err, "failed to write streaming frame; closing stream"); + return; + } + } + + if !last_was_error { + let completed = EventEnvelope::completed(request_id); + if let Err(err) = writer.write_frame(&completed).await { + warn!(error = %err, "failed to write call.completed"); + } + } + } + /// Run the shared dispatch loop over an established `CallConnection`: /// spawn the pending-entry sweeper, accept bidirectional streams until the /// connection closes, dispatch each stream via `handle_stream`, and fail @@ -325,9 +440,9 @@ impl Clone for Dispatcher { #[cfg(test)] mod tests { use super::*; - use crate::protocol::wire::EVENT_RESPONDED; + use crate::protocol::wire::{EVENT_COMPLETED, EVENT_ERROR, EVENT_RESPONDED}; use crate::registry::registration::{ - make_handler, HandlerKind, HandlerRegistration, OperationProvenance, + make_handler, make_streaming_handler, HandlerKind, HandlerRegistration, OperationProvenance, }; use crate::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility}; use alknet_core::auth::{AuthToken, Identity, IdentityProvider}; @@ -874,4 +989,388 @@ mod tests { Some(&serde_json::json!({ "v": 42 })) ); } + + // --- streaming dispatch branch (ADR-049 §6) --------------------------- + + fn subscription_spec(name: &str, acl: AccessControl) -> OperationSpec { + OperationSpec::new( + name, + OperationType::Subscription, + Visibility::External, + serde_json::json!({}), + serde_json::json!({}), + vec![], + acl, + ) + } + + fn encode_frame(envelope: &EventEnvelope) -> Vec { + let body = serde_json::to_vec(envelope).expect("serialize envelope"); + let mut buf = (body.len() as u32).to_be_bytes().to_vec(); + buf.extend_from_slice(&body); + buf + } + + async fn read_all_frames( + reader: &mut (impl tokio::io::AsyncRead + Unpin), + ) -> Vec { + let mut buf = Vec::new(); + use tokio::io::AsyncReadExt; + let _ = reader.read_to_end(&mut buf).await; + let mut frames = Vec::new(); + let mut cursor = std::io::Cursor::new(buf); + loop { + let mut len_buf = [0u8; 4]; + match tokio::io::AsyncReadExt::read_exact(&mut cursor, &mut len_buf).await { + Ok(_) => {} + Err(_) => break, + } + let len = u32::from_be_bytes(len_buf) as usize; + let mut body = vec![0u8; len]; + if tokio::io::AsyncReadExt::read_exact(&mut cursor, &mut body) + .await + .is_err() + { + break; + } + let envelope: EventEnvelope = + serde_json::from_slice(&body).expect("deserialize written frame"); + frames.push(envelope); + } + frames + } + + fn registry_with_subscription( + name: &str, + handler: crate::registry::registration::StreamingHandler, + ) -> Arc { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + subscription_spec(name, AccessControl::default()), + HandlerKind::Stream(handler), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + Arc::new(registry) + } + + #[tokio::test] + async fn dispatch_subscription_returns_stream_result() { + let handler = make_streaming_handler(|input, ctx| { + futures::stream::iter(vec![ + ResponseEnvelope::ok(ctx.request_id.clone(), input.clone()), + ResponseEnvelope::ok(ctx.request_id.clone(), serde_json::json!({"done": true})), + ]) + }); + let registry = registry_with_subscription("events/stream", handler); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/events/stream", + "input": { "v": 1 }, + }); + match dp.dispatch(&conn, "sub-1".to_string(), payload).await { + DispatchResult::Stream(mut stream) => { + use futures::stream::StreamExt; + let first = stream.next().await.expect("first envelope"); + assert_eq!(first.request_id, "sub-1"); + assert_eq!(first.result, Ok(serde_json::json!({ "v": 1 }))); + let second = stream.next().await.expect("second envelope"); + assert_eq!(second.result, Ok(serde_json::json!({ "done": true }))); + assert!( + stream.next().await.is_none(), + "stream ends after two values" + ); + } + other => panic!("expected Stream, got {other:?}"), + } + } + + #[tokio::test] + async fn dispatch_subscription_clears_deadline_to_none() { + let handler = make_streaming_handler(|_input, ctx| { + let deadline = ctx.deadline; + futures::stream::iter(vec![ResponseEnvelope::ok( + ctx.request_id.clone(), + serde_json::json!({ "deadline_is_none": deadline.is_none() }), + )]) + }); + let registry = registry_with_subscription("events/stream", handler); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/events/stream", + "input": {}, + }); + match dp.dispatch(&conn, "sub-dl".to_string(), payload).await { + DispatchResult::Stream(mut stream) => { + use futures::stream::StreamExt; + let env = stream.next().await.expect("one envelope"); + let out = env.result.expect("ok"); + assert_eq!(out["deadline_is_none"], Value::Bool(true)); + } + other => panic!("expected Stream, got {other:?}"), + } + } + + #[tokio::test] + async fn dispatch_query_keeps_deadline_some() { + let mut registry = OperationRegistry::new(); + let handler = make_handler(|_input, ctx| async move { + let deadline_is_some = ctx.deadline.is_some(); + ResponseEnvelope::ok( + ctx.request_id.clone(), + serde_json::json!({ "deadline_is_some": deadline_is_some }), + ) + }); + registry + .register(HandlerRegistration::new( + external_spec("echo/run", AccessControl::default()), + HandlerKind::Once(handler), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + let registry = Arc::new(registry); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let payload = serde_json::json!({ + "operationId": "/echo/run", + "input": {}, + }); + match dp.dispatch(&conn, "q-1".to_string(), payload).await { + DispatchResult::Once(env) => { + let out = env.result.expect("ok"); + assert_eq!(out["deadline_is_some"], Value::Bool(true)); + } + other => panic!("expected Once, got {other:?}"), + } + } + + #[tokio::test] + async fn handle_stream_subscription_pumps_each_frame_then_completed() { + let handler = make_streaming_handler(|input, ctx| { + let first = input.clone(); + let rid = ctx.request_id.clone(); + futures::stream::iter(vec![ + ResponseEnvelope::ok(rid.clone(), first), + ResponseEnvelope::ok(rid.clone(), serde_json::json!({"n": 2})), + ResponseEnvelope::ok(rid, serde_json::json!({"n": 3})), + ]) + }); + let registry = registry_with_subscription("events/stream", handler); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let request = EventEnvelope::requested( + "sub-pump-1", + serde_json::json!({ + "operationId": "/events/stream", + "input": { "n": 1 }, + }), + ); + let recv = tokio::io::BufReader::new(std::io::Cursor::new(encode_frame(&request))); + let (send, mut sink) = tokio::io::duplex(8 * 1024); + let send = alknet_core::types::SendStream::from_mock(send); + let recv = alknet_core::types::RecvStream::from_mock(recv); + + dp.handle_stream(conn, send, recv).await; + + let frames = read_all_frames(&mut sink).await; + assert_eq!(frames.len(), 4, "3 responded + 1 completed"); + for (i, f) in frames[..3].iter().enumerate() { + assert_eq!(f.r#type, EVENT_RESPONDED, "frame {i} is call.responded"); + assert_eq!(f.id, "sub-pump-1"); + } + assert_eq!(frames[3].r#type, EVENT_COMPLETED); + assert_eq!(frames[3].id, "sub-pump-1"); + assert_eq!(frames[3].payload, serde_json::json!({})); + } + + #[tokio::test] + async fn handle_stream_subscription_error_is_terminal_no_completed() { + let handler = make_streaming_handler(|_input, ctx| { + let rid = ctx.request_id.clone(); + futures::stream::iter(vec![ + ResponseEnvelope::ok(rid.clone(), serde_json::json!({"ok": true})), + ResponseEnvelope::error(rid.clone(), CallError::internal("boom")), + ]) + }); + let registry = registry_with_subscription("events/stream", handler); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let request = EventEnvelope::requested( + "sub-err-1", + serde_json::json!({ + "operationId": "/events/stream", + "input": {}, + }), + ); + let recv = tokio::io::BufReader::new(std::io::Cursor::new(encode_frame(&request))); + let (send, mut sink) = tokio::io::duplex(8 * 1024); + let send = alknet_core::types::SendStream::from_mock(send); + let recv = alknet_core::types::RecvStream::from_mock(recv); + + dp.handle_stream(conn, send, recv).await; + + let frames = read_all_frames(&mut sink).await; + assert_eq!(frames.len(), 2, "1 responded + 1 error, no completed"); + assert_eq!(frames[0].r#type, EVENT_RESPONDED); + assert_eq!(frames[1].r#type, EVENT_ERROR); + assert_eq!(frames[1].id, "sub-err-1"); + assert_eq!( + frames[1].payload.get("code"), + Some(&Value::String("INTERNAL".into())) + ); + } + + #[tokio::test] + async fn handle_stream_query_dispatch_unchanged_one_frame_no_completed() { + let registry = Arc::new(registry_with( + "echo/run", + Visibility::External, + AccessControl::default(), + )); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let request = EventEnvelope::requested( + "q-pump-1", + serde_json::json!({ + "operationId": "/echo/run", + "input": { "msg": "hi" }, + }), + ); + let recv = tokio::io::BufReader::new(std::io::Cursor::new(encode_frame(&request))); + let (send, mut sink) = tokio::io::duplex(8 * 1024); + let send = alknet_core::types::SendStream::from_mock(send); + let recv = alknet_core::types::RecvStream::from_mock(recv); + + dp.handle_stream(conn, send, recv).await; + + let frames = read_all_frames(&mut sink).await; + assert_eq!(frames.len(), 1, "query: exactly one frame, no completed"); + assert_eq!(frames[0].r#type, EVENT_RESPONDED); + assert_eq!(frames[0].id, "q-pump-1"); + assert_eq!( + frames[0].payload.get("output"), + Some(&serde_json::json!({ "msg": "hi" })) + ); + } + + #[tokio::test] + async fn handle_stream_subscription_unknown_op_yields_single_error_no_completed() { + let registry = Arc::new(OperationRegistry::new()); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let request = EventEnvelope::requested( + "sub-missing-1", + serde_json::json!({ + "operationId": "/no/such/stream", + "input": {}, + }), + ); + let recv = tokio::io::BufReader::new(std::io::Cursor::new(encode_frame(&request))); + let (send, mut sink) = tokio::io::duplex(8 * 1024); + let send = alknet_core::types::SendStream::from_mock(send); + let recv = alknet_core::types::RecvStream::from_mock(recv); + + dp.handle_stream(conn, send, recv).await; + + let frames = read_all_frames(&mut sink).await; + assert_eq!(frames.len(), 1, "unknown op: single error, no completed"); + assert_eq!(frames[0].r#type, EVENT_ERROR); + assert_eq!(frames[0].id, "sub-missing-1"); + assert_eq!( + frames[0].payload.get("code"), + Some(&Value::String("NOT_FOUND".into())) + ); + } + + #[tokio::test] + async fn handle_stream_aborted_for_streaming_request_drops_stream() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc as StdArc; + + let dropped = StdArc::new(AtomicBool::new(false)); + let dropped_clone = StdArc::clone(&dropped); + let handler = make_streaming_handler(move |_input, ctx| { + let rid = ctx.request_id.clone(); + let flag = StdArc::clone(&dropped_clone); + struct DropGuard(StdArc); + impl Drop for DropGuard { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + let guard = DropGuard(StdArc::clone(&flag)); + futures::stream::poll_fn(move |_cx| { + if flag.load(Ordering::SeqCst) { + return std::task::Poll::Ready(None); + } + std::task::Poll::Ready(Some(ResponseEnvelope::ok( + rid.clone(), + serde_json::json!({"tick": 1}), + ))) + }) + .map(move |env| { + let _keep_guard = &guard; + env + }) + }); + let registry = registry_with_subscription("events/stream", handler); + let provider: Arc = Arc::new(StaticIdentityProvider::new()); + let dp = Dispatcher::new(registry, provider); + let conn = Arc::new(CallConnection::new(stub_connection())); + + let request = EventEnvelope::requested( + "sub-abort-1", + serde_json::json!({ + "operationId": "/events/stream", + "input": {}, + }), + ); + let recv = tokio::io::BufReader::new(std::io::Cursor::new(encode_frame(&request))); + let (send, _sink) = tokio::io::duplex(8 * 1024); + let send = alknet_core::types::SendStream::from_mock(send); + let recv = alknet_core::types::RecvStream::from_mock(recv); + + let conn_clone = Arc::clone(&conn); + let dp_clone = dp.clone(); + let handle = tokio::spawn(async move { + dp_clone.handle_stream(conn_clone, send, recv).await; + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + dp.handle_abort(&conn, "sub-abort-1").await; + assert!( + !conn.pending().lock().contains("sub-abort-1"), + "abort removes the pending entry" + ); + + handle.abort(); + let _ = handle.await; + assert!( + dropped.load(Ordering::SeqCst), + "stream future dropped → Drop guard released handler resources" + ); + } }