feat(http): complete to_openapi gateway projection with error fidelity and route wiring

Refine to_openapi to project operation-level errors (with http_status)
onto /call and /subscribe responses via oneOf merge with protocol-level
errors, preserving HTTP_<status> prefix codes without collision. Fix
BTreeMap→serde_json::Map for Value::Object compatibility. Wire GET
/openapi.json route test. Apply cargo fmt across the crate.
This commit is contained in:
2026-07-01 20:11:09 +00:00
parent 2695a19502
commit dd6aacc598
17 changed files with 1227 additions and 683 deletions

View File

@@ -22,7 +22,11 @@ fn make_tool(name: &str, input: Value, output: Option<Value>) -> Tool {
tool tool
} }
fn call_tool_result(content: Vec<Content>, structured: Option<Value>, is_error: Option<bool>) -> CallToolResult { fn call_tool_result(
content: Vec<Content>,
structured: Option<Value>,
is_error: Option<bool>,
) -> CallToolResult {
let json = serde_json::json!({ let json = serde_json::json!({
"content": content, "content": content,
"structuredContent": structured, "structuredContent": structured,
@@ -204,7 +208,9 @@ fn build_spec_output_schema_present_shape() {
let tool = make_tool( let tool = make_tool(
"get_weather", "get_weather",
serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } } }), serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } } }),
Some(serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } })), Some(
serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } }),
),
); );
let spec = build_spec(&tool, "weather"); let spec = build_spec(&tool, "weather");
assert_eq!(spec.name, "weather/get_weather"); assert_eq!(spec.name, "weather/get_weather");

View File

