diff --git a/crates/alknet-http/src/server/gateway_routes.rs b/crates/alknet-http/src/server/gateway_routes.rs index 1fe675d..c2d19be 100644 --- a/crates/alknet-http/src/server/gateway_routes.rs +++ b/crates/alknet-http/src/server/gateway_routes.rs @@ -17,7 +17,8 @@ use axum::response::sse::Event; use axum::response::{IntoResponse, Json, Response, Sse}; use axum::routing::{get, post}; use axum::Router; -use futures::stream::{self, BoxStream, Stream}; +use futures::stream::{self, BoxStream}; +use futures::StreamExt; use serde::Deserialize; use serde_json::{json, Value}; @@ -163,18 +164,29 @@ pub(crate) async fn subscribe_handler( subscribe_stream_internal_error(request.operation) } else { let dispatch = state.dispatch(); - let envelope = dispatch - .invoke(identity, &request.operation, request.input) - .await; - subscribe_stream_from_envelope(envelope) + let envelope_stream = + dispatch.invoke_streaming(identity, &request.operation, request.input); + subscribe_stream_from_envelope_stream(envelope_stream) }; Sse::new(stream) } pub type SubscribeStream = BoxStream<'static, Result>; -fn subscribe_stream_from_envelope(envelope: ResponseEnvelope) -> SubscribeStream { - Box::pin(envelope_to_sse_stream(envelope)) +fn subscribe_stream_from_envelope_stream( + stream: BoxStream<'static, ResponseEnvelope>, +) -> SubscribeStream { + Box::pin(stream.map(|envelope| match envelope.result { + Ok(output) => { + let data = serde_json::to_string(&output).unwrap_or_else(|_| "null".to_string()); + Ok(Event::default().data(data)) + } + Err(error) => { + let payload = serde_json::to_value(&error).unwrap_or(Value::Null); + let data = serde_json::to_string(&payload).unwrap_or_else(|_| "null".to_string()); + Ok(Event::default().event("error").data(data)) + } + })) } fn subscribe_stream_internal_error(operation: String) -> SubscribeStream { @@ -263,24 +275,6 @@ fn is_internal_op(registry: &OperationRegistry, operation: &str) -> bool { } } -fn envelope_to_sse_stream( - envelope: ResponseEnvelope, -) -> impl Stream> { - stream::once(async move { - match envelope.result { - Ok(output) => { - let data = serde_json::to_string(&output).unwrap_or_else(|_| "null".to_string()); - Ok(Event::default().data(data)) - } - Err(error) => { - let payload = serde_json::to_value(&error).unwrap_or(Value::Null); - let data = serde_json::to_string(&payload).unwrap_or_else(|_| "null".to_string()); - Ok(Event::default().event("error").data(data)) - } - } - }) -} - fn error_event(operation: &str) -> Result { let error = CallError::not_found(operation); let payload = serde_json::to_value(&error).unwrap_or(Value::Null); @@ -295,7 +289,7 @@ mod tests { services_list_handler, services_list_spec, services_schema_handler, services_schema_spec, }; use alknet_call::registry::registration::{ - make_handler, HandlerKind, HandlerRegistration, OperationProvenance, + make_handler, make_streaming_handler, HandlerKind, HandlerRegistration, OperationProvenance, }; use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType}; use alknet_core::auth::{AuthToken, Identity}; @@ -425,6 +419,73 @@ mod tests { Arc::new(registry) } + fn subscription_spec(name: &str, visibility: Visibility, acl: AccessControl) -> OperationSpec { + OperationSpec::new( + name, + OperationType::Subscription, + visibility, + json!({}), + json!({}), + vec![], + acl, + ) + } + + fn multi_event_streaming_handler( + outputs: Vec, + ) -> alknet_call::registry::registration::StreamingHandler { + make_streaming_handler(move |_input, ctx| { + let request_id = ctx.request_id.clone(); + let outputs = outputs.clone(); + futures::stream::iter( + outputs + .into_iter() + .map(move |o| ResponseEnvelope::ok(request_id.clone(), o)), + ) + }) + } + + fn error_streaming_handler(error: CallError) -> HandlerKind { + HandlerKind::Stream(make_streaming_handler(move |_input, ctx| { + let request_id = ctx.request_id.clone(); + let error = error.clone(); + futures::stream::iter(vec![ResponseEnvelope::error(request_id, error)]) + })) + } + + fn registry_with_subscription_stream( + name: &str, + outputs: Vec, + ) -> Arc { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + subscription_spec(name, Visibility::External, AccessControl::default()), + HandlerKind::Stream(multi_event_streaming_handler(outputs)), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + Arc::new(registry) + } + + fn registry_with_subscription_error(name: &str, error: CallError) -> Arc { + let mut registry = OperationRegistry::new(); + registry + .register(HandlerRegistration::new( + subscription_spec(name, Visibility::External, AccessControl::default()), + error_streaming_handler(error), + OperationProvenance::Local, + None, + None, + Capabilities::new(), + )) + .unwrap(); + Arc::new(registry) + } + fn registry_with_discovery_and_ops( inner_ops: Vec, ) -> Arc { @@ -771,15 +832,20 @@ mod tests { } #[tokio::test] - async fn subscribe_streams_sse_data_event_until_completed() { - let router = build_router(registry_with_echo(), unused_provider()); + async fn subscribe_on_subscription_streams_multiple_data_frames() { + let router = build_router( + registry_with_subscription_stream( + "events/stream", + vec![json!({ "n": 1 }), json!({ "n": 2 }), json!({ "n": 3 })], + ), + unused_provider(), + ); let req = Request::builder() .method("POST") .uri("/subscribe") .header("content-type", "application/json") .body(Body::from( - serde_json::to_vec(&json!({ "operation": "echo/run", "input": { "v": 9 } })) - .unwrap(), + serde_json::to_vec(&json!({ "operation": "events/stream", "input": {} })).unwrap(), )) .unwrap(); let resp = router.oneshot(req).await.unwrap(); @@ -797,10 +863,73 @@ mod tests { ); let bytes = resp.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8_lossy(&bytes); - assert!(body.contains("data:"), "expected a data frame, got: {body}"); + let data_frames = body.matches("data:").count(); + assert_eq!(data_frames, 3, "expected 3 data frames, got: {body}"); + assert!(body.contains("\"n\":1"), "expected n=1, got: {body}"); + assert!(body.contains("\"n\":2"), "expected n=2, got: {body}"); + assert!(body.contains("\"n\":3"), "expected n=3, got: {body}"); + } + + #[tokio::test] + async fn subscribe_on_subscription_that_yields_error_emits_error_event_then_closes() { + let router = build_router( + registry_with_subscription_error("events/fail", CallError::internal("handler blew up")), + unused_provider(), + ); + let req = Request::builder() + .method("POST") + .uri("/subscribe") + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_vec(&json!({ "operation": "events/fail", "input": {} })).unwrap(), + )) + .unwrap(); + let resp = router.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let bytes = resp.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8_lossy(&bytes); assert!( - body.contains("\"v\":9"), - "expected output payload, got: {body}" + body.contains("event:error") || body.contains("event: error"), + "expected error event, got: {body}" + ); + assert!( + body.contains("INTERNAL"), + "expected INTERNAL code, got: {body}" + ); + assert!( + body.contains("handler blew up"), + "expected error message, got: {body}" + ); + let data_frames = body.matches("data:").count(); + assert_eq!( + data_frames, 1, + "expected exactly one data frame (the error payload), got: {body}" + ); + } + + #[tokio::test] + async fn subscribe_response_content_type_is_text_event_stream() { + let router = build_router( + registry_with_subscription_stream("events/stream", vec![json!({ "ok": true })]), + unused_provider(), + ); + let req = Request::builder() + .method("POST") + .uri("/subscribe") + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_vec(&json!({ "operation": "events/stream", "input": {} })).unwrap(), + )) + .unwrap(); + let resp = router.oneshot(req).await.unwrap(); + let ctype = resp + .headers() + .get(axum::http::header::CONTENT_TYPE) + .map(|v| v.to_str().unwrap().to_string()); + assert_eq!( + ctype.as_deref(), + Some("text/event-stream"), + "expected text/event-stream, got {ctype:?}" ); } @@ -829,6 +958,59 @@ mod tests { ); } + #[tokio::test] + async fn subscribe_unknown_op_emits_not_found_error_event() { + let router = build_router( + registry_with_subscription_stream("events/stream", vec![json!({})]), + unused_provider(), + ); + let req = Request::builder() + .method("POST") + .uri("/subscribe") + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_vec(&json!({ "operation": "no/such", "input": {} })).unwrap(), + )) + .unwrap(); + let resp = router.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let bytes = resp.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8_lossy(&bytes); + assert!( + body.contains("event:error") || body.contains("event: error"), + "expected error event, got: {body}" + ); + assert!( + body.contains("NOT_FOUND"), + "expected NOT_FOUND, got: {body}" + ); + } + + #[tokio::test] + async fn subscribe_on_query_op_emits_invalid_operation_type_error_event() { + let router = build_router(registry_with_echo(), unused_provider()); + let req = Request::builder() + .method("POST") + .uri("/subscribe") + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_vec(&json!({ "operation": "echo/run", "input": {} })).unwrap(), + )) + .unwrap(); + let resp = router.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let bytes = resp.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8_lossy(&bytes); + assert!( + body.contains("event:error") || body.contains("event: error"), + "expected error event, got: {body}" + ); + assert!( + body.contains("INVALID_OPERATION_TYPE"), + "expected INVALID_OPERATION_TYPE, got: {body}" + ); + } + #[test] fn is_internal_op_returns_false_for_unknown() { let registry = OperationRegistry::new();