diff --git a/crates/alknet-http/src/adapters/from_openapi.rs b/crates/alknet-http/src/adapters/from_openapi.rs index d41c0dd..35541a0 100644 --- a/crates/alknet-http/src/adapters/from_openapi.rs +++ b/crates/alknet-http/src/adapters/from_openapi.rs @@ -18,13 +18,15 @@ use alknet_call::client::{AdapterError, OperationAdapter}; use alknet_call::protocol::wire::{CallError, ResponseEnvelope}; use alknet_call::registry::context::OperationContext; use alknet_call::registry::registration::{ - make_handler, HandlerKind, HandlerRegistration, OperationProvenance, + make_handler, make_streaming_handler, HandlerKind, HandlerRegistration, OperationProvenance, + ResponseStream, }; use alknet_call::registry::spec::{ AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility, }; use alknet_core::types::Capabilities; use async_trait::async_trait; +use futures::stream; use futures::StreamExt; use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use reqwest::Method; @@ -440,38 +442,66 @@ impl FromOpenAPI { .map(|e| (e.http_status.unwrap_or(0), e.code.clone())) .collect(); - let handler = make_handler(move |input: Value, context: OperationContext| { - let path_template = path_template.clone(); - let method_upper = method_upper.clone(); - let auth_scheme = auth_scheme.clone(); - let default_headers = default_headers.clone(); - let base_url = base_url.clone(); - let namespace = namespace.clone(); - let http_client = Arc::clone(&http_client); - let error_status_codes = error_status_codes.clone(); - let op_type = op_type; - async move { - forward( - &http_client, - &base_url, - &path_template, - &method_upper, - &auth_scheme, - &default_headers, - &namespace, - &error_status_codes, - op_type, - input, - context, - ) - .await - } - }); + let handler = if op_type == OperationType::Subscription { + let stream_handler = + make_streaming_handler(move |input: Value, context: OperationContext| { + let path_template = path_template.clone(); + let method_upper = method_upper.clone(); + let auth_scheme = auth_scheme.clone(); + let default_headers = default_headers.clone(); + let base_url = base_url.clone(); + let namespace = namespace.clone(); + let http_client = Arc::clone(&http_client); + let error_status_codes = error_status_codes.clone(); + forward_stream( + &http_client, + &base_url, + &path_template, + &method_upper, + &auth_scheme, + &default_headers, + &namespace, + &error_status_codes, + input, + context, + ) + }); + HandlerKind::Stream(stream_handler) + } else { + let once_handler = make_handler(move |input: Value, context: OperationContext| { + let path_template = path_template.clone(); + let method_upper = method_upper.clone(); + let auth_scheme = auth_scheme.clone(); + let default_headers = default_headers.clone(); + let base_url = base_url.clone(); + let namespace = namespace.clone(); + let http_client = Arc::clone(&http_client); + let error_status_codes = error_status_codes.clone(); + let op_type = op_type; + async move { + forward( + &http_client, + &base_url, + &path_template, + &method_upper, + &auth_scheme, + &default_headers, + &namespace, + &error_status_codes, + op_type, + input, + context, + ) + .await + } + }); + HandlerKind::Once(once_handler) + }; let capabilities = Capabilities::new(); Ok(HandlerRegistration::new( spec, - HandlerKind::Once(handler), + handler, OperationProvenance::FromOpenAPI, None, None, @@ -666,10 +696,6 @@ async fn forward( let status = response.status(); - if op_type == OperationType::Subscription && status.is_success() { - return stream_subscription(request_id, response).await; - } - if !status.is_success() { let code = error_status_codes .iter() @@ -721,35 +747,136 @@ async fn forward( } } -async fn stream_subscription(request_id: String, response: reqwest::Response) -> ResponseEnvelope { - let mut stream = response.bytes_stream(); - let mut buffer = String::new(); - let mut last_event: Option = None; - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - buffer.push_str(&String::from_utf8_lossy(&chunk)); - let (events, remaining) = parse_sse_frames(&buffer); - buffer = remaining; - for event in events { - let parsed = if event.data.trim().is_empty() { - Value::Null - } else { - serde_json::from_str(&event.data) - .unwrap_or(Value::String(event.data.clone())) - }; - last_event = Some(parsed.clone()); +#[allow(clippy::too_many_arguments)] +fn forward_stream( + http_client: &Arc, + base_url: &str, + path_template: &str, + method: &str, + auth_scheme: &Option, + default_headers: &HashMap, + namespace: &str, + error_status_codes: &[(u16, String)], + input: Value, + context: OperationContext, +) -> ResponseStream { + let request_id = context.request_id.clone(); + + let (http_method, url, body, headers) = match build_request( + base_url, + path_template, + method, + auth_scheme, + default_headers, + namespace, + &input, + &context, + ) { + Ok(parts) => parts, + Err(err) => { + return Box::pin(stream::once(async move { + ResponseEnvelope::error(request_id, err) + })); + } + }; + + let http_client = Arc::clone(http_client); + let error_status_codes = error_status_codes.to_vec(); + + let request_id_stream = request_id.clone(); + let error_status_codes_stream = error_status_codes.clone(); + + let init = async move { + let request_builder = http_client + .client() + .request(http_method, url.as_str()) + .headers(headers) + .header(ACCEPT, "text/event-stream"); + let request_builder = match body.as_ref() { + Some(b) => { + let serialized = serde_json::to_string(b).unwrap_or_else(|_| String::from("null")); + request_builder.body(serialized) + } + None => request_builder, + }; + request_builder.send().await + }; + + let sse = stream::once(init).flat_map(move |result| { + let request_id = request_id_stream.clone(); + let error_status_codes = error_status_codes_stream.clone(); + match result { + Err(err) => Box::pin(stream::once(async move { + ResponseEnvelope::error( + request_id, + CallError::internal(format!("HTTP request failed: {err}")), + ) + })) as ResponseStream, + Ok(response) => { + let status = response.status(); + if !status.is_success() { + let code = error_status_codes + .iter() + .find(|(s, _)| *s == status.as_u16()) + .map(|(_, c)| c.clone()) + .unwrap_or_else(|| format!("HTTP_{}", status.as_u16())); + let message = format!( + "HTTP {}: {}", + status.as_u16(), + status.canonical_reason().unwrap_or("") + ); + Box::pin(stream::once(async move { + ResponseEnvelope::error(request_id, CallError::new(code, message, false)) + })) as ResponseStream + } else { + let request_id_inner = request_id.clone(); + Box::pin( + stream::unfold( + (response.bytes_stream(), String::new()), + move |(mut bytes, mut buffer)| { + let request_id = request_id_inner.clone(); + async move { + match bytes.next().await { + Some(Ok(chunk)) => { + buffer.push_str(&String::from_utf8_lossy(&chunk)); + let (events, remaining) = parse_sse_frames(&buffer); + let envelopes: Vec = events + .into_iter() + .map(|e| { + let parsed = if e.data.trim().is_empty() { + Value::Null + } else { + serde_json::from_str(&e.data).unwrap_or( + Value::String(e.data.clone()), + ) + }; + ResponseEnvelope::ok(&request_id, parsed) + }) + .collect(); + Some((envelopes, (bytes, remaining))) + } + Some(Err(err)) => { + let error = CallError::internal(format!( + "SSE stream error: {err}" + )); + Some(( + vec![ResponseEnvelope::error(request_id, error)], + (bytes, buffer), + )) + } + None => None, + } + } + }, + ) + .flat_map(stream::iter), + ) as ResponseStream } } - Err(err) => { - return ResponseEnvelope::error( - request_id, - CallError::internal(format!("SSE stream error: {err}")), - ); - } } - } - ResponseEnvelope::ok(request_id, last_event.unwrap_or(Value::Null)) + }); + + Box::pin(sse) } struct SseEvent { @@ -1194,6 +1321,34 @@ mod tests { } } + #[tokio::test] + async fn subscription_op_registration_is_handler_kind_stream() { + let spec = OpenAPISpec::from_json( + r##"{"openapi":"3.0.0","info":{"title":"T","version":"1"}, + "paths":{"/stream":{"post":{"operationId":"stream","responses":{"200":{"content":{"text/event-stream":{"schema":{}}}}}}}}}"##, + ) + .unwrap(); + let bundles = adapter(spec, config("svc", "https://x", None)) + .import() + .await + .unwrap(); + assert!(matches!(bundles[0].handler, HandlerKind::Stream(_))); + } + + #[tokio::test] + async fn query_op_registration_is_handler_kind_once() { + let spec = OpenAPISpec::from_json( + r#"{"openapi":"3.0.0","info":{"title":"T","version":"1"}, + "paths":{"/data":{"get":{"operationId":"data","responses":{"200":{"content":{"application/json":{"schema":{}}}}}}}}}"#, + ) + .unwrap(); + let bundles = adapter(spec, config("svc", "https://x", None)) + .import() + .await + .unwrap(); + assert!(matches!(bundles[0].handler, HandlerKind::Once(_))); + } + #[tokio::test] async fn integration_sse_subscription_streams_responded_events() { let sse_body = "data: {\"n\":1}\n\ndata: {\"n\":2}\n\n"; @@ -1209,13 +1364,67 @@ mod tests { .unwrap(); let registration = &bundles[0]; let ctx = noop_context("req-12", Capabilities::new()); + let stream = match ®istration.handler { + HandlerKind::Stream(h) => h(serde_json::json!({}), ctx), + _ => panic!("expected Stream handler"), + }; + let collected: Vec = stream.collect().await; + assert_eq!(collected.len(), 2); + assert_eq!(collected[0].result, Ok(serde_json::json!({"n":1}))); + assert_eq!(collected[1].result, Ok(serde_json::json!({"n":2}))); + assert_eq!(collected[0].request_id, "req-12"); + assert_eq!(collected[1].request_id, "req-12"); + } + + #[tokio::test] + async fn integration_sse_subscription_http_error_returns_single_error_envelope() { + let base = spawn_echo_server(404, r#"{"error":"missing"}"#, "application/json").await; + let spec = OpenAPISpec::from_json( + r##"{"openapi":"3.0.0","info":{"title":"T","version":"1"}, + "paths":{"/stream":{"post":{"operationId":"stream","responses":{ + "200":{"content":{"text/event-stream":{"schema":{}}}}, + "404":{"content":{"application/json":{"schema":{"type":"object"}}}} + }}}}}"##, + ) + .unwrap(); + let bundles = adapter(spec, config("svc", &base, None)) + .import() + .await + .unwrap(); + let registration = &bundles[0]; + let ctx = noop_context("req-err", Capabilities::new()); + let stream = match ®istration.handler { + HandlerKind::Stream(h) => h(serde_json::json!({}), ctx), + _ => panic!("expected Stream handler"), + }; + let collected: Vec = stream.collect().await; + assert_eq!(collected.len(), 1); + match &collected[0].result { + Err(e) => assert_eq!(e.code, "HTTP_404"), + other => panic!("expected HTTP_404 error, got {other:?}"), + } + } + + #[tokio::test] + async fn integration_query_forwarding_unchanged_single_response() { + let base = spawn_echo_server(200, r#"{"ok":true}"#, "application/json").await; + let spec = OpenAPISpec::from_json( + r#"{"openapi":"3.0.0","info":{"title":"T","version":"1"}, + "paths":{"/data":{"get":{"operationId":"data","responses":{"200":{"content":{"application/json":{"schema":{}}}}}}}}}"#, + ) + .unwrap(); + let bundles = adapter(spec, config("svc", &base, None)) + .import() + .await + .unwrap(); + let registration = &bundles[0]; + let ctx = noop_context("req-q", Capabilities::new()); let response = match ®istration.handler { HandlerKind::Once(h) => h(serde_json::json!({}), ctx).await, _ => panic!("expected Once handler"), }; - assert!(response.result.is_ok()); - let last = response.result.unwrap(); - assert_eq!(last, serde_json::json!({"n":2})); + assert_eq!(response.request_id, "req-q"); + assert_eq!(response.result, Ok(serde_json::json!({"ok":true}))); } #[test]