@@ -17,10 +17,10 @@ use std::sync::Arc;
use alknet_call::client::{AdapterError, OperationAdapter}; use alknet_call::client::{AdapterError, OperationAdapter};
use alknet_call::protocol::wire::{CallError, ResponseEnvelope}; use alknet_call::protocol::wire::{CallError, ResponseEnvelope};
use alknet_call::registry::context::OperationContext; use alknet_call::registry::context::OperationContext;
use alknet_call::registry::registration::{ use alknet_call::registry::registration::{make_handler, HandlerRegistration, OperationProvenance};
make_handler, HandlerRegistration, OperationProvenance, use alknet_call::registry::spec::{
AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility,
}; };
use alknet_call::registry::spec::{AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility};
use alknet_core::types::Capabilities; use alknet_core::types::Capabilities;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
@@ -128,11 +128,9 @@ impl OpenAPISpec {
.to_string(), .to_string(),
}; };
let paths_raw = raw let paths_raw = raw.get("paths").ok_or_else(|| AdapterError::SchemaParse {
.get("paths") message: "OpenAPI document missing `paths`".into(),
.ok_or_else(|| AdapterError::SchemaParse { })?;
message: "OpenAPI document missing `paths`".into(),
})?;
if !paths_raw.is_object() { if !paths_raw.is_object() {
return Err(AdapterError::SchemaParse { return Err(AdapterError::SchemaParse {
message: "`paths` must be a JSON object".into(), message: "`paths` must be a JSON object".into(),
@@ -155,14 +153,13 @@ impl OpenAPISpec {
if operations.is_empty() { if operations.is_empty() {
continue; continue;
} }
paths.insert( paths.insert(path.clone(), PathItem { operations });
path.clone(),
PathItem { operations },
);
} }
let components = raw.get("components").and_then(|c| c.get("schemas")).and_then( let components = raw
|schemas| { .get("components")
.and_then(|c| c.get("schemas"))
.and_then(|schemas| {
if !schemas.is_object() { if !schemas.is_object() {
return None; return None;
} }
@@ -171,8 +168,7 @@ impl OpenAPISpec {
map.insert(k.clone(), v.clone()); map.insert(k.clone(), v.clone());
} }
Some(Components { schemas: map }) Some(Components { schemas: map })
}, });
);
Ok(Self { Ok(Self {
info, info,
@@ -190,11 +186,9 @@ impl OpenAPISpec {
} }
let mut current: &Value = &self.raw; let mut current: &Value = &self.raw;
for part in reference.trim_start_matches("#/").split('/') { for part in reference.trim_start_matches("#/").split('/') {
current = current current = current.get(part).ok_or_else(|| AdapterError::SchemaParse {
.get(part) message: format!("cannot resolve $ref: {reference}"),
.ok_or_else(|| AdapterError::SchemaParse { })?;
message: format!("cannot resolve $ref: {reference}"),
})?;
} }
Ok(current.clone()) Ok(current.clone())
} }
@@ -241,10 +235,7 @@ fn parse_operation(raw: &Value) -> Option<Operation> {
.filter_map(|p| { .filter_map(|p| {
let name = p.get("name")?.as_str()?.to_string(); let name = p.get("name")?.as_str()?.to_string();
let in_ = p.get("in")?.as_str()?.to_string(); let in_ = p.get("in")?.as_str()?.to_string();
let required = p let required = p.get("required").and_then(|v| v.as_bool()).unwrap_or(false);
.get("required")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let schema = p.get("schema").cloned(); let schema = p.get("schema").cloned();
Some(Parameter { Some(Parameter {
name, name,
@@ -297,7 +288,11 @@ pub struct FromOpenAPI {
} }
impl FromOpenAPI { impl FromOpenAPI {
pub fn new(spec: OpenAPISpec, config: HttpServiceConfig, http_client: Arc<SharedHttpClient>) -> Self { pub fn new(
spec: OpenAPISpec,
config: HttpServiceConfig,
http_client: Arc<SharedHttpClient>,
) -> Self {
Self { Self {
spec, spec,
config, config,
@@ -322,10 +317,7 @@ impl FromOpenAPI {
} }
fn detect_op_type(method: &str, op: &Operation) -> OperationType { fn detect_op_type(method: &str, op: &Operation) -> OperationType {
let success = op let success = op.responses.get("200").or_else(|| op.responses.get("201"));
.responses
.get("200")
.or_else(|| op.responses.get("201"));
if let Some(resp) = success { if let Some(resp) = success {
if resp.content.contains_key("text/event-stream") { if resp.content.contains_key("text/event-stream") {
return OperationType::Subscription; return OperationType::Subscription;
@@ -531,9 +523,8 @@ fn build_request(
} }
} }
let base = Url::parse(base_url).map_err(|e| { let base = Url::parse(base_url)
CallError::internal(format!("invalid base_url `{base_url}`: {e}")) .map_err(|e| CallError::internal(format!("invalid base_url `{base_url}`: {e}")))?;
})?;
let mut url = base let mut url = base
.join(url_path.trim_start_matches('/')) .join(url_path.trim_start_matches('/'))
.map_err(|e| CallError::internal(format!("invalid path `{url_path}`: {e}")))?; .map_err(|e| CallError::internal(format!("invalid path `{url_path}`: {e}")))?;
@@ -683,11 +674,12 @@ async fn forward(
.find(|(s, _)| *s == status.as_u16()) .find(|(s, _)| *s == status.as_u16())
.map(|(_, code)| code.clone()) .map(|(_, code)| code.clone())
.unwrap_or_else(|| format!("HTTP_{}", status.as_u16())); .unwrap_or_else(|| format!("HTTP_{}", status.as_u16()));
let message = format!("HTTP {}: {}", status.as_u16(), status.canonical_reason().unwrap_or("")); let message = format!(
return ResponseEnvelope::error( "HTTP {}: {}",
request_id, status.as_u16(),
CallError::new(code, message, false), status.canonical_reason().unwrap_or("")
); );
return ResponseEnvelope::error(request_id, CallError::new(code, message, false));
} }
let content_type = response let content_type = response
@@ -716,10 +708,7 @@ async fn forward(
} else { } else {
match response.bytes().await { match response.bytes().await {
Ok(b) => { Ok(b) => {
let arr: Vec<Value> = b let arr: Vec<Value> = b.iter().map(|byte| Value::Number((*byte).into())).collect();
.iter()
.map(|byte| Value::Number((*byte).into()))
.collect();
ResponseEnvelope::ok(request_id, Value::Array(arr)) ResponseEnvelope::ok(request_id, Value::Array(arr))
} }
Err(err) => ResponseEnvelope::error( Err(err) => ResponseEnvelope::error(
@@ -744,7 +733,8 @@ async fn stream_subscription(request_id: String, response: reqwest::Response) ->
let parsed = if event.data.trim().is_empty() { let parsed = if event.data.trim().is_empty() {
Value::Null Value::Null
} else { } else {
serde_json::from_str(&event.data).unwrap_or(Value::String(event.data.clone())) serde_json::from_str(&event.data)
.unwrap_or(Value::String(event.data.clone()))
}; };
last_event = Some(parsed.clone()); last_event = Some(parsed.clone());
} }
@@ -1040,7 +1030,12 @@ mod tests {
.unwrap(); .unwrap();
let body = props.get("body").unwrap(); let body = props.get("body").unwrap();
assert_eq!(body.get("type").unwrap(), "object"); assert_eq!(body.get("type").unwrap(), "object");
assert!(body.get("properties").unwrap().as_object().unwrap().contains_key("name")); assert!(body
.get("properties")
.unwrap()
.as_object()
.unwrap()
.contains_key("name"));
} }
#[tokio::test] #[tokio::test]
@@ -1074,14 +1069,19 @@ mod tests {
"https://api.vast.ai", "https://api.vast.ai",
"/machines", "/machines",
"GET", "GET",
&Some(HttpAuthScheme::ApiKey { header_name: "X-API-Key".to_string() }), &Some(HttpAuthScheme::ApiKey {
header_name: "X-API-Key".to_string(),
}),
&HashMap::new(), &HashMap::new(),
"vastai", "vastai",
&serde_json::json!({}), &serde_json::json!({}),
&ctx, &ctx,
) )
.unwrap(); .unwrap();
assert_eq!(headers.get("X-API-Key").unwrap().to_str().unwrap(), "key-xyz"); assert_eq!(
headers.get("X-API-Key").unwrap().to_str().unwrap(),
"key-xyz"
);
} }
#[tokio::test] #[tokio::test]
@@ -1267,7 +1267,11 @@ mod tests {
#[test] #[test]
fn http_service_config_struct_fields() { fn http_service_config_struct_fields() {
let cfg = config("ns", "https://api.example.com", Some(HttpAuthScheme::Bearer)); let cfg = config(
"ns",
"https://api.example.com",
Some(HttpAuthScheme::Bearer),
);
assert_eq!(cfg.namespace, "ns"); assert_eq!(cfg.namespace, "ns");
assert_eq!(cfg.base_url, "https://api.example.com"); assert_eq!(cfg.base_url, "https://api.example.com");
assert!(matches!(cfg.auth, Some(HttpAuthScheme::Bearer))); assert!(matches!(cfg.auth, Some(HttpAuthScheme::Bearer)));
@@ -1289,7 +1293,12 @@ mod tests {
}"#; }"#;
let spec = OpenAPISpec::from_json(doc).unwrap(); let spec = OpenAPISpec::from_json(doc).unwrap();
assert!(spec.components.is_some()); assert!(spec.components.is_some());
assert!(spec.components.as_ref().unwrap().schemas.contains_key("Foo")); assert!(spec
.components
.as_ref()
.unwrap()
.schemas
.contains_key("Foo"));
} }
#[tokio::test] #[tokio::test]
@@ -1342,7 +1351,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn resolve_ref_missing_target_returns_schema_parse() { async fn resolve_ref_missing_target_returns_schema_parse() {
let spec = OpenAPISpec::from_json(minimal_spec_json()).unwrap(); let spec = OpenAPISpec::from_json(minimal_spec_json()).unwrap();
let err = spec.resolve_ref("#/components/schemas/Missing").unwrap_err(); let err = spec
.resolve_ref("#/components/schemas/Missing")
.unwrap_err();
assert!(matches!(err, AdapterError::SchemaParse { .. })); assert!(matches!(err, AdapterError::SchemaParse { .. }));
} }
@@ -1409,7 +1420,8 @@ mod tests {
headers, headers,
body, body,
}); });
let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}"; let response =
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}";
sock.write_all(response.as_bytes()).await.unwrap(); sock.write_all(response.as_bytes()).await.unwrap();
sock.flush().await.unwrap(); sock.flush().await.unwrap();
}); });
@@ -1440,12 +1452,19 @@ mod tests {
ctx, ctx,
) )
.await; .await;
assert!(response.result.is_ok(), "expected Ok, got {:?}", response.result); assert!(
response.result.is_ok(),
"expected Ok, got {:?}",
response.result
);
let captured = rx.await.unwrap(); let captured = rx.await.unwrap();
assert_eq!(captured.method, "POST"); assert_eq!(captured.method, "POST");
assert_eq!(captured.path, "/items/42"); assert_eq!(captured.path, "/items/42");
assert_eq!(captured.query, "filter=new"); assert_eq!(captured.query, "filter=new");
assert_eq!(captured.headers.get("content-type").unwrap(), "application/json"); assert_eq!(
captured.headers.get("content-type").unwrap(),
"application/json"
);
assert!(captured.body.contains("\"name\":\"widget\"")); assert!(captured.body.contains("\"name\":\"widget\""));
} }
@@ -1457,19 +1476,19 @@ mod tests {
}"#; }"#;
let (base, rx) = spawn_capturing_server().await; let (base, rx) = spawn_capturing_server().await;
let spec = OpenAPISpec::from_json(doc).unwrap(); let spec = OpenAPISpec::from_json(doc).unwrap();
let bundles = adapter( let bundles = adapter(spec, config("openai", &base, Some(HttpAuthScheme::Bearer)))
spec, .import()
config("openai", &base, Some(HttpAuthScheme::Bearer)), .await
) .unwrap();
.import()
.await
.unwrap();
let registration = &bundles[0]; let registration = &bundles[0];
let caps = Capabilities::new().with_http_token("openai", "sk-test-token".to_string()); let caps = Capabilities::new().with_http_token("openai", "sk-test-token".to_string());
let ctx = noop_context("req-17", caps); let ctx = noop_context("req-17", caps);
let _ = (registration.handler)(serde_json::json!({}), ctx).await; let _ = (registration.handler)(serde_json::json!({}), ctx).await;
let captured = rx.await.unwrap(); let captured = rx.await.unwrap();
assert_eq!(captured.headers.get("authorization").unwrap(), "Bearer sk-test-token"); assert_eq!(
captured.headers.get("authorization").unwrap(),
"Bearer sk-test-token"
);
} }
#[tokio::test] #[tokio::test]

View File

@@ -36,8 +36,8 @@ use rmcp::model::{
}; };
use rmcp::service::{RequestContext, RoleServer}; use rmcp::service::{RequestContext, RoleServer};
use rmcp::transport::{ use rmcp::transport::{
StreamableHttpServerConfig,
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService}, streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
StreamableHttpServerConfig,
}; };
use serde_json::{Map, Value}; use serde_json::{Map, Value};
@@ -133,7 +133,10 @@ impl ToMcpGateway {
fn extract_identity_from_extensions(extensions: &rmcp::model::Extensions) -> Option<Identity> { fn extract_identity_from_extensions(extensions: &rmcp::model::Extensions) -> Option<Identity> {
let parts = extensions.get::<http::request::Parts>()?; let parts = extensions.get::<http::request::Parts>()?;
parts.extensions.get::<Option<Identity>>().and_then(Option::clone) parts
.extensions
.get::<Option<Identity>>()
.and_then(Option::clone)
} }
async fn handle_search(&self, identity: Option<Identity>) -> CallToolResult { async fn handle_search(&self, identity: Option<Identity>) -> CallToolResult {
@@ -144,8 +147,15 @@ impl ToMcpGateway {
map_search_response(response, identity.as_ref()) map_search_response(response, identity.as_ref())
} }
async fn handle_schema(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult { async fn handle_schema(
let name = match arguments.and_then(|mut a| a.remove("name")).and_then(|v| v.as_str().map(str::to_string)) { &self,
arguments: Option<JsonObject>,
identity: Option<Identity>,
) -> CallToolResult {
let name = match arguments
.and_then(|mut a| a.remove("name"))
.and_then(|v| v.as_str().map(str::to_string))
{
Some(n) => n, Some(n) => n,
None => { None => {
return CallToolResult::structured_error(serde_json::json!({ return CallToolResult::structured_error(serde_json::json!({
@@ -156,12 +166,20 @@ impl ToMcpGateway {
}; };
let response = self let response = self
.dispatch .dispatch
.invoke(identity, OP_SERVICES_SCHEMA, serde_json::json!({ "name": name })) .invoke(
identity,
OP_SERVICES_SCHEMA,
serde_json::json!({ "name": name }),
)
.await; .await;
envelope_to_call_tool_result(response) envelope_to_call_tool_result(response)
} }
async fn handle_call(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult { async fn handle_call(
&self,
arguments: Option<JsonObject>,
identity: Option<Identity>,
) -> CallToolResult {
let (operation, input) = match parse_call_arguments(arguments) { let (operation, input) = match parse_call_arguments(arguments) {
Ok(pair) => pair, Ok(pair) => pair,
Err(err) => return err, Err(err) => return err,
@@ -170,7 +188,11 @@ impl ToMcpGateway {
envelope_to_call_tool_result(response) envelope_to_call_tool_result(response)
} }
async fn handle_batch(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult { async fn handle_batch(
&self,
arguments: Option<JsonObject>,
identity: Option<Identity>,
) -> CallToolResult {
let calls = match arguments let calls = match arguments
.and_then(|mut a| a.remove("calls")) .and_then(|mut a| a.remove("calls"))
.and_then(|v| v.as_array().cloned()) .and_then(|v| v.as_array().cloned())
@@ -193,7 +215,10 @@ impl ToMcpGateway {
continue; continue;
} }
}; };
let response = self.dispatch.invoke(identity.clone(), &operation, input).await; let response = self
.dispatch
.invoke(identity.clone(), &operation, input)
.await;
results.push(envelope_to_value(response)); results.push(envelope_to_value(response));
} }
CallToolResult::structured(Value::Array(results)) CallToolResult::structured(Value::Array(results))
@@ -210,7 +235,10 @@ fn parse_call_arguments(arguments: Option<JsonObject>) -> Result<(String, Value)
}))); })));
} }
}; };
let operation = match map.remove("operation").and_then(|v| v.as_str().map(str::to_string)) { let operation = match map
.remove("operation")
.and_then(|v| v.as_str().map(str::to_string))
{
Some(s) => s, Some(s) => s,
None => { None => {
return Err(CallToolResult::structured_error(serde_json::json!({ return Err(CallToolResult::structured_error(serde_json::json!({
@@ -359,7 +387,11 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway {
TOOL_CALL => this.handle_call(arguments, identity).await, TOOL_CALL => this.handle_call(arguments, identity).await,
TOOL_BATCH => this.handle_batch(arguments, identity).await, TOOL_BATCH => this.handle_batch(arguments, identity).await,
unknown => { unknown => {
let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false); let err = CallError::new(
"NOT_FOUND",
format!("unknown gateway tool: {unknown}"),
false,
);
call_error_to_structured_error(err) call_error_to_structured_error(err)
} }
}; };
@@ -368,9 +400,7 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway {
} }
fn get_info(&self) -> ServerInfo { fn get_info(&self) -> ServerInfo {
let capabilities = ServerCapabilities::builder() let capabilities = ServerCapabilities::builder().enable_tools().build();
.enable_tools()
.build();
ServerInfo::new(capabilities) ServerInfo::new(capabilities)
.with_server_info(Implementation::new( .with_server_info(Implementation::new(
"alknet-to-mcp", "alknet-to-mcp",
@@ -462,10 +492,14 @@ mod tests {
} }
fn make_echo_handler() -> alknet_call::registry::registration::Handler { fn make_echo_handler() -> alknet_call::registry::registration::Handler {
make_handler(|input, context| async move { ResponseEnvelope::ok(context.request_id, input) }) make_handler(
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
)
} }
fn full_registry_with_ops(specs: Vec<(String, OperationType, AccessControl)>) -> Arc<OperationRegistry> { fn full_registry_with_ops(
specs: Vec<(String, OperationType, AccessControl)>,
) -> Arc<OperationRegistry> {
let mut inner = OperationRegistry::new(); let mut inner = OperationRegistry::new();
for (name, op_type, acl) in specs { for (name, op_type, acl) in specs {
inner.register(HandlerRegistration::new( inner.register(HandlerRegistration::new(
@@ -509,7 +543,10 @@ mod tests {
Arc::new(dispatch_registry) Arc::new(dispatch_registry)
} }
fn dispatch(registry: Arc<OperationRegistry>, provider: Arc<dyn IdentityProvider>) -> Arc<GatewayDispatch> { fn dispatch(
registry: Arc<OperationRegistry>,
provider: Arc<dyn IdentityProvider>,
) -> Arc<GatewayDispatch> {
Arc::new(GatewayDispatch::new(registry, provider)) Arc::new(GatewayDispatch::new(registry, provider))
} }
@@ -542,7 +579,11 @@ mod tests {
TOOL_CALL => gateway.handle_call(arguments, identity).await, TOOL_CALL => gateway.handle_call(arguments, identity).await,
TOOL_BATCH => gateway.handle_batch(arguments, identity).await, TOOL_BATCH => gateway.handle_batch(arguments, identity).await,
unknown => { unknown => {
let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false); let err = CallError::new(
"NOT_FOUND",
format!("unknown gateway tool: {unknown}"),
false,
);
call_error_to_structured_error(err) call_error_to_structured_error(err)
} }
} }
@@ -550,10 +591,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn list_tools_returns_exactly_four_gateway_tools() { async fn list_tools_returns_exactly_four_gateway_tools() {
let _gateway = ToMcpGateway::new(dispatch( let _gateway = ToMcpGateway::new(dispatch(full_registry_with_ops(vec![]), provider()));
full_registry_with_ops(vec![]),
provider(),
));
let tools = gateway_tools(); let tools = gateway_tools();
let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect(); let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
assert_eq!(names.len(), 4); assert_eq!(names.len(), 4);
@@ -583,7 +621,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn search_returns_access_control_filtered_ops_excluding_subscriptions() { async fn search_returns_access_control_filtered_ops_excluding_subscriptions() {
let registry = full_registry_with_ops(vec![ let registry = full_registry_with_ops(vec![
("public/echo".to_string(), OperationType::Query, AccessControl::default()), (
"public/echo".to_string(),
OperationType::Query,
AccessControl::default(),
),
( (
"admin/secret".to_string(), "admin/secret".to_string(),
OperationType::Query, OperationType::Query,
@@ -592,13 +634,22 @@ mod tests {
..Default::default() ..Default::default()
}, },
), ),
("events/stream".to_string(), OperationType::Subscription, AccessControl::default()), (
"events/stream".to_string(),
OperationType::Subscription,
AccessControl::default(),
),
]); ]);
let idp: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new()); let idp: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
let gateway = ToMcpGateway::new(dispatch(registry, idp)); let gateway = ToMcpGateway::new(dispatch(registry, idp));
let result = invoke_tool(&gateway, "search", None, Some(identity_with_scopes("user", &["user"]))) let result = invoke_tool(
.await; &gateway,
"search",
None,
Some(identity_with_scopes("user", &["user"])),
)
.await;
assert_eq!(result.is_error, Some(false)); assert_eq!(result.is_error, Some(false));
let structured = result.structured_content.expect("structured present"); let structured = result.structured_content.expect("structured present");
let ops = structured let ops = structured
@@ -610,11 +661,23 @@ mod tests {
.filter_map(|o| o.get("name").and_then(Value::as_str)) .filter_map(|o| o.get("name").and_then(Value::as_str))
.collect(); .collect();
assert!(names.contains(&"public/echo")); assert!(names.contains(&"public/echo"));
assert!(!names.contains(&"admin/secret"), "ACL-filtered op must not appear"); assert!(
assert!(!names.contains(&"events/stream"), "Subscription op must be excluded"); !names.contains(&"admin/secret"),
"ACL-filtered op must not appear"
);
assert!(
!names.contains(&"events/stream"),
"Subscription op must be excluded"
);
for op in ops { for op in ops {
assert!(op.get("description").is_some(), "each entry has a description"); assert!(
assert!(op.get("input_schema").is_none(), "search must not return full schemas"); op.get("description").is_some(),
"each entry has a description"
);
assert!(
op.get("input_schema").is_none(),
"search must not return full schemas"
);
} }
} }
@@ -632,7 +695,10 @@ mod tests {
let result = invoke_tool(&gateway, "schema", Some(args), None).await; let result = invoke_tool(&gateway, "schema", Some(args), None).await;
assert_eq!(result.is_error, Some(false)); assert_eq!(result.is_error, Some(false));
let structured = result.structured_content.expect("structured present"); let structured = result.structured_content.expect("structured present");
assert_eq!(structured.get("name"), Some(&Value::String("fs/readFile".to_string()))); assert_eq!(
structured.get("name"),
Some(&Value::String("fs/readFile".to_string()))
);
assert!(structured.get("input_schema").is_some()); assert!(structured.get("input_schema").is_some());
assert!(structured.get("output_schema").is_some()); assert!(structured.get("output_schema").is_some());
assert!(structured.get("error_schemas").is_some()); assert!(structured.get("error_schemas").is_some());
@@ -649,7 +715,10 @@ mod tests {
let gateway = ToMcpGateway::new(dispatch(registry, provider())); let gateway = ToMcpGateway::new(dispatch(registry, provider()));
let mut args = Map::new(); let mut args = Map::new();
args.insert("operation".to_string(), Value::String("echo/run".to_string())); args.insert(
"operation".to_string(),
Value::String("echo/run".to_string()),
);
args.insert("input".to_string(), serde_json::json!({ "msg": "hi" })); args.insert("input".to_string(), serde_json::json!({ "msg": "hi" }));
let result = invoke_tool(&gateway, "call", Some(args), None).await; let result = invoke_tool(&gateway, "call", Some(args), None).await;
assert_eq!(result.is_error, Some(false)); assert_eq!(result.is_error, Some(false));
@@ -665,12 +734,18 @@ mod tests {
let gateway = ToMcpGateway::new(dispatch(registry, provider())); let gateway = ToMcpGateway::new(dispatch(registry, provider()));
let mut args = Map::new(); let mut args = Map::new();
args.insert("operation".to_string(), Value::String("no/such".to_string())); args.insert(
"operation".to_string(),
Value::String("no/such".to_string()),
);
args.insert("input".to_string(), Value::Object(Map::new())); args.insert("input".to_string(), Value::Object(Map::new()));
let result = invoke_tool(&gateway, "call", Some(args), None).await; let result = invoke_tool(&gateway, "call", Some(args), None).await;
assert_eq!(result.is_error, Some(true)); assert_eq!(result.is_error, Some(true));
let structured = result.structured_content.expect("structured error present"); let structured = result.structured_content.expect("structured error present");
assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string()))); assert_eq!(
structured.get("code"),
Some(&Value::String("NOT_FOUND".to_string()))
);
} }
#[tokio::test] #[tokio::test]
@@ -713,12 +788,18 @@ mod tests {
let gateway = ToMcpGateway::new(dispatch(registry, idp)); let gateway = ToMcpGateway::new(dispatch(registry, idp));
let mut args = Map::new(); let mut args = Map::new();
args.insert("operation".to_string(), Value::String("admin/run".to_string())); args.insert(
"operation".to_string(),
Value::String("admin/run".to_string()),
);
args.insert("input".to_string(), Value::Object(Map::new())); args.insert("input".to_string(), Value::Object(Map::new()));
let result = invoke_tool(&gateway, "call", Some(args), None).await; let result = invoke_tool(&gateway, "call", Some(args), None).await;
assert_eq!(result.is_error, Some(true)); assert_eq!(result.is_error, Some(true));
let structured = result.structured_content.expect("structured error present"); let structured = result.structured_content.expect("structured error present");
assert_eq!(structured.get("code"), Some(&Value::String("FORBIDDEN".to_string()))); assert_eq!(
structured.get("code"),
Some(&Value::String("FORBIDDEN".to_string()))
);
} }
#[tokio::test] #[tokio::test]
@@ -727,7 +808,10 @@ mod tests {
let result = invoke_tool(&gateway, "bogus", None, None).await; let result = invoke_tool(&gateway, "bogus", None, None).await;
assert_eq!(result.is_error, Some(true)); assert_eq!(result.is_error, Some(true));
let structured = result.structured_content.expect("structured error present"); let structured = result.structured_content.expect("structured error present");
assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string()))); assert_eq!(
structured.get("code"),
Some(&Value::String("NOT_FOUND".to_string()))
);
} }
#[tokio::test] #[tokio::test]
@@ -749,10 +833,16 @@ mod tests {
let admin_identity = identity_with_scopes("admin-peer", &["admin"]); let admin_identity = identity_with_scopes("admin-peer", &["admin"]);
let extensions = extensions_with_identity(Some(admin_identity.clone())); let extensions = extensions_with_identity(Some(admin_identity.clone()));
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions); let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
assert_eq!(extracted.as_ref().map(|i| &i.id), Some(&"admin-peer".to_string())); assert_eq!(
extracted.as_ref().map(|i| &i.id),
Some(&"admin-peer".to_string())
);
let mut args = Map::new(); let mut args = Map::new();
args.insert("operation".to_string(), Value::String("admin/run".to_string())); args.insert(
"operation".to_string(),
Value::String("admin/run".to_string()),
);
args.insert("input".to_string(), serde_json::json!({ "ok": 1 })); args.insert("input".to_string(), serde_json::json!({ "ok": 1 }));
let result = gateway.handle_call(Some(args), extracted).await; let result = gateway.handle_call(Some(args), extracted).await;
assert_eq!(result.is_error, Some(false)); assert_eq!(result.is_error, Some(false));
@@ -779,7 +869,10 @@ mod tests {
let id = identity_with_scopes("caller", &["read"]); let id = identity_with_scopes("caller", &["read"]);
let extensions = extensions_with_identity(Some(id.clone())); let extensions = extensions_with_identity(Some(id.clone()));
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions); let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
assert_eq!(extracted.as_ref().map(|i| i.id.clone()), Some("caller".to_string())); assert_eq!(
extracted.as_ref().map(|i| i.id.clone()),
Some("caller".to_string())
);
assert_eq!( assert_eq!(
extracted.as_ref().map(|i| i.scopes.clone()), extracted.as_ref().map(|i| i.scopes.clone()),
Some(vec!["read".to_string()]) Some(vec!["read".to_string()])
@@ -834,8 +927,14 @@ mod tests {
); );
let mut call_args = Map::new(); let mut call_args = Map::new();
call_args.insert("operation".to_string(), Value::String(first_name.to_string())); call_args.insert(
call_args.insert("input".to_string(), serde_json::json!({ "path": "/etc/hosts" })); "operation".to_string(),
Value::String(first_name.to_string()),
);
call_args.insert(
"input".to_string(),
serde_json::json!({ "path": "/etc/hosts" }),
);
let call_result = invoke_tool(&gateway, "call", Some(call_args), None).await; let call_result = invoke_tool(&gateway, "call", Some(call_args), None).await;
assert_eq!( assert_eq!(
call_result.structured_content, call_result.structured_content,

File diff suppressed because it is too large Load Diff

View File

@@ -125,10 +125,11 @@ fn build_client(config: &HttpClientConfig) -> Result<ClientWithMiddleware, HttpC
builder = builder.timeout(timeout); builder = builder.timeout(timeout);
} }
if let Some(ca_bundle_path) = &config.ca_bundle { if let Some(ca_bundle_path) = &config.ca_bundle {
let pem = std::fs::read(ca_bundle_path).map_err(|source| HttpClientBuildError::CaBundleRead { let pem =
path: ca_bundle_path.clone(), std::fs::read(ca_bundle_path).map_err(|source| HttpClientBuildError::CaBundleRead {
source, path: ca_bundle_path.clone(),
})?; source,
})?;
let certs = reqwest::Certificate::from_pem_bundle(&pem).map_err(|source| { let certs = reqwest::Certificate::from_pem_bundle(&pem).map_err(|source| {
HttpClientBuildError::CaBundleParse { HttpClientBuildError::CaBundleParse {
path: ca_bundle_path.clone(), path: ca_bundle_path.clone(),
@@ -152,19 +153,21 @@ fn build_client(config: &HttpClientConfig) -> Result<ClientWithMiddleware, HttpC
source, source,
} }
})?; })?;
let identity = reqwest::Identity::from_pem( let identity = reqwest::Identity::from_pem(concat_pem(&cert_pem, &key_pem).as_slice())
concat_pem(&cert_pem, &key_pem).as_slice(), .map_err(|source| HttpClientBuildError::ClientCertParse {
) path: client_cert_cfg.cert_pem.clone(),
.map_err(|source| HttpClientBuildError::ClientCertParse { source,
path: client_cert_cfg.cert_pem.clone(), })?;
source,
})?;
builder = builder.identity(identity); builder = builder.identity(identity);
} }
let reqwest_client = builder.build().map_err(HttpClientBuildError::Build)?; let reqwest_client = builder.build().map_err(HttpClientBuildError::Build)?;
let client = reqwest_middleware::ClientBuilder::new(reqwest_client) let client = reqwest_middleware::ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(config.retry_policy)) .with(RetryTransientMiddleware::new_with_policy(
.with(RetryAfterMiddleware::with_capacity(DEFAULT_RETRY_AFTER_CAPACITY)) config.retry_policy,
))
.with(RetryAfterMiddleware::with_capacity(
DEFAULT_RETRY_AFTER_CAPACITY,
))
.build(); .build();
Ok(client) Ok(client)
} }
@@ -203,10 +206,7 @@ mod tests {
.build() .build()
.expect("RequestBuilder builds"); .expect("RequestBuilder builds");
assert_eq!(request.method(), reqwest::Method::GET); assert_eq!(request.method(), reqwest::Method::GET);
assert_eq!( assert_eq!(request.url().as_str(), "https://api.example.com/v1/chat");
request.url().as_str(),
"https://api.example.com/v1/chat"
);
} }
#[test] #[test]

View File

@@ -99,7 +99,10 @@ impl RetryAfterMiddleware {
#[cfg(test)] #[cfg(test)]
fn len(&self) -> usize { fn len(&self) -> usize {
self.deadlines.lock().expect("deadlines mutex poisoned").len() self.deadlines
.lock()
.expect("deadlines mutex poisoned")
.len()
} }
#[cfg(test)] #[cfg(test)]
@@ -156,8 +159,8 @@ mod tests {
#[test] #[test]
fn parse_retry_after_http_date() { fn parse_retry_after_http_date() {
let deadline = parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT") let deadline =
.expect("HTTP-date value parses"); parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT").expect("HTTP-date value parses");
assert!(deadline > SystemTime::now()); assert!(deadline > SystemTime::now());
} }
@@ -272,7 +275,10 @@ mod tests {
async fn middleware_sleeps_before_request_with_active_deadline() { async fn middleware_sleeps_before_request_with_active_deadline() {
let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8));
let target = url("https://api.example.com/v1/chat"); let target = url("https://api.example.com/v1/chat");
mw.record_test(target.clone(), SystemTime::now() + Duration::from_millis(50)); mw.record_test(
target.clone(),
SystemTime::now() + Duration::from_millis(50),
);
let started = SystemTime::now(); let started = SystemTime::now();
mw.maybe_sleep_for(&target).await; mw.maybe_sleep_for(&target).await;
let elapsed = SystemTime::now().duration_since(started).unwrap(); let elapsed = SystemTime::now().duration_since(started).unwrap();

View File

@@ -83,11 +83,7 @@ impl GatewayDispatch {
r.capabilities.clone(), r.capabilities.clone(),
r.scoped_env.clone().unwrap_or_else(ScopedPeerEnv::empty), r.scoped_env.clone().unwrap_or_else(ScopedPeerEnv::empty),
), ),
None => ( None => (None, Capabilities::new(), ScopedPeerEnv::empty()),
None,
Capabilities::new(),
ScopedPeerEnv::empty(),
),
}; };
let env: Arc<dyn alknet_call::registry::env::OperationEnv + Send + Sync> = let env: Arc<dyn alknet_call::registry::env::OperationEnv + Send + Sync> =
@@ -254,10 +250,7 @@ mod tests {
.invoke(None, "echo/run", serde_json::json!({ "msg": "hi" })) .invoke(None, "echo/run", serde_json::json!({ "msg": "hi" }))
.await; .await;
assert!(response.result.is_ok()); assert!(response.result.is_ok());
assert_eq!( assert_eq!(response.result.unwrap(), serde_json::json!({ "msg": "hi" }));
response.result.unwrap(),
serde_json::json!({ "msg": "hi" })
);
} }
#[tokio::test] #[tokio::test]
@@ -270,9 +263,7 @@ mod tests {
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new()); let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
let dp = dispatch(registry, provider); let dp = dispatch(registry, provider);
let response = dp let response = dp.invoke(None, "/echo/run", serde_json::json!({})).await;
.invoke(None, "/echo/run", serde_json::json!({}))
.await;
assert!(response.result.is_ok()); assert!(response.result.is_ok());
} }
@@ -369,9 +360,7 @@ mod tests {
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new()); let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
let dp = dispatch(registry, provider); let dp = dispatch(registry, provider);
let response = dp let response = dp.invoke(None, "no/such", serde_json::json!({})).await;
.invoke(None, "no/such", serde_json::json!({}))
.await;
match response.result { match response.result {
Err(e) => { Err(e) => {
assert_eq!(e.code, "NOT_FOUND"); assert_eq!(e.code, "NOT_FOUND");
@@ -398,9 +387,7 @@ mod tests {
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new()); let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
let dp = dispatch(registry, provider); let dp = dispatch(registry, provider);
let response = dp let response = dp.invoke(None, "secret/op", serde_json::json!({})).await;
.invoke(None, "secret/op", serde_json::json!({}))
.await;
match response.result { match response.result {
Err(e) => { Err(e) => {
assert_eq!(e.code, "NOT_FOUND"); assert_eq!(e.code, "NOT_FOUND");
@@ -423,9 +410,7 @@ mod tests {
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new()); let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
let dp = dispatch(registry, provider); let dp = dispatch(registry, provider);
let response = dp let response = dp.invoke(None, "admin/run", serde_json::json!({})).await;
.invoke(None, "admin/run", serde_json::json!({}))
.await;
match response.result { match response.result {
Err(e) => { Err(e) => {
assert_eq!(e.code, "FORBIDDEN"); assert_eq!(e.code, "FORBIDDEN");
@@ -506,8 +491,10 @@ mod tests {
#[test] #[test]
fn build_root_context_carries_registration_bundle_fields() { fn build_root_context_carries_registration_bundle_fields() {
let authority = let authority = alknet_call::registry::context::CompositionAuthority::new(
alknet_call::registry::context::CompositionAuthority::new("agent", ["fs:read".to_string()]); "agent",
["fs:read".to_string()],
);
let scoped = ScopedPeerEnv::new(["fs/readFile"]); let scoped = ScopedPeerEnv::new(["fs/readFile"]);
let caps = Capabilities::new().with_api_key("google", "k".to_string()); let caps = Capabilities::new().with_api_key("google", "k".to_string());

View File

@@ -31,7 +31,10 @@ pub fn call_error_to_http_status(error: &CallError) -> u16 {
call_error_to_http_status_with_identity(error, None) call_error_to_http_status_with_identity(error, None)
} }
pub fn call_error_to_http_status_with_identity(error: &CallError, identity: Option<&Identity>) -> u16 { pub fn call_error_to_http_status_with_identity(
error: &CallError,
identity: Option<&Identity>,
) -> u16 {
match error.code.as_str() { match error.code.as_str() {
PROTOCOL_CODE_NOT_FOUND => STATUS_NOT_FOUND, PROTOCOL_CODE_NOT_FOUND => STATUS_NOT_FOUND,
PROTOCOL_CODE_FORBIDDEN => { PROTOCOL_CODE_FORBIDDEN => {
@@ -59,8 +62,8 @@ pub fn call_error_to_http_response(error: &CallError) -> Response {
let retry_after = retry_after_value(error, status_code); let retry_after = retry_after_value(error, status_code);
if let Some(retry_after) = retry_after { if let Some(retry_after) = retry_after {
let header_value = HeaderValue::from_str(&retry_after) let header_value =
.unwrap_or_else(|_| HeaderValue::from_static("0")); HeaderValue::from_str(&retry_after).unwrap_or_else(|_| HeaderValue::from_static("0"));
(status, [(header::RETRY_AFTER, header_value)], Json(body)).into_response() (status, [(header::RETRY_AFTER, header_value)], Json(body)).into_response()
} else { } else {
(status, Json(body)).into_response() (status, Json(body)).into_response()
@@ -139,7 +142,10 @@ mod tests {
fn forbidden_with_some_identity_maps_to_403() { fn forbidden_with_some_identity_maps_to_403() {
let error = CallError::forbidden("insufficient scopes"); let error = CallError::forbidden("insufficient scopes");
let id = identity(); let id = identity();
assert_eq!(call_error_to_http_status_with_identity(&error, Some(&id)), 403); assert_eq!(
call_error_to_http_status_with_identity(&error, Some(&id)),
403
);
} }
#[test] #[test]
@@ -213,7 +219,10 @@ mod tests {
let error = CallError::new("HTTP_503", "slow down", true); let error = CallError::new("HTTP_503", "slow down", true);
let response = call_error_to_http_response(&error); let response = call_error_to_http_response(&error);
assert_eq!(response.status(), StatusCode::from_u16(503).unwrap()); assert_eq!(response.status(), StatusCode::from_u16(503).unwrap());
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); assert!(response
.headers()
.get(axum::http::header::RETRY_AFTER)
.is_none());
} }
#[test] #[test]
@@ -221,7 +230,10 @@ mod tests {
let error = CallError::new("HTTP_503", "down", false) let error = CallError::new("HTTP_503", "down", false)
.with_details(serde_json::json!({ "retry_after": "5" })); .with_details(serde_json::json!({ "retry_after": "5" }));
let response = call_error_to_http_response(&error); let response = call_error_to_http_response(&error);
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); assert!(response
.headers()
.get(axum::http::header::RETRY_AFTER)
.is_none());
} }
#[test] #[test]
@@ -241,7 +253,10 @@ mod tests {
let error = CallError::timeout("timed out"); let error = CallError::timeout("timed out");
let response = call_error_to_http_response(&error); let response = call_error_to_http_response(&error);
assert_eq!(response.status(), StatusCode::from_u16(504).unwrap()); assert_eq!(response.status(), StatusCode::from_u16(504).unwrap());
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none()); assert!(response
.headers()
.get(axum::http::header::RETRY_AFTER)
.is_none());
} }
#[test] #[test]

View File

@@ -694,4 +694,22 @@ mod tests {
); );
assert!(response.contains("location: https://example.com")); assert!(response.contains("location: https://example.com"));
} }
#[tokio::test]
async fn openapi_json_route_serves_gateway_spec() {
let adapter = HttpAdapter::new(provider(), empty_registry());
let request = b"GET /openapi.json HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
let response = serve_and_read(adapter, request).await;
assert!(
response.starts_with("HTTP/1.1 200"),
"expected 200 for /openapi.json, got: {response}"
);
assert!(response.contains("\"openapi\""));
assert!(response.contains("\"/search\""));
assert!(response.contains("\"/schema\""));
assert!(response.contains("\"/call\""));
assert!(response.contains("\"/batch\""));
assert!(response.contains("\"/subscribe\""));
assert!(response.contains("\"1.0.0\""));
}
} }

View File

@@ -80,11 +80,12 @@ where
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request_parts( async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts: &mut Parts, let identity = parts
_state: &S, .extensions
) -> Result<Self, Self::Rejection> { .get::<Option<Identity>>()
let identity = parts.extensions.get::<Option<Identity>>().cloned().flatten(); .cloned()
.flatten();
Ok(ResolvedIdentity(identity)) Ok(ResolvedIdentity(identity))
} }
} }
@@ -174,15 +175,16 @@ mod tests {
assert!(identity.is_none()); assert!(identity.is_none());
} }
async fn run_middleware( async fn run_middleware(idp: Arc<dyn IdentityProvider>, request: Request) -> Response {
idp: Arc<dyn IdentityProvider>,
request: Request,
) -> Response {
let app: Router<()> = Router::new() let app: Router<()> = Router::new()
.route( .route(
"/", "/",
get(|req: Request| async move { get(|req: Request| async move {
let identity = req.extensions().get::<Option<Identity>>().cloned().flatten(); let identity = req
.extensions()
.get::<Option<Identity>>()
.cloned()
.flatten();
if let Some(id) = identity { if let Some(id) = identity {
(StatusCode::OK, id.id) (StatusCode::OK, id.id)
} else { } else {
@@ -261,14 +263,12 @@ mod tests {
let app: Router<()> = Router::new() let app: Router<()> = Router::new()
.route( .route(
"/", "/",
get( get(|ResolvedIdentity(identity): ResolvedIdentity| async move {
|ResolvedIdentity(identity): ResolvedIdentity| async move { match identity {
match identity { Some(id) => (StatusCode::OK, id.id),
Some(id) => (StatusCode::OK, id.id), None => (StatusCode::OK, "none".to_string()),
None => (StatusCode::OK, "none".to_string()), }
} }),
},
),
) )
.layer(from_fn_with_state(idp, bearer_auth_middleware)); .layer(from_fn_with_state(idp, bearer_auth_middleware));
@@ -287,14 +287,12 @@ mod tests {
let app: Router<()> = Router::new() let app: Router<()> = Router::new()
.route( .route(
"/", "/",
get( get(|ResolvedIdentity(identity): ResolvedIdentity| async move {
|ResolvedIdentity(identity): ResolvedIdentity| async move { match identity {
match identity { Some(id) => (StatusCode::OK, id.id),
Some(id) => (StatusCode::OK, id.id), None => (StatusCode::OK, "none".to_string()),
None => (StatusCode::OK, "none".to_string()), }
} }),
},
),
) )
.layer(from_fn_with_state(idp, bearer_auth_middleware)); .layer(from_fn_with_state(idp, bearer_auth_middleware));

View File

@@ -33,10 +33,8 @@ pub fn fake_nginx_404() -> Response {
header::CONTENT_TYPE, header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"), HeaderValue::from_static("text/html; charset=utf-8"),
); );
resp.headers_mut().insert( resp.headers_mut()
header::SERVER, .insert(header::SERVER, HeaderValue::from_static("nginx"));
HeaderValue::from_static("nginx"),
);
resp resp
} }
@@ -61,10 +59,8 @@ pub async fn serve_static(root: &Path, request: Request) -> Response {
let content_type = mime_for_path(&resolved); let content_type = mime_for_path(&resolved);
let mut resp = Response::new(Body::from(bytes)); let mut resp = Response::new(Body::from(bytes));
*resp.status_mut() = StatusCode::OK; *resp.status_mut() = StatusCode::OK;
resp.headers_mut().insert( resp.headers_mut()
header::CONTENT_TYPE, .insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type));
HeaderValue::from_static(content_type),
);
resp resp
} }
Err(_) => fake_nginx_404(), Err(_) => fake_nginx_404(),
@@ -173,10 +169,7 @@ mod tests {
async fn send(router: axum::Router, uri: &str) -> axum::response::Response { async fn send(router: axum::Router, uri: &str) -> axum::response::Response {
tower::ServiceExt::<Request<Body>>::oneshot( tower::ServiceExt::<Request<Body>>::oneshot(
router, router,
Request::builder() Request::builder().uri(uri).body(Body::empty()).unwrap(),
.uri(uri)
.body(Body::empty())
.unwrap(),
) )
.await .await
.unwrap() .unwrap()
@@ -220,9 +213,7 @@ mod tests {
async fn unknown_path_with_static_site_decoy_serves_file() { async fn unknown_path_with_static_site_decoy_serves_file() {
let dir = tempfile_dir(); let dir = tempfile_dir();
let file = dir.join("index.html"); let file = dir.join("index.html");
tokio::fs::write(&file, "<h1>hello</h1>") tokio::fs::write(&file, "<h1>hello</h1>").await.unwrap();
.await
.unwrap();
let decoy = DecoyConfig::StaticSite { root: dir.clone() }; let decoy = DecoyConfig::StaticSite { root: dir.clone() };
let resp = send(decoy_router(decoy), "/").await; let resp = send(decoy_router(decoy), "/").await;
@@ -293,10 +284,8 @@ mod tests {
} }
fn tempfile_dir() -> PathBuf { fn tempfile_dir() -> PathBuf {
let dir = PathBuf::from("/tmp").join(format!( let dir =
"alknet-http-decoy-test-{}", PathBuf::from("/tmp").join(format!("alknet-http-decoy-test-{}", uuid::Uuid::new_v4()));
uuid::Uuid::new_v4()
));
std::fs::create_dir_all(&dir).unwrap(); std::fs::create_dir_all(&dir).unwrap();
dir dir
} }

View File

@@ -52,13 +52,19 @@ impl GatewayState {
} }
fn dispatch(&self) -> GatewayDispatch { fn dispatch(&self) -> GatewayDispatch {
GatewayDispatch::new(Arc::clone(&self.registry), Arc::clone(&self.identity_provider)) GatewayDispatch::new(
Arc::clone(&self.registry),
Arc::clone(&self.identity_provider),
)
} }
} }
impl FromRef<RouterState> for GatewayState { impl FromRef<RouterState> for GatewayState {
fn from_ref(state: &RouterState) -> Self { fn from_ref(state: &RouterState) -> Self {
GatewayState::new(Arc::clone(&state.registry), Arc::clone(&state.identity_provider)) GatewayState::new(
Arc::clone(&state.registry),
Arc::clone(&state.identity_provider),
)
} }
} }
@@ -92,7 +98,9 @@ pub(crate) async fn call_handler(
return not_found_response(&request.operation); return not_found_response(&request.operation);
} }
let dispatch = state.dispatch(); let dispatch = state.dispatch();
let envelope = dispatch.invoke(identity.clone(), &request.operation, request.input).await; let envelope = dispatch
.invoke(identity.clone(), &request.operation, request.input)
.await;
envelope_to_response(envelope, identity.as_ref()) envelope_to_response(envelope, identity.as_ref())
} }
@@ -101,7 +109,9 @@ pub(crate) async fn search_handler(
ResolvedIdentity(identity): ResolvedIdentity, ResolvedIdentity(identity): ResolvedIdentity,
) -> Response { ) -> Response {
let dispatch = state.dispatch(); let dispatch = state.dispatch();
let envelope = dispatch.invoke(identity.clone(), SERVICES_LIST, json!({})).await; let envelope = dispatch
.invoke(identity.clone(), SERVICES_LIST, json!({}))
.await;
envelope_to_response(envelope, identity.as_ref()) envelope_to_response(envelope, identity.as_ref())
} }
@@ -115,7 +125,11 @@ pub(crate) async fn schema_handler(
} }
let dispatch = state.dispatch(); let dispatch = state.dispatch();
let envelope = dispatch let envelope = dispatch
.invoke(identity.clone(), SERVICES_SCHEMA, json!({ "name": query.name })) .invoke(
identity.clone(),
SERVICES_SCHEMA,
json!({ "name": query.name }),
)
.await; .await;
envelope_to_response(envelope, identity.as_ref()) envelope_to_response(envelope, identity.as_ref())
} }
@@ -149,7 +163,9 @@ pub(crate) async fn subscribe_handler(
subscribe_stream_internal_error(request.operation) subscribe_stream_internal_error(request.operation)
} else { } else {
let dispatch = state.dispatch(); let dispatch = state.dispatch();
let envelope = dispatch.invoke(identity, &request.operation, request.input).await; let envelope = dispatch
.invoke(identity, &request.operation, request.input)
.await;
subscribe_stream_from_envelope(envelope) subscribe_stream_from_envelope(envelope)
}; };
Sse::new(stream) Sse::new(stream)
@@ -221,8 +237,7 @@ fn not_found_response(operation: &str) -> Response {
fn forbidden_response(message: String, identity: Option<&Identity>) -> Response { fn forbidden_response(message: String, identity: Option<&Identity>) -> Response {
let error = CallError::forbidden(message); let error = CallError::forbidden(message);
let status_code = call_error_to_http_status_with_identity(&error, identity); let status_code = call_error_to_http_status_with_identity(&error, identity);
let status = let status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = serde_json::to_value(&error).unwrap_or(Value::Null); let body = serde_json::to_value(&error).unwrap_or(Value::Null);
(status, Json(body)).into_response() (status, Json(body)).into_response()
} }
@@ -248,7 +263,9 @@ fn is_internal_op(registry: &OperationRegistry, operation: &str) -> bool {
} }
} }
fn envelope_to_sse_stream(envelope: ResponseEnvelope) -> impl Stream<Item = Result<Event, Infallible>> { fn envelope_to_sse_stream(
envelope: ResponseEnvelope,
) -> impl Stream<Item = Result<Event, Infallible>> {
stream::once(async move { stream::once(async move {
match envelope.result { match envelope.result {
Ok(output) => { Ok(output) => {
@@ -756,7 +773,10 @@ mod tests {
.get(axum::http::header::CONTENT_TYPE) .get(axum::http::header::CONTENT_TYPE)
.map(|v| v.to_str().unwrap().to_string()); .map(|v| v.to_str().unwrap().to_string());
assert!( assert!(
ctype.as_deref().unwrap_or("").starts_with("text/event-stream"), ctype
.as_deref()
.unwrap_or("")
.starts_with("text/event-stream"),
"expected text/event-stream, got {ctype:?}" "expected text/event-stream, got {ctype:?}"
); );
let bytes = resp.into_body().collect().await.unwrap().to_bytes(); let bytes = resp.into_body().collect().await.unwrap().to_bytes();

View File

@@ -128,7 +128,10 @@ mod tests {
let out: EventEnvelope = response.into(); let out: EventEnvelope = response.into();
assert_eq!(out.r#type, EVENT_RESPONDED); assert_eq!(out.r#type, EVENT_RESPONDED);
assert_eq!(out.id, "ws-rt-1"); assert_eq!(out.id, "ws-rt-1");
assert_eq!(out.payload.get("output"), Some(&serde_json::json!({ "v": 7 }))); assert_eq!(
out.payload.get("output"),
Some(&serde_json::json!({ "v": 7 }))
);
} }
#[tokio::test] #[tokio::test]
@@ -160,7 +163,10 @@ mod tests {
async fn ws_overlay_only_connection_holds_overlay_and_pending() { async fn ws_overlay_only_connection_holds_overlay_and_pending() {
let conn = CallConnection::new_overlay_only(identity("ws-peer")); let conn = CallConnection::new_overlay_only(identity("ws-peer"));
assert!(conn.connection().is_none()); assert!(conn.connection().is_none());
assert_eq!(conn.identity().map(|i| i.id.clone()), Some("ws-peer".to_string())); assert_eq!(
conn.identity().map(|i| i.id.clone()),
Some("ws-peer".to_string())
);
assert!(conn.pending().lock().is_empty()); assert!(conn.pending().lock().is_empty());
let env = conn.overlay_env(); let env = conn.overlay_env();

View File

@@ -84,8 +84,9 @@ async fn ws_upgrade_handler_inner(
}; };
match ws_upgrade { match ws_upgrade {
Some(upgrade) => upgrade Some(upgrade) => upgrade.on_upgrade(move |socket| {
.on_upgrade(move |socket| run_ws_session(socket, registry, identity_provider, identity)), run_ws_session(socket, registry, identity_provider, identity)
}),
None => { None => {
let _ = registry; let _ = registry;
let _ = identity_provider; let _ = identity_provider;
@@ -240,19 +241,19 @@ fn serialize_envelope(envelope: &EventEnvelope) -> Result<Vec<u8>, serde_json::E
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use alknet_call::registry::context::{
AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv,
};
use alknet_call::registry::discovery::{ use alknet_call::registry::discovery::{
services_list_handler, services_list_spec, services_schema_handler, services_schema_spec, services_list_handler, services_list_spec, services_schema_handler, services_schema_spec,
}; };
use alknet_call::registry::env::OperationEnv;
use alknet_call::registry::registration::{ use alknet_call::registry::registration::{
make_handler, HandlerRegistration, OperationProvenance, make_handler, HandlerRegistration, OperationProvenance,
}; };
use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility}; use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
use alknet_core::auth::{AuthToken, Identity}; use alknet_core::auth::{AuthToken, Identity};
use alknet_core::types::Capabilities; use alknet_core::types::Capabilities;
use alknet_call::registry::context::{
AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv,
};
use alknet_call::registry::env::OperationEnv;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Mutex as StdMutex; use std::sync::Mutex as StdMutex;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@@ -331,9 +332,7 @@ mod tests {
let mut registry = OperationRegistry::new(); let mut registry = OperationRegistry::new();
registry.register(HandlerRegistration::new( registry.register(HandlerRegistration::new(
external_spec("echo/run", AccessControl::default()), external_spec("echo/run", AccessControl::default()),
make_handler(|input, ctx| async move { make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }),
ResponseEnvelope::ok(ctx.request_id, input)
}),
OperationProvenance::Local, OperationProvenance::Local,
None, None,
None, None,
@@ -352,9 +351,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
), ),
make_handler(|input, ctx| async move { make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }),
ResponseEnvelope::ok(ctx.request_id, input)
}),
OperationProvenance::Local, OperationProvenance::Local,
None, None,
None, None,
@@ -519,9 +516,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handle_inbound_envelope_forbidden_yields_call_error() { async fn handle_inbound_envelope_forbidden_yields_call_error() {
let registry = registry_with_restricted_op(); let registry = registry_with_restricted_op();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("none", identity("unpriv")), Arc::new(StaticIdentityProvider::new().with_token("none", identity("unpriv")));
);
let dp = dispatcher(registry, provider); let dp = dispatcher(registry, provider);
let conn = Arc::new(CallConnection::new_overlay_only(identity("unpriv"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("unpriv")));
@@ -727,9 +723,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn round_trip_call_requested_to_call_responded_over_ws_message_stream() { async fn round_trip_call_requested_to_call_responded_over_ws_message_stream() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -753,9 +748,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn subscription_streams_multiple_call_responded_events() { async fn subscription_streams_multiple_call_responded_events() {
let registry = registry_with_subscription(); let registry = registry_with_subscription();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(registry, provider); let dp = dispatcher(registry, provider);
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -782,8 +776,10 @@ mod tests {
.with_token("no-admin", identity_with_scopes("user", &["user"])), .with_token("no-admin", identity_with_scopes("user", &["user"])),
); );
let dp = dispatcher(registry, provider); let dp = dispatcher(registry, provider);
let conn = let conn = Arc::new(CallConnection::new_overlay_only(identity_with_scopes(
Arc::new(CallConnection::new_overlay_only(identity_with_scopes("user", &["user"]))); "user",
&["user"],
)));
let request = EventEnvelope::requested( let request = EventEnvelope::requested(
"req-admin", "req-admin",
@@ -882,8 +878,10 @@ mod tests {
let overlay_env = conn.overlay_env(); let overlay_env = conn.overlay_env();
assert!(overlay_env.contains("ui/dragged")); assert!(overlay_env.contains("ui/dragged"));
let composed_env: Arc<dyn OperationEnv + Send + Sync> = dp let composed_env: Arc<dyn OperationEnv + Send + Sync> = dp.compose_root_env(
.compose_root_env(&conn, &root_context_for_compose("hub-call-1", overlay_env.clone())); &conn,
&root_context_for_compose("hub-call-1", overlay_env.clone()),
);
let ctx = root_context_with_env("hub-call-1", composed_env); let ctx = root_context_with_env("hub-call-1", composed_env);
let response = overlay_env let response = overlay_env
.invoke("ui", "dragged", serde_json::json!({ "x": 5 }), &ctx) .invoke("ui", "dragged", serde_json::json!({ "x": 5 }), &ctx)
@@ -935,9 +933,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_round_trips_binary_call_requested_to_call_responded() { async fn drive_ws_session_round_trips_binary_call_requested_to_call_responded() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -960,7 +957,10 @@ mod tests {
let env: EventEnvelope = serde_json::from_slice(&bytes).unwrap(); let env: EventEnvelope = serde_json::from_slice(&bytes).unwrap();
assert_eq!(env.r#type, EVENT_RESPONDED); assert_eq!(env.r#type, EVENT_RESPONDED);
assert_eq!(env.id, "ws-socket-1"); assert_eq!(env.id, "ws-socket-1");
assert_eq!(env.payload.get("output"), Some(&serde_json::json!({ "v": 7 }))); assert_eq!(
env.payload.get("output"),
Some(&serde_json::json!({ "v": 7 }))
);
} }
other => panic!("expected binary, got {other:?}"), other => panic!("expected binary, got {other:?}"),
} }
@@ -972,9 +972,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_rejects_text_with_protocol_close() { async fn drive_ws_session_rejects_text_with_protocol_close() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -999,9 +998,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_disconnect_aborts_in_flight_pending() { async fn drive_ws_session_disconnect_aborts_in_flight_pending() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -1036,9 +1034,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_subscription_streams_call_responded_events() { async fn drive_ws_session_subscription_streams_call_responded_events() {
let registry = registry_with_subscription(); let registry = registry_with_subscription();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -1077,9 +1074,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_invalid_binary_closes_with_protocol_error() { async fn drive_ws_session_invalid_binary_closes_with_protocol_error() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -1102,9 +1098,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drive_ws_session_client_close_terminates_server() { async fn drive_ws_session_client_close_terminates_server() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider)); let dp = dispatcher(Arc::clone(&registry), Arc::clone(&provider));
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer"))); let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
@@ -1164,17 +1159,11 @@ mod tests {
} }
async fn send_text(&mut self, text: String) { async fn send_text(&mut self, text: String) {
self.outbound_tx self.outbound_tx.send(Message::Text(text.into())).await.ok();
.send(Message::Text(text.into()))
.await
.ok();
} }
async fn send_close(&mut self) { async fn send_close(&mut self) {
self.outbound_tx self.outbound_tx.send(Message::Close(None)).await.ok();
.send(Message::Close(None))
.await
.ok();
} }
async fn close(&mut self) { async fn close(&mut self) {
@@ -1215,9 +1204,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn ws_upgrade_handler_returns_401_when_identity_is_none() { async fn ws_upgrade_handler_returns_401_when_identity_is_none() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let identity: Option<Identity> = None; let identity: Option<Identity> = None;
let response = ws_upgrade_handler_inner(registry, provider, identity, None).await; let response = ws_upgrade_handler_inner(registry, provider, identity, None).await;
@@ -1227,9 +1215,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn ws_upgrade_handler_does_not_reject_when_identity_present() { async fn ws_upgrade_handler_does_not_reject_when_identity_present() {
let registry = echo_registry(); let registry = echo_registry();
let provider: Arc<dyn IdentityProvider> = Arc::new( let provider: Arc<dyn IdentityProvider> =
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")), Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
);
let identity = identity("ws-peer"); let identity = identity("ws-peer");
let response = ws_upgrade_handler_inner(registry, provider, Some(identity), None).await; let response = ws_upgrade_handler_inner(registry, provider, Some(identity), None).await;

View File

@@ -9,11 +9,11 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use alknet_call::client::OperationAdapter;
use alknet_call::protocol::wire::ResponseEnvelope; use alknet_call::protocol::wire::ResponseEnvelope;
use alknet_call::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv}; use alknet_call::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv};
use alknet_call::registry::env::OperationEnv; use alknet_call::registry::env::OperationEnv;
use alknet_call::registry::registration::OperationProvenance; use alknet_call::registry::registration::OperationProvenance;
use alknet_call::client::OperationAdapter;
use alknet_core::types::Capabilities; use alknet_core::types::Capabilities;
use alknet_http::adapters::FromMCP; use alknet_http::adapters::FromMCP;
use axum::Router; use axum::Router;
@@ -22,8 +22,8 @@ use rmcp::model::{
}; };
use rmcp::service::RequestContext; use rmcp::service::RequestContext;
use rmcp::transport::{ use rmcp::transport::{
StreamableHttpServerConfig,
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService}, streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
StreamableHttpServerConfig,
}; };
use rmcp::{RoleServer, ServerHandler}; use rmcp::{RoleServer, ServerHandler};
use serde_json::Value; use serde_json::Value;
@@ -72,18 +72,19 @@ impl ServerHandler for EchoServer {
&self, &self,
_request: Option<PaginatedRequestParams>, _request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>, _context: RequestContext<RoleServer>,
) -> impl std::future::Future< ) -> impl std::future::Future<Output = Result<ListToolsResult, rmcp::ErrorData>>
Output = Result<ListToolsResult, rmcp::ErrorData>, + rmcp::service::MaybeSendFuture
> + rmcp::service::MaybeSendFuture + '_ { + '_ {
let tools = vec![ let tools = vec![
Tool::new_with_raw( Tool::new_with_raw(
"echo", "echo",
Some("Echo the input back as structured content".into()), Some("Echo the input back as structured content".into()),
Arc::new(serde_json::Map::new()), Arc::new(serde_json::Map::new()),
) )
.with_raw_output_schema(Arc::new(serde_json::Map::from_iter([ .with_raw_output_schema(Arc::new(serde_json::Map::from_iter([(
("type".to_string(), Value::String("object".into())), "type".to_string(),
]))), Value::String("object".into()),
)]))),
Tool::new_with_raw( Tool::new_with_raw(
"legacy", "legacy",
Some("Legacy tool returning text content blocks".into()), Some("Legacy tool returning text content blocks".into()),
@@ -101,22 +102,17 @@ impl ServerHandler for EchoServer {
&self, &self,
request: CallToolRequestParams, request: CallToolRequestParams,
_context: RequestContext<RoleServer>, _context: RequestContext<RoleServer>,
) -> impl std::future::Future< ) -> impl std::future::Future<Output = Result<CallToolResult, rmcp::ErrorData>>
Output = Result<CallToolResult, rmcp::ErrorData>, + rmcp::service::MaybeSendFuture
> + rmcp::service::MaybeSendFuture + '_ { + '_ {
let name = request.name.to_string(); let name = request.name.to_string();
std::future::ready(Ok(match name.as_str() { std::future::ready(Ok(match name.as_str() {
"echo" => { "echo" => {
let args = request let args = request.arguments.map(Value::Object).unwrap_or(Value::Null);
.arguments
.map(Value::Object)
.unwrap_or(Value::Null);
CallToolResult::structured(serde_json::json!({ "echoed": args })) CallToolResult::structured(serde_json::json!({ "echoed": args }))
} }
"legacy" => CallToolResult::success(vec![Content::text("plain text result")]), "legacy" => CallToolResult::success(vec![Content::text("plain text result")]),
other => CallToolResult::error(vec![Content::text(format!( other => CallToolResult::error(vec![Content::text(format!("unknown tool: {other}"))]),
"unknown tool: {other}"
))]),
})) }))
} }