diff --git a/crates/alknet-http/src/adapters/from_mcp/tests.rs b/crates/alknet-http/src/adapters/from_mcp/tests.rs index 8381a19..3274937 100644 --- a/crates/alknet-http/src/adapters/from_mcp/tests.rs +++ b/crates/alknet-http/src/adapters/from_mcp/tests.rs @@ -22,7 +22,11 @@ fn make_tool(name: &str, input: Value, output: Option) -> Tool { tool } -fn call_tool_result(content: Vec, structured: Option, is_error: Option) -> CallToolResult { +fn call_tool_result( + content: Vec, + structured: Option, + is_error: Option, +) -> CallToolResult { let json = serde_json::json!({ "content": content, "structuredContent": structured, @@ -204,7 +208,9 @@ fn build_spec_output_schema_present_shape() { let tool = make_tool( "get_weather", serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } } }), - Some(serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } })), + Some( + serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } }), + ), ); let spec = build_spec(&tool, "weather"); assert_eq!(spec.name, "weather/get_weather"); @@ -248,4 +254,4 @@ async fn forwarding_handler_reads_capabilities_not_env_vars() { let adapter = FromMCP::new("http://127.0.0.1:1/mcp", "ns"); let _ = adapter.auth_token(); assert!(adapter.auth_token().is_none()); -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/adapters/from_openapi.rs b/crates/alknet-http/src/adapters/from_openapi.rs index 58f3e95..a57f2bb 100644 --- a/crates/alknet-http/src/adapters/from_openapi.rs +++ b/crates/alknet-http/src/adapters/from_openapi.rs @@ -17,10 +17,10 @@ use std::sync::Arc; use alknet_call::client::{AdapterError, OperationAdapter}; use alknet_call::protocol::wire::{CallError, ResponseEnvelope}; use alknet_call::registry::context::OperationContext; -use alknet_call::registry::registration::{ - make_handler, HandlerRegistration, OperationProvenance, +use alknet_call::registry::registration::{make_handler, HandlerRegistration, OperationProvenance}; +use alknet_call::registry::spec::{ + AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility, }; -use alknet_call::registry::spec::{AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility}; use alknet_core::types::Capabilities; use async_trait::async_trait; use futures::StreamExt; @@ -128,11 +128,9 @@ impl OpenAPISpec { .to_string(), }; - let paths_raw = raw - .get("paths") - .ok_or_else(|| AdapterError::SchemaParse { - message: "OpenAPI document missing `paths`".into(), - })?; + let paths_raw = raw.get("paths").ok_or_else(|| AdapterError::SchemaParse { + message: "OpenAPI document missing `paths`".into(), + })?; if !paths_raw.is_object() { return Err(AdapterError::SchemaParse { message: "`paths` must be a JSON object".into(), @@ -155,14 +153,13 @@ impl OpenAPISpec { if operations.is_empty() { continue; } - paths.insert( - path.clone(), - PathItem { operations }, - ); + paths.insert(path.clone(), PathItem { operations }); } - let components = raw.get("components").and_then(|c| c.get("schemas")).and_then( - |schemas| { + let components = raw + .get("components") + .and_then(|c| c.get("schemas")) + .and_then(|schemas| { if !schemas.is_object() { return None; } @@ -171,8 +168,7 @@ impl OpenAPISpec { map.insert(k.clone(), v.clone()); } Some(Components { schemas: map }) - }, - ); + }); Ok(Self { info, @@ -190,11 +186,9 @@ impl OpenAPISpec { } let mut current: &Value = &self.raw; for part in reference.trim_start_matches("#/").split('/') { - current = current - .get(part) - .ok_or_else(|| AdapterError::SchemaParse { - message: format!("cannot resolve $ref: {reference}"), - })?; + current = current.get(part).ok_or_else(|| AdapterError::SchemaParse { + message: format!("cannot resolve $ref: {reference}"), + })?; } Ok(current.clone()) } @@ -241,10 +235,7 @@ fn parse_operation(raw: &Value) -> Option { .filter_map(|p| { let name = p.get("name")?.as_str()?.to_string(); let in_ = p.get("in")?.as_str()?.to_string(); - let required = p - .get("required") - .and_then(|v| v.as_bool()) - .unwrap_or(false); + let required = p.get("required").and_then(|v| v.as_bool()).unwrap_or(false); let schema = p.get("schema").cloned(); Some(Parameter { name, @@ -297,7 +288,11 @@ pub struct FromOpenAPI { } impl FromOpenAPI { - pub fn new(spec: OpenAPISpec, config: HttpServiceConfig, http_client: Arc) -> Self { + pub fn new( + spec: OpenAPISpec, + config: HttpServiceConfig, + http_client: Arc, + ) -> Self { Self { spec, config, @@ -322,10 +317,7 @@ impl FromOpenAPI { } fn detect_op_type(method: &str, op: &Operation) -> OperationType { - let success = op - .responses - .get("200") - .or_else(|| op.responses.get("201")); + let success = op.responses.get("200").or_else(|| op.responses.get("201")); if let Some(resp) = success { if resp.content.contains_key("text/event-stream") { return OperationType::Subscription; @@ -531,9 +523,8 @@ fn build_request( } } - let base = Url::parse(base_url).map_err(|e| { - CallError::internal(format!("invalid base_url `{base_url}`: {e}")) - })?; + let base = Url::parse(base_url) + .map_err(|e| CallError::internal(format!("invalid base_url `{base_url}`: {e}")))?; let mut url = base .join(url_path.trim_start_matches('/')) .map_err(|e| CallError::internal(format!("invalid path `{url_path}`: {e}")))?; @@ -683,11 +674,12 @@ async fn forward( .find(|(s, _)| *s == status.as_u16()) .map(|(_, code)| code.clone()) .unwrap_or_else(|| format!("HTTP_{}", status.as_u16())); - let message = format!("HTTP {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("")); - return ResponseEnvelope::error( - request_id, - CallError::new(code, message, false), + let message = format!( + "HTTP {}: {}", + status.as_u16(), + status.canonical_reason().unwrap_or("") ); + return ResponseEnvelope::error(request_id, CallError::new(code, message, false)); } let content_type = response @@ -716,10 +708,7 @@ async fn forward( } else { match response.bytes().await { Ok(b) => { - let arr: Vec = b - .iter() - .map(|byte| Value::Number((*byte).into())) - .collect(); + let arr: Vec = b.iter().map(|byte| Value::Number((*byte).into())).collect(); ResponseEnvelope::ok(request_id, Value::Array(arr)) } Err(err) => ResponseEnvelope::error( @@ -744,7 +733,8 @@ async fn stream_subscription(request_id: String, response: reqwest::Response) -> let parsed = if event.data.trim().is_empty() { Value::Null } else { - serde_json::from_str(&event.data).unwrap_or(Value::String(event.data.clone())) + serde_json::from_str(&event.data) + .unwrap_or(Value::String(event.data.clone())) }; last_event = Some(parsed.clone()); } @@ -1040,7 +1030,12 @@ mod tests { .unwrap(); let body = props.get("body").unwrap(); assert_eq!(body.get("type").unwrap(), "object"); - assert!(body.get("properties").unwrap().as_object().unwrap().contains_key("name")); + assert!(body + .get("properties") + .unwrap() + .as_object() + .unwrap() + .contains_key("name")); } #[tokio::test] @@ -1074,14 +1069,19 @@ mod tests { "https://api.vast.ai", "/machines", "GET", - &Some(HttpAuthScheme::ApiKey { header_name: "X-API-Key".to_string() }), + &Some(HttpAuthScheme::ApiKey { + header_name: "X-API-Key".to_string(), + }), &HashMap::new(), "vastai", &serde_json::json!({}), &ctx, ) .unwrap(); - assert_eq!(headers.get("X-API-Key").unwrap().to_str().unwrap(), "key-xyz"); + assert_eq!( + headers.get("X-API-Key").unwrap().to_str().unwrap(), + "key-xyz" + ); } #[tokio::test] @@ -1267,7 +1267,11 @@ mod tests { #[test] fn http_service_config_struct_fields() { - let cfg = config("ns", "https://api.example.com", Some(HttpAuthScheme::Bearer)); + let cfg = config( + "ns", + "https://api.example.com", + Some(HttpAuthScheme::Bearer), + ); assert_eq!(cfg.namespace, "ns"); assert_eq!(cfg.base_url, "https://api.example.com"); assert!(matches!(cfg.auth, Some(HttpAuthScheme::Bearer))); @@ -1289,7 +1293,12 @@ mod tests { }"#; let spec = OpenAPISpec::from_json(doc).unwrap(); assert!(spec.components.is_some()); - assert!(spec.components.as_ref().unwrap().schemas.contains_key("Foo")); + assert!(spec + .components + .as_ref() + .unwrap() + .schemas + .contains_key("Foo")); } #[tokio::test] @@ -1342,7 +1351,9 @@ mod tests { #[tokio::test] async fn resolve_ref_missing_target_returns_schema_parse() { let spec = OpenAPISpec::from_json(minimal_spec_json()).unwrap(); - let err = spec.resolve_ref("#/components/schemas/Missing").unwrap_err(); + let err = spec + .resolve_ref("#/components/schemas/Missing") + .unwrap_err(); assert!(matches!(err, AdapterError::SchemaParse { .. })); } @@ -1409,7 +1420,8 @@ mod tests { headers, body, }); - let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}"; + let response = + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}"; sock.write_all(response.as_bytes()).await.unwrap(); sock.flush().await.unwrap(); }); @@ -1440,12 +1452,19 @@ mod tests { ctx, ) .await; - assert!(response.result.is_ok(), "expected Ok, got {:?}", response.result); + assert!( + response.result.is_ok(), + "expected Ok, got {:?}", + response.result + ); let captured = rx.await.unwrap(); assert_eq!(captured.method, "POST"); assert_eq!(captured.path, "/items/42"); assert_eq!(captured.query, "filter=new"); - assert_eq!(captured.headers.get("content-type").unwrap(), "application/json"); + assert_eq!( + captured.headers.get("content-type").unwrap(), + "application/json" + ); assert!(captured.body.contains("\"name\":\"widget\"")); } @@ -1457,19 +1476,19 @@ mod tests { }"#; let (base, rx) = spawn_capturing_server().await; let spec = OpenAPISpec::from_json(doc).unwrap(); - let bundles = adapter( - spec, - config("openai", &base, Some(HttpAuthScheme::Bearer)), - ) - .import() - .await - .unwrap(); + let bundles = adapter(spec, config("openai", &base, Some(HttpAuthScheme::Bearer))) + .import() + .await + .unwrap(); let registration = &bundles[0]; let caps = Capabilities::new().with_http_token("openai", "sk-test-token".to_string()); let ctx = noop_context("req-17", caps); let _ = (registration.handler)(serde_json::json!({}), ctx).await; let captured = rx.await.unwrap(); - assert_eq!(captured.headers.get("authorization").unwrap(), "Bearer sk-test-token"); + assert_eq!( + captured.headers.get("authorization").unwrap(), + "Bearer sk-test-token" + ); } #[tokio::test] @@ -1527,4 +1546,4 @@ mod tests { other => panic!("expected HTTP_500, got {other:?}"), } } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/adapters/mod.rs b/crates/alknet-http/src/adapters/mod.rs index 3ec6ed1..9d8c95a 100644 --- a/crates/alknet-http/src/adapters/mod.rs +++ b/crates/alknet-http/src/adapters/mod.rs @@ -22,4 +22,4 @@ pub use from_openapi::{FromOpenAPI, HttpAuthScheme, HttpServiceConfig, OpenAPISp pub use from_mcp::FromMCP; #[cfg(feature = "mcp")] -pub use to_mcp::{ToMcpGateway, ToMcpService, to_mcp_service}; +pub use to_mcp::{to_mcp_service, ToMcpGateway, ToMcpService}; diff --git a/crates/alknet-http/src/adapters/to_mcp.rs b/crates/alknet-http/src/adapters/to_mcp.rs index 56a026a..96a8697 100644 --- a/crates/alknet-http/src/adapters/to_mcp.rs +++ b/crates/alknet-http/src/adapters/to_mcp.rs @@ -36,8 +36,8 @@ use rmcp::model::{ }; use rmcp::service::{RequestContext, RoleServer}; use rmcp::transport::{ - StreamableHttpServerConfig, streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService}, + StreamableHttpServerConfig, }; use serde_json::{Map, Value}; @@ -133,7 +133,10 @@ impl ToMcpGateway { fn extract_identity_from_extensions(extensions: &rmcp::model::Extensions) -> Option { let parts = extensions.get::()?; - parts.extensions.get::>().and_then(Option::clone) + parts + .extensions + .get::>() + .and_then(Option::clone) } async fn handle_search(&self, identity: Option) -> CallToolResult { @@ -144,8 +147,15 @@ impl ToMcpGateway { map_search_response(response, identity.as_ref()) } - async fn handle_schema(&self, arguments: Option, identity: Option) -> CallToolResult { - let name = match arguments.and_then(|mut a| a.remove("name")).and_then(|v| v.as_str().map(str::to_string)) { + async fn handle_schema( + &self, + arguments: Option, + identity: Option, + ) -> CallToolResult { + let name = match arguments + .and_then(|mut a| a.remove("name")) + .and_then(|v| v.as_str().map(str::to_string)) + { Some(n) => n, None => { return CallToolResult::structured_error(serde_json::json!({ @@ -156,12 +166,20 @@ impl ToMcpGateway { }; let response = self .dispatch - .invoke(identity, OP_SERVICES_SCHEMA, serde_json::json!({ "name": name })) + .invoke( + identity, + OP_SERVICES_SCHEMA, + serde_json::json!({ "name": name }), + ) .await; envelope_to_call_tool_result(response) } - async fn handle_call(&self, arguments: Option, identity: Option) -> CallToolResult { + async fn handle_call( + &self, + arguments: Option, + identity: Option, + ) -> CallToolResult { let (operation, input) = match parse_call_arguments(arguments) { Ok(pair) => pair, Err(err) => return err, @@ -170,7 +188,11 @@ impl ToMcpGateway { envelope_to_call_tool_result(response) } - async fn handle_batch(&self, arguments: Option, identity: Option) -> CallToolResult { + async fn handle_batch( + &self, + arguments: Option, + identity: Option, + ) -> CallToolResult { let calls = match arguments .and_then(|mut a| a.remove("calls")) .and_then(|v| v.as_array().cloned()) @@ -193,7 +215,10 @@ impl ToMcpGateway { continue; } }; - let response = self.dispatch.invoke(identity.clone(), &operation, input).await; + let response = self + .dispatch + .invoke(identity.clone(), &operation, input) + .await; results.push(envelope_to_value(response)); } CallToolResult::structured(Value::Array(results)) @@ -210,7 +235,10 @@ fn parse_call_arguments(arguments: Option) -> Result<(String, Value) }))); } }; - let operation = match map.remove("operation").and_then(|v| v.as_str().map(str::to_string)) { + let operation = match map + .remove("operation") + .and_then(|v| v.as_str().map(str::to_string)) + { Some(s) => s, None => { return Err(CallToolResult::structured_error(serde_json::json!({ @@ -359,7 +387,11 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway { TOOL_CALL => this.handle_call(arguments, identity).await, TOOL_BATCH => this.handle_batch(arguments, identity).await, unknown => { - let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false); + let err = CallError::new( + "NOT_FOUND", + format!("unknown gateway tool: {unknown}"), + false, + ); call_error_to_structured_error(err) } }; @@ -368,9 +400,7 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway { } fn get_info(&self) -> ServerInfo { - let capabilities = ServerCapabilities::builder() - .enable_tools() - .build(); + let capabilities = ServerCapabilities::builder().enable_tools().build(); ServerInfo::new(capabilities) .with_server_info(Implementation::new( "alknet-to-mcp", @@ -462,10 +492,14 @@ mod tests { } fn make_echo_handler() -> alknet_call::registry::registration::Handler { - make_handler(|input, context| async move { ResponseEnvelope::ok(context.request_id, input) }) + make_handler( + |input, context| async move { ResponseEnvelope::ok(context.request_id, input) }, + ) } - fn full_registry_with_ops(specs: Vec<(String, OperationType, AccessControl)>) -> Arc { + fn full_registry_with_ops( + specs: Vec<(String, OperationType, AccessControl)>, + ) -> Arc { let mut inner = OperationRegistry::new(); for (name, op_type, acl) in specs { inner.register(HandlerRegistration::new( @@ -509,7 +543,10 @@ mod tests { Arc::new(dispatch_registry) } - fn dispatch(registry: Arc, provider: Arc) -> Arc { + fn dispatch( + registry: Arc, + provider: Arc, + ) -> Arc { Arc::new(GatewayDispatch::new(registry, provider)) } @@ -542,7 +579,11 @@ mod tests { TOOL_CALL => gateway.handle_call(arguments, identity).await, TOOL_BATCH => gateway.handle_batch(arguments, identity).await, unknown => { - let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false); + let err = CallError::new( + "NOT_FOUND", + format!("unknown gateway tool: {unknown}"), + false, + ); call_error_to_structured_error(err) } } @@ -550,10 +591,7 @@ mod tests { #[tokio::test] async fn list_tools_returns_exactly_four_gateway_tools() { - let _gateway = ToMcpGateway::new(dispatch( - full_registry_with_ops(vec![]), - provider(), - )); + let _gateway = ToMcpGateway::new(dispatch(full_registry_with_ops(vec![]), provider())); let tools = gateway_tools(); let names: Vec = tools.iter().map(|t| t.name.to_string()).collect(); assert_eq!(names.len(), 4); @@ -583,7 +621,11 @@ mod tests { #[tokio::test] async fn search_returns_access_control_filtered_ops_excluding_subscriptions() { let registry = full_registry_with_ops(vec![ - ("public/echo".to_string(), OperationType::Query, AccessControl::default()), + ( + "public/echo".to_string(), + OperationType::Query, + AccessControl::default(), + ), ( "admin/secret".to_string(), OperationType::Query, @@ -592,13 +634,22 @@ mod tests { ..Default::default() }, ), - ("events/stream".to_string(), OperationType::Subscription, AccessControl::default()), + ( + "events/stream".to_string(), + OperationType::Subscription, + AccessControl::default(), + ), ]); let idp: Arc = Arc::new(StaticIdentityProvider::new()); let gateway = ToMcpGateway::new(dispatch(registry, idp)); - let result = invoke_tool(&gateway, "search", None, Some(identity_with_scopes("user", &["user"]))) - .await; + let result = invoke_tool( + &gateway, + "search", + None, + Some(identity_with_scopes("user", &["user"])), + ) + .await; assert_eq!(result.is_error, Some(false)); let structured = result.structured_content.expect("structured present"); let ops = structured @@ -610,11 +661,23 @@ mod tests { .filter_map(|o| o.get("name").and_then(Value::as_str)) .collect(); assert!(names.contains(&"public/echo")); - assert!(!names.contains(&"admin/secret"), "ACL-filtered op must not appear"); - assert!(!names.contains(&"events/stream"), "Subscription op must be excluded"); + assert!( + !names.contains(&"admin/secret"), + "ACL-filtered op must not appear" + ); + assert!( + !names.contains(&"events/stream"), + "Subscription op must be excluded" + ); for op in ops { - assert!(op.get("description").is_some(), "each entry has a description"); - assert!(op.get("input_schema").is_none(), "search must not return full schemas"); + assert!( + op.get("description").is_some(), + "each entry has a description" + ); + assert!( + op.get("input_schema").is_none(), + "search must not return full schemas" + ); } } @@ -632,7 +695,10 @@ mod tests { let result = invoke_tool(&gateway, "schema", Some(args), None).await; assert_eq!(result.is_error, Some(false)); let structured = result.structured_content.expect("structured present"); - assert_eq!(structured.get("name"), Some(&Value::String("fs/readFile".to_string()))); + assert_eq!( + structured.get("name"), + Some(&Value::String("fs/readFile".to_string())) + ); assert!(structured.get("input_schema").is_some()); assert!(structured.get("output_schema").is_some()); assert!(structured.get("error_schemas").is_some()); @@ -649,7 +715,10 @@ mod tests { let gateway = ToMcpGateway::new(dispatch(registry, provider())); let mut args = Map::new(); - args.insert("operation".to_string(), Value::String("echo/run".to_string())); + args.insert( + "operation".to_string(), + Value::String("echo/run".to_string()), + ); args.insert("input".to_string(), serde_json::json!({ "msg": "hi" })); let result = invoke_tool(&gateway, "call", Some(args), None).await; assert_eq!(result.is_error, Some(false)); @@ -665,12 +734,18 @@ mod tests { let gateway = ToMcpGateway::new(dispatch(registry, provider())); let mut args = Map::new(); - args.insert("operation".to_string(), Value::String("no/such".to_string())); + args.insert( + "operation".to_string(), + Value::String("no/such".to_string()), + ); args.insert("input".to_string(), Value::Object(Map::new())); let result = invoke_tool(&gateway, "call", Some(args), None).await; assert_eq!(result.is_error, Some(true)); let structured = result.structured_content.expect("structured error present"); - assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string()))); + assert_eq!( + structured.get("code"), + Some(&Value::String("NOT_FOUND".to_string())) + ); } #[tokio::test] @@ -713,12 +788,18 @@ mod tests { let gateway = ToMcpGateway::new(dispatch(registry, idp)); let mut args = Map::new(); - args.insert("operation".to_string(), Value::String("admin/run".to_string())); + args.insert( + "operation".to_string(), + Value::String("admin/run".to_string()), + ); args.insert("input".to_string(), Value::Object(Map::new())); let result = invoke_tool(&gateway, "call", Some(args), None).await; assert_eq!(result.is_error, Some(true)); let structured = result.structured_content.expect("structured error present"); - assert_eq!(structured.get("code"), Some(&Value::String("FORBIDDEN".to_string()))); + assert_eq!( + structured.get("code"), + Some(&Value::String("FORBIDDEN".to_string())) + ); } #[tokio::test] @@ -727,7 +808,10 @@ mod tests { let result = invoke_tool(&gateway, "bogus", None, None).await; assert_eq!(result.is_error, Some(true)); let structured = result.structured_content.expect("structured error present"); - assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string()))); + assert_eq!( + structured.get("code"), + Some(&Value::String("NOT_FOUND".to_string())) + ); } #[tokio::test] @@ -749,10 +833,16 @@ mod tests { let admin_identity = identity_with_scopes("admin-peer", &["admin"]); let extensions = extensions_with_identity(Some(admin_identity.clone())); let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions); - assert_eq!(extracted.as_ref().map(|i| &i.id), Some(&"admin-peer".to_string())); + assert_eq!( + extracted.as_ref().map(|i| &i.id), + Some(&"admin-peer".to_string()) + ); let mut args = Map::new(); - args.insert("operation".to_string(), Value::String("admin/run".to_string())); + args.insert( + "operation".to_string(), + Value::String("admin/run".to_string()), + ); args.insert("input".to_string(), serde_json::json!({ "ok": 1 })); let result = gateway.handle_call(Some(args), extracted).await; assert_eq!(result.is_error, Some(false)); @@ -779,7 +869,10 @@ mod tests { let id = identity_with_scopes("caller", &["read"]); let extensions = extensions_with_identity(Some(id.clone())); let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions); - assert_eq!(extracted.as_ref().map(|i| i.id.clone()), Some("caller".to_string())); + assert_eq!( + extracted.as_ref().map(|i| i.id.clone()), + Some("caller".to_string()) + ); assert_eq!( extracted.as_ref().map(|i| i.scopes.clone()), Some(vec!["read".to_string()]) @@ -834,12 +927,18 @@ mod tests { ); let mut call_args = Map::new(); - call_args.insert("operation".to_string(), Value::String(first_name.to_string())); - call_args.insert("input".to_string(), serde_json::json!({ "path": "/etc/hosts" })); + call_args.insert( + "operation".to_string(), + Value::String(first_name.to_string()), + ); + call_args.insert( + "input".to_string(), + serde_json::json!({ "path": "/etc/hosts" }), + ); let call_result = invoke_tool(&gateway, "call", Some(call_args), None).await; assert_eq!( call_result.structured_content, Some(serde_json::json!({ "path": "/etc/hosts" })) ); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/client/http_client.rs b/crates/alknet-http/src/client/http_client.rs index c0155dd..4b27ef5 100644 --- a/crates/alknet-http/src/client/http_client.rs +++ b/crates/alknet-http/src/client/http_client.rs @@ -125,10 +125,11 @@ fn build_client(config: &HttpClientConfig) -> Result Result usize { - self.deadlines.lock().expect("deadlines mutex poisoned").len() + self.deadlines + .lock() + .expect("deadlines mutex poisoned") + .len() } #[cfg(test)] @@ -156,8 +159,8 @@ mod tests { #[test] fn parse_retry_after_http_date() { - let deadline = parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT") - .expect("HTTP-date value parses"); + let deadline = + parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT").expect("HTTP-date value parses"); assert!(deadline > SystemTime::now()); } @@ -272,7 +275,10 @@ mod tests { async fn middleware_sleeps_before_request_with_active_deadline() { let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); let target = url("https://api.example.com/v1/chat"); - mw.record_test(target.clone(), SystemTime::now() + Duration::from_millis(50)); + mw.record_test( + target.clone(), + SystemTime::now() + Duration::from_millis(50), + ); let started = SystemTime::now(); mw.maybe_sleep_for(&target).await; let elapsed = SystemTime::now().duration_since(started).unwrap(); @@ -281,4 +287,4 @@ mod tests { "middleware must sleep until the deadline elapses" ); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/gateway/dispatch.rs b/crates/alknet-http/src/gateway/dispatch.rs index 1a9446f..96c638f 100644 --- a/crates/alknet-http/src/gateway/dispatch.rs +++ b/crates/alknet-http/src/gateway/dispatch.rs @@ -83,11 +83,7 @@ impl GatewayDispatch { r.capabilities.clone(), r.scoped_env.clone().unwrap_or_else(ScopedPeerEnv::empty), ), - None => ( - None, - Capabilities::new(), - ScopedPeerEnv::empty(), - ), + None => (None, Capabilities::new(), ScopedPeerEnv::empty()), }; let env: Arc = @@ -254,10 +250,7 @@ mod tests { .invoke(None, "echo/run", serde_json::json!({ "msg": "hi" })) .await; assert!(response.result.is_ok()); - assert_eq!( - response.result.unwrap(), - serde_json::json!({ "msg": "hi" }) - ); + assert_eq!(response.result.unwrap(), serde_json::json!({ "msg": "hi" })); } #[tokio::test] @@ -270,9 +263,7 @@ mod tests { let provider: Arc = Arc::new(StaticIdentityProvider::new()); let dp = dispatch(registry, provider); - let response = dp - .invoke(None, "/echo/run", serde_json::json!({})) - .await; + let response = dp.invoke(None, "/echo/run", serde_json::json!({})).await; assert!(response.result.is_ok()); } @@ -369,9 +360,7 @@ mod tests { let provider: Arc = Arc::new(StaticIdentityProvider::new()); let dp = dispatch(registry, provider); - let response = dp - .invoke(None, "no/such", serde_json::json!({})) - .await; + let response = dp.invoke(None, "no/such", serde_json::json!({})).await; match response.result { Err(e) => { assert_eq!(e.code, "NOT_FOUND"); @@ -398,9 +387,7 @@ mod tests { let provider: Arc = Arc::new(StaticIdentityProvider::new()); let dp = dispatch(registry, provider); - let response = dp - .invoke(None, "secret/op", serde_json::json!({})) - .await; + let response = dp.invoke(None, "secret/op", serde_json::json!({})).await; match response.result { Err(e) => { assert_eq!(e.code, "NOT_FOUND"); @@ -423,9 +410,7 @@ mod tests { let provider: Arc = Arc::new(StaticIdentityProvider::new()); let dp = dispatch(registry, provider); - let response = dp - .invoke(None, "admin/run", serde_json::json!({})) - .await; + let response = dp.invoke(None, "admin/run", serde_json::json!({})).await; match response.result { Err(e) => { assert_eq!(e.code, "FORBIDDEN"); @@ -506,8 +491,10 @@ mod tests { #[test] fn build_root_context_carries_registration_bundle_fields() { - let authority = - alknet_call::registry::context::CompositionAuthority::new("agent", ["fs:read".to_string()]); + let authority = alknet_call::registry::context::CompositionAuthority::new( + "agent", + ["fs:read".to_string()], + ); let scoped = ScopedPeerEnv::new(["fs/readFile"]); let caps = Capabilities::new().with_api_key("google", "k".to_string()); @@ -545,4 +532,4 @@ mod tests { fn assert_concrete() {} assert_concrete::(); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/gateway/error.rs b/crates/alknet-http/src/gateway/error.rs index b07674e..d848109 100644 --- a/crates/alknet-http/src/gateway/error.rs +++ b/crates/alknet-http/src/gateway/error.rs @@ -31,7 +31,10 @@ pub fn call_error_to_http_status(error: &CallError) -> u16 { call_error_to_http_status_with_identity(error, None) } -pub fn call_error_to_http_status_with_identity(error: &CallError, identity: Option<&Identity>) -> u16 { +pub fn call_error_to_http_status_with_identity( + error: &CallError, + identity: Option<&Identity>, +) -> u16 { match error.code.as_str() { PROTOCOL_CODE_NOT_FOUND => STATUS_NOT_FOUND, PROTOCOL_CODE_FORBIDDEN => { @@ -59,8 +62,8 @@ pub fn call_error_to_http_response(error: &CallError) -> Response { let retry_after = retry_after_value(error, status_code); if let Some(retry_after) = retry_after { - let header_value = HeaderValue::from_str(&retry_after) - .unwrap_or_else(|_| HeaderValue::from_static("0")); + let header_value = + HeaderValue::from_str(&retry_after).unwrap_or_else(|_| HeaderValue::from_static("0")); (status, [(header::RETRY_AFTER, header_value)], Json(body)).into_response() } else { (status, Json(body)).into_response() @@ -139,7 +142,10 @@ mod tests { fn forbidden_with_some_identity_maps_to_403() { let error = CallError::forbidden("insufficient scopes"); let id = identity(); - assert_eq!(call_error_to_http_status_with_identity(&error, Some(&id)), 403); + assert_eq!( + call_error_to_http_status_with_identity(&error, Some(&id)), + 403 + ); } #[test] @@ -213,7 +219,10 @@ mod tests { let error = CallError::new("HTTP_503", "slow down", true); let response = call_error_to_http_response(&error); assert_eq!(response.status(), StatusCode::from_u16(503).unwrap()); - assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); + assert!(response + .headers() + .get(axum::http::header::RETRY_AFTER) + .is_none()); } #[test] @@ -221,7 +230,10 @@ mod tests { let error = CallError::new("HTTP_503", "down", false) .with_details(serde_json::json!({ "retry_after": "5" })); let response = call_error_to_http_response(&error); - assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); + assert!(response + .headers() + .get(axum::http::header::RETRY_AFTER) + .is_none()); } #[test] @@ -241,7 +253,10 @@ mod tests { let error = CallError::timeout("timed out"); let response = call_error_to_http_response(&error); assert_eq!(response.status(), StatusCode::from_u16(504).unwrap()); - assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); + assert!(response + .headers() + .get(axum::http::header::RETRY_AFTER) + .is_none()); } #[test] @@ -266,4 +281,4 @@ mod tests { ); assert_eq!(call_error_to_http_status_with_identity(&error, None), 404); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/server/adapter.rs b/crates/alknet-http/src/server/adapter.rs index 51920d8..0e80839 100644 --- a/crates/alknet-http/src/server/adapter.rs +++ b/crates/alknet-http/src/server/adapter.rs @@ -33,12 +33,12 @@ use super::auth::bearer_auth_middleware; use super::decoy::decoy_fallback; use super::gateway_routes; use super::healthz::healthz; -use crate::websocket::upgrade::ws_upgrade_handler; -use crate::websocket::upgrade::WS_UPGRADE_PATH; #[cfg(feature = "mcp")] use crate::adapters::to_mcp_service; #[cfg(feature = "mcp")] use crate::gateway::GatewayDispatch; +use crate::websocket::upgrade::ws_upgrade_handler; +use crate::websocket::upgrade::WS_UPGRADE_PATH; const ALPN_HTTP1: &[u8] = b"http/1.1"; const ALPN_H2: &[u8] = b"h2"; @@ -47,8 +47,12 @@ const ALPN_H2: &[u8] = b"h2"; pub enum DecoyConfig { #[default] NotFound, - StaticSite { root: PathBuf }, - Redirect { to: String }, + StaticSite { + root: PathBuf, + }, + Redirect { + to: String, + }, } #[derive(Clone)] @@ -87,11 +91,17 @@ pub struct HttpAdapter { } impl HttpAdapter { - pub fn new(identity_provider: Arc, registry: Arc) -> Self { + pub fn new( + identity_provider: Arc, + registry: Arc, + ) -> Self { Self::for_alpn(identity_provider, registry, ALPN_HTTP1) } - pub fn h2(identity_provider: Arc, registry: Arc) -> Self { + pub fn h2( + identity_provider: Arc, + registry: Arc, + ) -> Self { Self::for_alpn(identity_provider, registry, ALPN_H2) } @@ -163,7 +173,10 @@ fn build_router(state: RouterState, extra_routes: Option) -> Router { )); Router::new() .nest_service("/mcp", to_mcp_service(dispatch)) - .layer(from_fn_with_state(auth_state.clone(), bearer_auth_middleware)) + .layer(from_fn_with_state( + auth_state.clone(), + bearer_auth_middleware, + )) }; #[cfg(not(feature = "mcp"))] let mcp_router: Router = Router::new(); @@ -172,7 +185,10 @@ fn build_router(state: RouterState, extra_routes: Option) -> Router { .merge(gateway_routes::gateway_router()) .route("/openapi.json", get(not_implemented)) .route(WS_UPGRADE_PATH, get(ws_upgrade_handler)) - .route_layer(from_fn_with_state(auth_state.clone(), bearer_auth_middleware)) + .route_layer(from_fn_with_state( + auth_state.clone(), + bearer_auth_middleware, + )) .route("/healthz", get(healthz)) .fallback(decoy_fallback) .merge(mcp_router); @@ -203,7 +219,10 @@ impl ProtocolHandler for HttpAdapter { let _ = connection.set_identity(identity); } - let (send, recv) = connection.accept_bi().await.map_err(stream_error_to_handler)?; + let (send, recv) = connection + .accept_bi() + .await + .map_err(stream_error_to_handler)?; let io = QuicStream::new(send, recv); self.serve_io(io).await } @@ -295,7 +314,10 @@ mod tests { fn resolve_from_fingerprint(&self, _: &str) -> Option { None } - fn resolve_from_token(&self, _: &alknet_core::auth::AuthToken) -> Option { + fn resolve_from_token( + &self, + _: &alknet_core::auth::AuthToken, + ) -> Option { None } } @@ -341,7 +363,9 @@ mod tests { #[test] fn with_decoy_updates_decoy() { let adapter = HttpAdapter::new(provider(), empty_registry()); - let adapter = adapter.with_decoy(DecoyConfig::Redirect { to: "https://example.com".to_string() }); + let adapter = adapter.with_decoy(DecoyConfig::Redirect { + to: "https://example.com".to_string(), + }); assert!(matches!(adapter.decoy(), DecoyConfig::Redirect { .. })); } @@ -386,7 +410,10 @@ mod tests { ) -> (String, tokio::task::JoinHandle<()>) { let (mut client_send, server_recv) = duplex(8 * 1024); let (server_send, mut client_recv) = duplex(8 * 1024); - let server_io = QuicStreamDuplex { read: server_recv, write: server_send }; + let server_io = QuicStreamDuplex { + read: server_recv, + write: server_send, + }; let adapter = HttpAdapter::new(provider(), empty_registry()); let handle = tokio::spawn(async move { @@ -399,7 +426,12 @@ mod tests { let mut response = Vec::new(); let mut buf = [0u8; 4096]; loop { - match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await { + match tokio::time::timeout( + std::time::Duration::from_secs(5), + client_recv.read(&mut buf), + ) + .await + { Ok(Ok(0)) => break, Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), Ok(Err(_)) => break, @@ -455,21 +487,24 @@ mod tests { let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; let (response, handle) = send_request_and_read_response(request).await; handle.await.ok(); - assert!(response.starts_with("HTTP/1.1 200 "), "expected 200, got: {response}"); + assert!( + response.starts_with("HTTP/1.1 200 "), + "expected 200, got: {response}" + ); assert!(response.contains("\r\n\r\nok")); } #[tokio::test] async fn custom_route_v1_foo_coexists_with_default_surface() { - let extra = Router::new().route( - "/v1/foo", - get(|| async { (StatusCode::OK, "foo-body") }), - ); + let extra = Router::new().route("/v1/foo", get(|| async { (StatusCode::OK, "foo-body") })); let adapter = HttpAdapter::new(provider(), empty_registry()).with_extra_routes(extra); let (mut client_send, server_recv) = duplex(8 * 1024); let (server_send, mut client_recv) = duplex(8 * 1024); - let server_io = QuicStreamDuplex { read: server_recv, write: server_send }; + let server_io = QuicStreamDuplex { + read: server_recv, + write: server_send, + }; let handle = tokio::spawn(async move { adapter.serve_io(server_io).await.ok(); @@ -482,7 +517,12 @@ mod tests { let mut response = Vec::new(); let mut buf = [0u8; 4096]; loop { - match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await { + match tokio::time::timeout( + std::time::Duration::from_secs(5), + client_recv.read(&mut buf), + ) + .await + { Ok(Ok(0)) => break, Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), Ok(Err(_)) => break, @@ -491,7 +531,10 @@ mod tests { } handle.await.ok(); let response_str = String::from_utf8_lossy(&response); - assert!(response_str.starts_with("HTTP/1.1 200 "), "expected 200, got: {response_str}"); + assert!( + response_str.starts_with("HTTP/1.1 200 "), + "expected 200, got: {response_str}" + ); assert!(response_str.contains("foo-body")); } @@ -505,7 +548,10 @@ mod tests { let (mut client_send, server_recv) = duplex(8 * 1024); let (server_send, mut client_recv) = duplex(8 * 1024); - let server_io = QuicStreamDuplex { read: server_recv, write: server_send }; + let server_io = QuicStreamDuplex { + read: server_recv, + write: server_send, + }; let handle = tokio::spawn(async move { adapter.serve_io(server_io).await.ok(); @@ -518,7 +564,12 @@ mod tests { let mut response = Vec::new(); let mut buf = [0u8; 4096]; loop { - match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await { + match tokio::time::timeout( + std::time::Duration::from_secs(5), + client_recv.read(&mut buf), + ) + .await + { Ok(Ok(0)) => break, Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), Ok(Err(_)) => break, @@ -527,7 +578,10 @@ mod tests { } handle.await.ok(); let response_str = String::from_utf8_lossy(&response); - assert!(response_str.starts_with("HTTP/1.1 200 "), "default GET /healthz wins, got: {response_str}"); + assert!( + response_str.starts_with("HTTP/1.1 200 "), + "default GET /healthz wins, got: {response_str}" + ); assert!(response_str.contains("\r\n\r\nok")); assert!(!response_str.contains("custom-healthz")); } @@ -547,7 +601,12 @@ mod tests { let mut response = Vec::new(); let mut buf = [0u8; 4096]; loop { - match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await { + match tokio::time::timeout( + std::time::Duration::from_secs(5), + client_recv.read(&mut buf), + ) + .await + { Ok(Ok(0)) => break, Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), Ok(Err(_)) => break, @@ -569,7 +628,10 @@ mod tests { .with_extra_routes(extra); let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nContent-Length: 0\r\n\r\n"; let response = serve_and_read(adapter, request).await; - assert!(response.starts_with("HTTP/1.1 200"), "expected 200, got: {response}"); + assert!( + response.starts_with("HTTP/1.1 200"), + "expected 200, got: {response}" + ); assert!(response.contains("oai-proxy")); assert!(!response.contains("404 Not Found")); } @@ -583,32 +645,43 @@ mod tests { let adapter = HttpAdapter::new(provider(), empty_registry()) .with_decoy(DecoyConfig::NotFound) .with_extra_routes(extra); - let request = b"GET /totally/unknown HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + let request = + b"GET /totally/unknown HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; let response = serve_and_read(adapter, request).await; - assert!(response.starts_with("HTTP/1.1 404"), "expected 404 decoy, got: {response}"); + assert!( + response.starts_with("HTTP/1.1 404"), + "expected 404 decoy, got: {response}" + ); assert!(response.contains("404 Not Found")); } #[tokio::test] async fn healthz_takes_precedence_over_decoy() { - let adapter = HttpAdapter::new(provider(), empty_registry()) - .with_decoy(DecoyConfig::Redirect { + let adapter = + HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect { to: "https://example.com".to_string(), }); let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; let response = serve_and_read(adapter, request).await; - assert!(response.starts_with("HTTP/1.1 200"), "expected 200 healthz, got: {response}"); + assert!( + response.starts_with("HTTP/1.1 200"), + "expected 200 healthz, got: {response}" + ); assert!(response.contains("\r\n\r\nok")); } #[tokio::test] async fn unknown_path_with_redirect_decoy_returns_redirect_over_wire() { - let adapter = HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect { - to: "https://example.com".to_string(), - }); + let adapter = + HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect { + to: "https://example.com".to_string(), + }); let request = b"GET /nope HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; let response = serve_and_read(adapter, request).await; - assert!(response.starts_with("HTTP/1.1 302"), "expected 302 redirect, got: {response}"); + assert!( + response.starts_with("HTTP/1.1 302"), + "expected 302 redirect, got: {response}" + ); assert!(response.contains("location: https://example.com")); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/server/auth.rs b/crates/alknet-http/src/server/auth.rs index 137ead5..e212e40 100644 --- a/crates/alknet-http/src/server/auth.rs +++ b/crates/alknet-http/src/server/auth.rs @@ -80,11 +80,12 @@ where { type Rejection = Infallible; - async fn from_request_parts( - parts: &mut Parts, - _state: &S, - ) -> Result { - let identity = parts.extensions.get::>().cloned().flatten(); + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let identity = parts + .extensions + .get::>() + .cloned() + .flatten(); Ok(ResolvedIdentity(identity)) } } @@ -174,15 +175,16 @@ mod tests { assert!(identity.is_none()); } - async fn run_middleware( - idp: Arc, - request: Request, - ) -> Response { + async fn run_middleware(idp: Arc, request: Request) -> Response { let app: Router<()> = Router::new() .route( "/", get(|req: Request| async move { - let identity = req.extensions().get::>().cloned().flatten(); + let identity = req + .extensions() + .get::>() + .cloned() + .flatten(); if let Some(id) = identity { (StatusCode::OK, id.id) } else { @@ -261,14 +263,12 @@ mod tests { let app: Router<()> = Router::new() .route( "/", - get( - |ResolvedIdentity(identity): ResolvedIdentity| async move { - match identity { - Some(id) => (StatusCode::OK, id.id), - None => (StatusCode::OK, "none".to_string()), - } - }, - ), + get(|ResolvedIdentity(identity): ResolvedIdentity| async move { + match identity { + Some(id) => (StatusCode::OK, id.id), + None => (StatusCode::OK, "none".to_string()), + } + }), ) .layer(from_fn_with_state(idp, bearer_auth_middleware)); @@ -287,14 +287,12 @@ mod tests { let app: Router<()> = Router::new() .route( "/", - get( - |ResolvedIdentity(identity): ResolvedIdentity| async move { - match identity { - Some(id) => (StatusCode::OK, id.id), - None => (StatusCode::OK, "none".to_string()), - } - }, - ), + get(|ResolvedIdentity(identity): ResolvedIdentity| async move { + match identity { + Some(id) => (StatusCode::OK, id.id), + None => (StatusCode::OK, "none".to_string()), + } + }), ) .layer(from_fn_with_state(idp, bearer_auth_middleware)); @@ -306,4 +304,4 @@ mod tests { .unwrap(); assert_eq!(&bytes[..], b"none"); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/server/decoy.rs b/crates/alknet-http/src/server/decoy.rs index 58cb514..96c76c0 100644 --- a/crates/alknet-http/src/server/decoy.rs +++ b/crates/alknet-http/src/server/decoy.rs @@ -33,10 +33,8 @@ pub fn fake_nginx_404() -> Response { header::CONTENT_TYPE, HeaderValue::from_static("text/html; charset=utf-8"), ); - resp.headers_mut().insert( - header::SERVER, - HeaderValue::from_static("nginx"), - ); + resp.headers_mut() + .insert(header::SERVER, HeaderValue::from_static("nginx")); resp } @@ -61,10 +59,8 @@ pub async fn serve_static(root: &Path, request: Request) -> Response { let content_type = mime_for_path(&resolved); let mut resp = Response::new(Body::from(bytes)); *resp.status_mut() = StatusCode::OK; - resp.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(content_type), - ); + resp.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type)); resp } Err(_) => fake_nginx_404(), @@ -173,10 +169,7 @@ mod tests { async fn send(router: axum::Router, uri: &str) -> axum::response::Response { tower::ServiceExt::>::oneshot( router, - Request::builder() - .uri(uri) - .body(Body::empty()) - .unwrap(), + Request::builder().uri(uri).body(Body::empty()).unwrap(), ) .await .unwrap() @@ -220,9 +213,7 @@ mod tests { async fn unknown_path_with_static_site_decoy_serves_file() { let dir = tempfile_dir(); let file = dir.join("index.html"); - tokio::fs::write(&file, "

hello

") - .await - .unwrap(); + tokio::fs::write(&file, "

hello

").await.unwrap(); let decoy = DecoyConfig::StaticSite { root: dir.clone() }; let resp = send(decoy_router(decoy), "/").await; @@ -293,11 +284,9 @@ mod tests { } fn tempfile_dir() -> PathBuf { - let dir = PathBuf::from("/tmp").join(format!( - "alknet-http-decoy-test-{}", - uuid::Uuid::new_v4() - )); + let dir = + PathBuf::from("/tmp").join(format!("alknet-http-decoy-test-{}", uuid::Uuid::new_v4())); std::fs::create_dir_all(&dir).unwrap(); dir } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/server/gateway_routes.rs b/crates/alknet-http/src/server/gateway_routes.rs index 595f7c7..4b7b8bc 100644 --- a/crates/alknet-http/src/server/gateway_routes.rs +++ b/crates/alknet-http/src/server/gateway_routes.rs @@ -52,13 +52,19 @@ impl GatewayState { } fn dispatch(&self) -> GatewayDispatch { - GatewayDispatch::new(Arc::clone(&self.registry), Arc::clone(&self.identity_provider)) + GatewayDispatch::new( + Arc::clone(&self.registry), + Arc::clone(&self.identity_provider), + ) } } impl FromRef for GatewayState { fn from_ref(state: &RouterState) -> Self { - GatewayState::new(Arc::clone(&state.registry), Arc::clone(&state.identity_provider)) + GatewayState::new( + Arc::clone(&state.registry), + Arc::clone(&state.identity_provider), + ) } } @@ -92,7 +98,9 @@ pub(crate) async fn call_handler( return not_found_response(&request.operation); } let dispatch = state.dispatch(); - let envelope = dispatch.invoke(identity.clone(), &request.operation, request.input).await; + let envelope = dispatch + .invoke(identity.clone(), &request.operation, request.input) + .await; envelope_to_response(envelope, identity.as_ref()) } @@ -101,7 +109,9 @@ pub(crate) async fn search_handler( ResolvedIdentity(identity): ResolvedIdentity, ) -> Response { let dispatch = state.dispatch(); - let envelope = dispatch.invoke(identity.clone(), SERVICES_LIST, json!({})).await; + let envelope = dispatch + .invoke(identity.clone(), SERVICES_LIST, json!({})) + .await; envelope_to_response(envelope, identity.as_ref()) } @@ -115,7 +125,11 @@ pub(crate) async fn schema_handler( } let dispatch = state.dispatch(); let envelope = dispatch - .invoke(identity.clone(), SERVICES_SCHEMA, json!({ "name": query.name })) + .invoke( + identity.clone(), + SERVICES_SCHEMA, + json!({ "name": query.name }), + ) .await; envelope_to_response(envelope, identity.as_ref()) } @@ -149,7 +163,9 @@ pub(crate) async fn subscribe_handler( subscribe_stream_internal_error(request.operation) } else { let dispatch = state.dispatch(); - let envelope = dispatch.invoke(identity, &request.operation, request.input).await; + let envelope = dispatch + .invoke(identity, &request.operation, request.input) + .await; subscribe_stream_from_envelope(envelope) }; Sse::new(stream) @@ -221,8 +237,7 @@ fn not_found_response(operation: &str) -> Response { fn forbidden_response(message: String, identity: Option<&Identity>) -> Response { let error = CallError::forbidden(message); let status_code = call_error_to_http_status_with_identity(&error, identity); - let status = - StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); let body = serde_json::to_value(&error).unwrap_or(Value::Null); (status, Json(body)).into_response() } @@ -248,7 +263,9 @@ fn is_internal_op(registry: &OperationRegistry, operation: &str) -> bool { } } -fn envelope_to_sse_stream(envelope: ResponseEnvelope) -> impl Stream> { +fn envelope_to_sse_stream( + envelope: ResponseEnvelope, +) -> impl Stream> { stream::once(async move { match envelope.result { Ok(output) => { @@ -756,7 +773,10 @@ mod tests { .get(axum::http::header::CONTENT_TYPE) .map(|v| v.to_str().unwrap().to_string()); assert!( - ctype.as_deref().unwrap_or("").starts_with("text/event-stream"), + ctype + .as_deref() + .unwrap_or("") + .starts_with("text/event-stream"), "expected text/event-stream, got {ctype:?}" ); let bytes = resp.into_body().collect().await.unwrap().to_bytes(); @@ -950,4 +970,4 @@ mod tests { assert_eq!(status, StatusCode::NOT_FOUND); assert_eq!(body.get("code"), Some(&json!("NOT_FOUND"))); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/server/healthz.rs b/crates/alknet-http/src/server/healthz.rs index 7a23b20..8d101c7 100644 --- a/crates/alknet-http/src/server/healthz.rs +++ b/crates/alknet-http/src/server/healthz.rs @@ -59,4 +59,4 @@ mod tests { let resp = call_healthz(req).await; assert_eq!(resp.status(), StatusCode::OK); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/src/websocket/mod.rs b/crates/alknet-http/src/websocket/mod.rs index 8712f97..e594e15 100644 --- a/crates/alknet-http/src/websocket/mod.rs +++ b/crates/alknet-http/src/websocket/mod.rs @@ -128,7 +128,10 @@ mod tests { let out: EventEnvelope = response.into(); assert_eq!(out.r#type, EVENT_RESPONDED); assert_eq!(out.id, "ws-rt-1"); - assert_eq!(out.payload.get("output"), Some(&serde_json::json!({ "v": 7 }))); + assert_eq!( + out.payload.get("output"), + Some(&serde_json::json!({ "v": 7 })) + ); } #[tokio::test] @@ -160,7 +163,10 @@ mod tests { async fn ws_overlay_only_connection_holds_overlay_and_pending() { let conn = CallConnection::new_overlay_only(identity("ws-peer")); assert!(conn.connection().is_none()); - assert_eq!(conn.identity().map(|i| i.id.clone()), Some("ws-peer".to_string())); + assert_eq!( + conn.identity().map(|i| i.id.clone()), + Some("ws-peer".to_string()) + ); assert!(conn.pending().lock().is_empty()); let env = conn.overlay_env(); diff --git a/crates/alknet-http/src/websocket/upgrade.rs b/crates/alknet-http/src/websocket/upgrade.rs index 4d89dff..cef07bf 100644 --- a/crates/alknet-http/src/websocket/upgrade.rs +++ b/crates/alknet-http/src/websocket/upgrade.rs @@ -84,8 +84,9 @@ async fn ws_upgrade_handler_inner( }; match ws_upgrade { - Some(upgrade) => upgrade - .on_upgrade(move |socket| run_ws_session(socket, registry, identity_provider, identity)), + Some(upgrade) => upgrade.on_upgrade(move |socket| { + run_ws_session(socket, registry, identity_provider, identity) + }), None => { let _ = registry; let _ = identity_provider; @@ -240,19 +241,19 @@ fn serialize_envelope(envelope: &EventEnvelope) -> Result, serde_json::E #[cfg(test)] mod tests { use super::*; + use alknet_call::registry::context::{ + AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv, + }; use alknet_call::registry::discovery::{ services_list_handler, services_list_spec, services_schema_handler, services_schema_spec, }; + use alknet_call::registry::env::OperationEnv; use alknet_call::registry::registration::{ make_handler, HandlerRegistration, OperationProvenance, }; use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility}; use alknet_core::auth::{AuthToken, Identity}; use alknet_core::types::Capabilities; - use alknet_call::registry::context::{ - AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv, - }; - use alknet_call::registry::env::OperationEnv; use std::collections::HashMap; use std::sync::Mutex as StdMutex; use std::time::{Duration, Instant}; @@ -331,9 +332,7 @@ mod tests { let mut registry = OperationRegistry::new(); registry.register(HandlerRegistration::new( external_spec("echo/run", AccessControl::default()), - make_handler(|input, ctx| async move { - ResponseEnvelope::ok(ctx.request_id, input) - }), + make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }), OperationProvenance::Local, None, None, @@ -352,9 +351,7 @@ mod tests { ..Default::default() }, ), - make_handler(|input, ctx| async move { - ResponseEnvelope::ok(ctx.request_id, input) - }), + make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }), OperationProvenance::Local, None, None, @@ -519,9 +516,8 @@ mod tests { #[tokio::test] async fn handle_inbound_envelope_forbidden_yields_call_error() { let registry = registry_with_restricted_op(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("none", identity("unpriv")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("none", identity("unpriv"))); let dp = dispatcher(registry, provider); let conn = Arc::new(CallConnection::new_overlay_only(identity("unpriv"))); @@ -727,9 +723,8 @@ mod tests { #[tokio::test] async fn round_trip_call_requested_to_call_responded_over_ws_message_stream() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -753,9 +748,8 @@ mod tests { #[tokio::test] async fn subscription_streams_multiple_call_responded_events() { let registry = registry_with_subscription(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(registry, provider); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -782,8 +776,10 @@ mod tests { .with_token("no-admin", identity_with_scopes("user", &["user"])), ); let dp = dispatcher(registry, provider); - let conn = - Arc::new(CallConnection::new_overlay_only(identity_with_scopes("user", &["user"]))); + let conn = Arc::new(CallConnection::new_overlay_only(identity_with_scopes( + "user", + &["user"], + ))); let request = EventEnvelope::requested( "req-admin", @@ -882,8 +878,10 @@ mod tests { let overlay_env = conn.overlay_env(); assert!(overlay_env.contains("ui/dragged")); - let composed_env: Arc = dp - .compose_root_env(&conn, &root_context_for_compose("hub-call-1", overlay_env.clone())); + let composed_env: Arc = dp.compose_root_env( + &conn, + &root_context_for_compose("hub-call-1", overlay_env.clone()), + ); let ctx = root_context_with_env("hub-call-1", composed_env); let response = overlay_env .invoke("ui", "dragged", serde_json::json!({ "x": 5 }), &ctx) @@ -935,9 +933,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_round_trips_binary_call_requested_to_call_responded() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -960,7 +957,10 @@ mod tests { let env: EventEnvelope = serde_json::from_slice(&bytes).unwrap(); assert_eq!(env.r#type, EVENT_RESPONDED); assert_eq!(env.id, "ws-socket-1"); - assert_eq!(env.payload.get("output"), Some(&serde_json::json!({ "v": 7 }))); + assert_eq!( + env.payload.get("output"), + Some(&serde_json::json!({ "v": 7 })) + ); } other => panic!("expected binary, got {other:?}"), } @@ -972,9 +972,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_rejects_text_with_protocol_close() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -999,9 +998,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_disconnect_aborts_in_flight_pending() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -1036,9 +1034,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_subscription_streams_call_responded_events() { let registry = registry_with_subscription(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -1077,9 +1074,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_invalid_binary_closes_with_protocol_error() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -1102,9 +1098,8 @@ mod tests { #[tokio::test] async fn drive_ws_session_client_close_terminates_server() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider)); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); @@ -1164,17 +1159,11 @@ mod tests { } async fn send_text(&mut self, text: String) { - self.outbound_tx - .send(Message::Text(text.into())) - .await - .ok(); + self.outbound_tx.send(Message::Text(text.into())).await.ok(); } async fn send_close(&mut self) { - self.outbound_tx - .send(Message::Close(None)) - .await - .ok(); + self.outbound_tx.send(Message::Close(None)).await.ok(); } async fn close(&mut self) { @@ -1215,9 +1204,8 @@ mod tests { #[tokio::test] async fn ws_upgrade_handler_returns_401_when_identity_is_none() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let identity: Option = None; let response = ws_upgrade_handler_inner(registry, provider, identity, None).await; @@ -1227,12 +1215,11 @@ mod tests { #[tokio::test] async fn ws_upgrade_handler_does_not_reject_when_identity_present() { let registry = echo_registry(); - let provider: Arc = Arc::new( - StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), - ); + let provider: Arc = + Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer"))); let identity = identity("ws-peer"); let response = ws_upgrade_handler_inner(registry, provider, Some(identity), None).await; assert_ne!(response.status(), StatusCode::UNAUTHORIZED); } -} \ No newline at end of file +} diff --git a/crates/alknet-http/tests/from_mcp_integration.rs b/crates/alknet-http/tests/from_mcp_integration.rs index 9d49867..afd69f6 100644 --- a/crates/alknet-http/tests/from_mcp_integration.rs +++ b/crates/alknet-http/tests/from_mcp_integration.rs @@ -9,11 +9,11 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; +use alknet_call::client::OperationAdapter; use alknet_call::protocol::wire::ResponseEnvelope; use alknet_call::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv}; use alknet_call::registry::env::OperationEnv; use alknet_call::registry::registration::OperationProvenance; -use alknet_call::client::OperationAdapter; use alknet_core::types::Capabilities; use alknet_http::adapters::FromMCP; use axum::Router; @@ -22,8 +22,8 @@ use rmcp::model::{ }; use rmcp::service::RequestContext; use rmcp::transport::{ - StreamableHttpServerConfig, streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService}, + StreamableHttpServerConfig, }; use rmcp::{RoleServer, ServerHandler}; use serde_json::Value; @@ -72,18 +72,19 @@ impl ServerHandler for EchoServer { &self, _request: Option, _context: RequestContext, - ) -> impl std::future::Future< - Output = Result, - > + rmcp::service::MaybeSendFuture + '_ { + ) -> impl std::future::Future> + + rmcp::service::MaybeSendFuture + + '_ { let tools = vec![ Tool::new_with_raw( "echo", Some("Echo the input back as structured content".into()), Arc::new(serde_json::Map::new()), ) - .with_raw_output_schema(Arc::new(serde_json::Map::from_iter([ - ("type".to_string(), Value::String("object".into())), - ]))), + .with_raw_output_schema(Arc::new(serde_json::Map::from_iter([( + "type".to_string(), + Value::String("object".into()), + )]))), Tool::new_with_raw( "legacy", Some("Legacy tool returning text content blocks".into()), @@ -101,22 +102,17 @@ impl ServerHandler for EchoServer { &self, request: CallToolRequestParams, _context: RequestContext, - ) -> impl std::future::Future< - Output = Result, - > + rmcp::service::MaybeSendFuture + '_ { + ) -> impl std::future::Future> + + rmcp::service::MaybeSendFuture + + '_ { let name = request.name.to_string(); std::future::ready(Ok(match name.as_str() { "echo" => { - let args = request - .arguments - .map(Value::Object) - .unwrap_or(Value::Null); + let args = request.arguments.map(Value::Object).unwrap_or(Value::Null); CallToolResult::structured(serde_json::json!({ "echoed": args })) } "legacy" => CallToolResult::success(vec![Content::text("plain text result")]), - other => CallToolResult::error(vec![Content::text(format!( - "unknown tool: {other}" - ))]), + other => CallToolResult::error(vec![Content::text(format!("unknown tool: {other}"))]), })) } @@ -234,4 +230,4 @@ async fn import_unreachable_server_returns_discovery_failed() { Err(alknet_call::client::AdapterError::Transport { .. }) => {} Err(other) => panic!("expected DiscoveryFailed or Transport, got {other}"), } -} \ No newline at end of file +}