review(http): mark http/review-mcp completed + fix formatting across crate
Review-mcp verification complete: all 12 checklist items pass (from_mcp/to_mcp conformance, ADR-037/041/014/023/034, feature gate isolation, GatewayDispatch concrete struct, test coverage 223+5). Applied cargo fmt across crate.
This commit is contained in:
@@ -22,7 +22,11 @@ fn make_tool(name: &str, input: Value, output: Option<Value>) -> Tool {
|
|||||||
tool
|
tool
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_tool_result(content: Vec<Content>, structured: Option<Value>, is_error: Option<bool>) -> CallToolResult {
|
fn call_tool_result(
|
||||||
|
content: Vec<Content>,
|
||||||
|
structured: Option<Value>,
|
||||||
|
is_error: Option<bool>,
|
||||||
|
) -> CallToolResult {
|
||||||
let json = serde_json::json!({
|
let json = serde_json::json!({
|
||||||
"content": content,
|
"content": content,
|
||||||
"structuredContent": structured,
|
"structuredContent": structured,
|
||||||
@@ -204,7 +208,9 @@ fn build_spec_output_schema_present_shape() {
|
|||||||
let tool = make_tool(
|
let tool = make_tool(
|
||||||
"get_weather",
|
"get_weather",
|
||||||
serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } } }),
|
serde_json::json!({ "type": "object", "properties": { "city": { "type": "string" } } }),
|
||||||
Some(serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } })),
|
Some(
|
||||||
|
serde_json::json!({ "type": "object", "properties": { "temperature": { "type": "number" } } }),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
let spec = build_spec(&tool, "weather");
|
let spec = build_spec(&tool, "weather");
|
||||||
assert_eq!(spec.name, "weather/get_weather");
|
assert_eq!(spec.name, "weather/get_weather");
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ use std::sync::Arc;
|
|||||||
use alknet_call::client::{AdapterError, OperationAdapter};
|
use alknet_call::client::{AdapterError, OperationAdapter};
|
||||||
use alknet_call::protocol::wire::{CallError, ResponseEnvelope};
|
use alknet_call::protocol::wire::{CallError, ResponseEnvelope};
|
||||||
use alknet_call::registry::context::OperationContext;
|
use alknet_call::registry::context::OperationContext;
|
||||||
use alknet_call::registry::registration::{
|
use alknet_call::registry::registration::{make_handler, HandlerRegistration, OperationProvenance};
|
||||||
make_handler, HandlerRegistration, OperationProvenance,
|
use alknet_call::registry::spec::{
|
||||||
|
AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility,
|
||||||
};
|
};
|
||||||
use alknet_call::registry::spec::{AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility};
|
|
||||||
use alknet_core::types::Capabilities;
|
use alknet_core::types::Capabilities;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
@@ -128,9 +128,7 @@ impl OpenAPISpec {
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let paths_raw = raw
|
let paths_raw = raw.get("paths").ok_or_else(|| AdapterError::SchemaParse {
|
||||||
.get("paths")
|
|
||||||
.ok_or_else(|| AdapterError::SchemaParse {
|
|
||||||
message: "OpenAPI document missing `paths`".into(),
|
message: "OpenAPI document missing `paths`".into(),
|
||||||
})?;
|
})?;
|
||||||
if !paths_raw.is_object() {
|
if !paths_raw.is_object() {
|
||||||
@@ -155,14 +153,13 @@ impl OpenAPISpec {
|
|||||||
if operations.is_empty() {
|
if operations.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
paths.insert(
|
paths.insert(path.clone(), PathItem { operations });
|
||||||
path.clone(),
|
|
||||||
PathItem { operations },
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let components = raw.get("components").and_then(|c| c.get("schemas")).and_then(
|
let components = raw
|
||||||
|schemas| {
|
.get("components")
|
||||||
|
.and_then(|c| c.get("schemas"))
|
||||||
|
.and_then(|schemas| {
|
||||||
if !schemas.is_object() {
|
if !schemas.is_object() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -171,8 +168,7 @@ impl OpenAPISpec {
|
|||||||
map.insert(k.clone(), v.clone());
|
map.insert(k.clone(), v.clone());
|
||||||
}
|
}
|
||||||
Some(Components { schemas: map })
|
Some(Components { schemas: map })
|
||||||
},
|
});
|
||||||
);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
info,
|
info,
|
||||||
@@ -190,9 +186,7 @@ impl OpenAPISpec {
|
|||||||
}
|
}
|
||||||
let mut current: &Value = &self.raw;
|
let mut current: &Value = &self.raw;
|
||||||
for part in reference.trim_start_matches("#/").split('/') {
|
for part in reference.trim_start_matches("#/").split('/') {
|
||||||
current = current
|
current = current.get(part).ok_or_else(|| AdapterError::SchemaParse {
|
||||||
.get(part)
|
|
||||||
.ok_or_else(|| AdapterError::SchemaParse {
|
|
||||||
message: format!("cannot resolve $ref: {reference}"),
|
message: format!("cannot resolve $ref: {reference}"),
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
@@ -241,10 +235,7 @@ fn parse_operation(raw: &Value) -> Option<Operation> {
|
|||||||
.filter_map(|p| {
|
.filter_map(|p| {
|
||||||
let name = p.get("name")?.as_str()?.to_string();
|
let name = p.get("name")?.as_str()?.to_string();
|
||||||
let in_ = p.get("in")?.as_str()?.to_string();
|
let in_ = p.get("in")?.as_str()?.to_string();
|
||||||
let required = p
|
let required = p.get("required").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||||
.get("required")
|
|
||||||
.and_then(|v| v.as_bool())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let schema = p.get("schema").cloned();
|
let schema = p.get("schema").cloned();
|
||||||
Some(Parameter {
|
Some(Parameter {
|
||||||
name,
|
name,
|
||||||
@@ -297,7 +288,11 @@ pub struct FromOpenAPI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FromOpenAPI {
|
impl FromOpenAPI {
|
||||||
pub fn new(spec: OpenAPISpec, config: HttpServiceConfig, http_client: Arc<SharedHttpClient>) -> Self {
|
pub fn new(
|
||||||
|
spec: OpenAPISpec,
|
||||||
|
config: HttpServiceConfig,
|
||||||
|
http_client: Arc<SharedHttpClient>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
spec,
|
spec,
|
||||||
config,
|
config,
|
||||||
@@ -322,10 +317,7 @@ impl FromOpenAPI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn detect_op_type(method: &str, op: &Operation) -> OperationType {
|
fn detect_op_type(method: &str, op: &Operation) -> OperationType {
|
||||||
let success = op
|
let success = op.responses.get("200").or_else(|| op.responses.get("201"));
|
||||||
.responses
|
|
||||||
.get("200")
|
|
||||||
.or_else(|| op.responses.get("201"));
|
|
||||||
if let Some(resp) = success {
|
if let Some(resp) = success {
|
||||||
if resp.content.contains_key("text/event-stream") {
|
if resp.content.contains_key("text/event-stream") {
|
||||||
return OperationType::Subscription;
|
return OperationType::Subscription;
|
||||||
@@ -531,9 +523,8 @@ fn build_request(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let base = Url::parse(base_url).map_err(|e| {
|
let base = Url::parse(base_url)
|
||||||
CallError::internal(format!("invalid base_url `{base_url}`: {e}"))
|
.map_err(|e| CallError::internal(format!("invalid base_url `{base_url}`: {e}")))?;
|
||||||
})?;
|
|
||||||
let mut url = base
|
let mut url = base
|
||||||
.join(url_path.trim_start_matches('/'))
|
.join(url_path.trim_start_matches('/'))
|
||||||
.map_err(|e| CallError::internal(format!("invalid path `{url_path}`: {e}")))?;
|
.map_err(|e| CallError::internal(format!("invalid path `{url_path}`: {e}")))?;
|
||||||
@@ -683,11 +674,12 @@ async fn forward(
|
|||||||
.find(|(s, _)| *s == status.as_u16())
|
.find(|(s, _)| *s == status.as_u16())
|
||||||
.map(|(_, code)| code.clone())
|
.map(|(_, code)| code.clone())
|
||||||
.unwrap_or_else(|| format!("HTTP_{}", status.as_u16()));
|
.unwrap_or_else(|| format!("HTTP_{}", status.as_u16()));
|
||||||
let message = format!("HTTP {}: {}", status.as_u16(), status.canonical_reason().unwrap_or(""));
|
let message = format!(
|
||||||
return ResponseEnvelope::error(
|
"HTTP {}: {}",
|
||||||
request_id,
|
status.as_u16(),
|
||||||
CallError::new(code, message, false),
|
status.canonical_reason().unwrap_or("")
|
||||||
);
|
);
|
||||||
|
return ResponseEnvelope::error(request_id, CallError::new(code, message, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
let content_type = response
|
let content_type = response
|
||||||
@@ -716,10 +708,7 @@ async fn forward(
|
|||||||
} else {
|
} else {
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(b) => {
|
Ok(b) => {
|
||||||
let arr: Vec<Value> = b
|
let arr: Vec<Value> = b.iter().map(|byte| Value::Number((*byte).into())).collect();
|
||||||
.iter()
|
|
||||||
.map(|byte| Value::Number((*byte).into()))
|
|
||||||
.collect();
|
|
||||||
ResponseEnvelope::ok(request_id, Value::Array(arr))
|
ResponseEnvelope::ok(request_id, Value::Array(arr))
|
||||||
}
|
}
|
||||||
Err(err) => ResponseEnvelope::error(
|
Err(err) => ResponseEnvelope::error(
|
||||||
@@ -744,7 +733,8 @@ async fn stream_subscription(request_id: String, response: reqwest::Response) ->
|
|||||||
let parsed = if event.data.trim().is_empty() {
|
let parsed = if event.data.trim().is_empty() {
|
||||||
Value::Null
|
Value::Null
|
||||||
} else {
|
} else {
|
||||||
serde_json::from_str(&event.data).unwrap_or(Value::String(event.data.clone()))
|
serde_json::from_str(&event.data)
|
||||||
|
.unwrap_or(Value::String(event.data.clone()))
|
||||||
};
|
};
|
||||||
last_event = Some(parsed.clone());
|
last_event = Some(parsed.clone());
|
||||||
}
|
}
|
||||||
@@ -1040,7 +1030,12 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let body = props.get("body").unwrap();
|
let body = props.get("body").unwrap();
|
||||||
assert_eq!(body.get("type").unwrap(), "object");
|
assert_eq!(body.get("type").unwrap(), "object");
|
||||||
assert!(body.get("properties").unwrap().as_object().unwrap().contains_key("name"));
|
assert!(body
|
||||||
|
.get("properties")
|
||||||
|
.unwrap()
|
||||||
|
.as_object()
|
||||||
|
.unwrap()
|
||||||
|
.contains_key("name"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1074,14 +1069,19 @@ mod tests {
|
|||||||
"https://api.vast.ai",
|
"https://api.vast.ai",
|
||||||
"/machines",
|
"/machines",
|
||||||
"GET",
|
"GET",
|
||||||
&Some(HttpAuthScheme::ApiKey { header_name: "X-API-Key".to_string() }),
|
&Some(HttpAuthScheme::ApiKey {
|
||||||
|
header_name: "X-API-Key".to_string(),
|
||||||
|
}),
|
||||||
&HashMap::new(),
|
&HashMap::new(),
|
||||||
"vastai",
|
"vastai",
|
||||||
&serde_json::json!({}),
|
&serde_json::json!({}),
|
||||||
&ctx,
|
&ctx,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(headers.get("X-API-Key").unwrap().to_str().unwrap(), "key-xyz");
|
assert_eq!(
|
||||||
|
headers.get("X-API-Key").unwrap().to_str().unwrap(),
|
||||||
|
"key-xyz"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1267,7 +1267,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn http_service_config_struct_fields() {
|
fn http_service_config_struct_fields() {
|
||||||
let cfg = config("ns", "https://api.example.com", Some(HttpAuthScheme::Bearer));
|
let cfg = config(
|
||||||
|
"ns",
|
||||||
|
"https://api.example.com",
|
||||||
|
Some(HttpAuthScheme::Bearer),
|
||||||
|
);
|
||||||
assert_eq!(cfg.namespace, "ns");
|
assert_eq!(cfg.namespace, "ns");
|
||||||
assert_eq!(cfg.base_url, "https://api.example.com");
|
assert_eq!(cfg.base_url, "https://api.example.com");
|
||||||
assert!(matches!(cfg.auth, Some(HttpAuthScheme::Bearer)));
|
assert!(matches!(cfg.auth, Some(HttpAuthScheme::Bearer)));
|
||||||
@@ -1289,7 +1293,12 @@ mod tests {
|
|||||||
}"#;
|
}"#;
|
||||||
let spec = OpenAPISpec::from_json(doc).unwrap();
|
let spec = OpenAPISpec::from_json(doc).unwrap();
|
||||||
assert!(spec.components.is_some());
|
assert!(spec.components.is_some());
|
||||||
assert!(spec.components.as_ref().unwrap().schemas.contains_key("Foo"));
|
assert!(spec
|
||||||
|
.components
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.schemas
|
||||||
|
.contains_key("Foo"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1342,7 +1351,9 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn resolve_ref_missing_target_returns_schema_parse() {
|
async fn resolve_ref_missing_target_returns_schema_parse() {
|
||||||
let spec = OpenAPISpec::from_json(minimal_spec_json()).unwrap();
|
let spec = OpenAPISpec::from_json(minimal_spec_json()).unwrap();
|
||||||
let err = spec.resolve_ref("#/components/schemas/Missing").unwrap_err();
|
let err = spec
|
||||||
|
.resolve_ref("#/components/schemas/Missing")
|
||||||
|
.unwrap_err();
|
||||||
assert!(matches!(err, AdapterError::SchemaParse { .. }));
|
assert!(matches!(err, AdapterError::SchemaParse { .. }));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1409,7 +1420,8 @@ mod tests {
|
|||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
});
|
});
|
||||||
let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}";
|
let response =
|
||||||
|
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}";
|
||||||
sock.write_all(response.as_bytes()).await.unwrap();
|
sock.write_all(response.as_bytes()).await.unwrap();
|
||||||
sock.flush().await.unwrap();
|
sock.flush().await.unwrap();
|
||||||
});
|
});
|
||||||
@@ -1440,12 +1452,19 @@ mod tests {
|
|||||||
ctx,
|
ctx,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
assert!(response.result.is_ok(), "expected Ok, got {:?}", response.result);
|
assert!(
|
||||||
|
response.result.is_ok(),
|
||||||
|
"expected Ok, got {:?}",
|
||||||
|
response.result
|
||||||
|
);
|
||||||
let captured = rx.await.unwrap();
|
let captured = rx.await.unwrap();
|
||||||
assert_eq!(captured.method, "POST");
|
assert_eq!(captured.method, "POST");
|
||||||
assert_eq!(captured.path, "/items/42");
|
assert_eq!(captured.path, "/items/42");
|
||||||
assert_eq!(captured.query, "filter=new");
|
assert_eq!(captured.query, "filter=new");
|
||||||
assert_eq!(captured.headers.get("content-type").unwrap(), "application/json");
|
assert_eq!(
|
||||||
|
captured.headers.get("content-type").unwrap(),
|
||||||
|
"application/json"
|
||||||
|
);
|
||||||
assert!(captured.body.contains("\"name\":\"widget\""));
|
assert!(captured.body.contains("\"name\":\"widget\""));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1457,10 +1476,7 @@ mod tests {
|
|||||||
}"#;
|
}"#;
|
||||||
let (base, rx) = spawn_capturing_server().await;
|
let (base, rx) = spawn_capturing_server().await;
|
||||||
let spec = OpenAPISpec::from_json(doc).unwrap();
|
let spec = OpenAPISpec::from_json(doc).unwrap();
|
||||||
let bundles = adapter(
|
let bundles = adapter(spec, config("openai", &base, Some(HttpAuthScheme::Bearer)))
|
||||||
spec,
|
|
||||||
config("openai", &base, Some(HttpAuthScheme::Bearer)),
|
|
||||||
)
|
|
||||||
.import()
|
.import()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1469,7 +1485,10 @@ mod tests {
|
|||||||
let ctx = noop_context("req-17", caps);
|
let ctx = noop_context("req-17", caps);
|
||||||
let _ = (registration.handler)(serde_json::json!({}), ctx).await;
|
let _ = (registration.handler)(serde_json::json!({}), ctx).await;
|
||||||
let captured = rx.await.unwrap();
|
let captured = rx.await.unwrap();
|
||||||
assert_eq!(captured.headers.get("authorization").unwrap(), "Bearer sk-test-token");
|
assert_eq!(
|
||||||
|
captured.headers.get("authorization").unwrap(),
|
||||||
|
"Bearer sk-test-token"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -22,4 +22,4 @@ pub use from_openapi::{FromOpenAPI, HttpAuthScheme, HttpServiceConfig, OpenAPISp
|
|||||||
pub use from_mcp::FromMCP;
|
pub use from_mcp::FromMCP;
|
||||||
|
|
||||||
#[cfg(feature = "mcp")]
|
#[cfg(feature = "mcp")]
|
||||||
pub use to_mcp::{ToMcpGateway, ToMcpService, to_mcp_service};
|
pub use to_mcp::{to_mcp_service, ToMcpGateway, ToMcpService};
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ use rmcp::model::{
|
|||||||
};
|
};
|
||||||
use rmcp::service::{RequestContext, RoleServer};
|
use rmcp::service::{RequestContext, RoleServer};
|
||||||
use rmcp::transport::{
|
use rmcp::transport::{
|
||||||
StreamableHttpServerConfig,
|
|
||||||
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
|
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
|
||||||
|
StreamableHttpServerConfig,
|
||||||
};
|
};
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
|
|
||||||
@@ -133,7 +133,10 @@ impl ToMcpGateway {
|
|||||||
|
|
||||||
fn extract_identity_from_extensions(extensions: &rmcp::model::Extensions) -> Option<Identity> {
|
fn extract_identity_from_extensions(extensions: &rmcp::model::Extensions) -> Option<Identity> {
|
||||||
let parts = extensions.get::<http::request::Parts>()?;
|
let parts = extensions.get::<http::request::Parts>()?;
|
||||||
parts.extensions.get::<Option<Identity>>().and_then(Option::clone)
|
parts
|
||||||
|
.extensions
|
||||||
|
.get::<Option<Identity>>()
|
||||||
|
.and_then(Option::clone)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_search(&self, identity: Option<Identity>) -> CallToolResult {
|
async fn handle_search(&self, identity: Option<Identity>) -> CallToolResult {
|
||||||
@@ -144,8 +147,15 @@ impl ToMcpGateway {
|
|||||||
map_search_response(response, identity.as_ref())
|
map_search_response(response, identity.as_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_schema(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult {
|
async fn handle_schema(
|
||||||
let name = match arguments.and_then(|mut a| a.remove("name")).and_then(|v| v.as_str().map(str::to_string)) {
|
&self,
|
||||||
|
arguments: Option<JsonObject>,
|
||||||
|
identity: Option<Identity>,
|
||||||
|
) -> CallToolResult {
|
||||||
|
let name = match arguments
|
||||||
|
.and_then(|mut a| a.remove("name"))
|
||||||
|
.and_then(|v| v.as_str().map(str::to_string))
|
||||||
|
{
|
||||||
Some(n) => n,
|
Some(n) => n,
|
||||||
None => {
|
None => {
|
||||||
return CallToolResult::structured_error(serde_json::json!({
|
return CallToolResult::structured_error(serde_json::json!({
|
||||||
@@ -156,12 +166,20 @@ impl ToMcpGateway {
|
|||||||
};
|
};
|
||||||
let response = self
|
let response = self
|
||||||
.dispatch
|
.dispatch
|
||||||
.invoke(identity, OP_SERVICES_SCHEMA, serde_json::json!({ "name": name }))
|
.invoke(
|
||||||
|
identity,
|
||||||
|
OP_SERVICES_SCHEMA,
|
||||||
|
serde_json::json!({ "name": name }),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
envelope_to_call_tool_result(response)
|
envelope_to_call_tool_result(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_call(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult {
|
async fn handle_call(
|
||||||
|
&self,
|
||||||
|
arguments: Option<JsonObject>,
|
||||||
|
identity: Option<Identity>,
|
||||||
|
) -> CallToolResult {
|
||||||
let (operation, input) = match parse_call_arguments(arguments) {
|
let (operation, input) = match parse_call_arguments(arguments) {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(err) => return err,
|
Err(err) => return err,
|
||||||
@@ -170,7 +188,11 @@ impl ToMcpGateway {
|
|||||||
envelope_to_call_tool_result(response)
|
envelope_to_call_tool_result(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_batch(&self, arguments: Option<JsonObject>, identity: Option<Identity>) -> CallToolResult {
|
async fn handle_batch(
|
||||||
|
&self,
|
||||||
|
arguments: Option<JsonObject>,
|
||||||
|
identity: Option<Identity>,
|
||||||
|
) -> CallToolResult {
|
||||||
let calls = match arguments
|
let calls = match arguments
|
||||||
.and_then(|mut a| a.remove("calls"))
|
.and_then(|mut a| a.remove("calls"))
|
||||||
.and_then(|v| v.as_array().cloned())
|
.and_then(|v| v.as_array().cloned())
|
||||||
@@ -193,7 +215,10 @@ impl ToMcpGateway {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let response = self.dispatch.invoke(identity.clone(), &operation, input).await;
|
let response = self
|
||||||
|
.dispatch
|
||||||
|
.invoke(identity.clone(), &operation, input)
|
||||||
|
.await;
|
||||||
results.push(envelope_to_value(response));
|
results.push(envelope_to_value(response));
|
||||||
}
|
}
|
||||||
CallToolResult::structured(Value::Array(results))
|
CallToolResult::structured(Value::Array(results))
|
||||||
@@ -210,7 +235,10 @@ fn parse_call_arguments(arguments: Option<JsonObject>) -> Result<(String, Value)
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let operation = match map.remove("operation").and_then(|v| v.as_str().map(str::to_string)) {
|
let operation = match map
|
||||||
|
.remove("operation")
|
||||||
|
.and_then(|v| v.as_str().map(str::to_string))
|
||||||
|
{
|
||||||
Some(s) => s,
|
Some(s) => s,
|
||||||
None => {
|
None => {
|
||||||
return Err(CallToolResult::structured_error(serde_json::json!({
|
return Err(CallToolResult::structured_error(serde_json::json!({
|
||||||
@@ -359,7 +387,11 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway {
|
|||||||
TOOL_CALL => this.handle_call(arguments, identity).await,
|
TOOL_CALL => this.handle_call(arguments, identity).await,
|
||||||
TOOL_BATCH => this.handle_batch(arguments, identity).await,
|
TOOL_BATCH => this.handle_batch(arguments, identity).await,
|
||||||
unknown => {
|
unknown => {
|
||||||
let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false);
|
let err = CallError::new(
|
||||||
|
"NOT_FOUND",
|
||||||
|
format!("unknown gateway tool: {unknown}"),
|
||||||
|
false,
|
||||||
|
);
|
||||||
call_error_to_structured_error(err)
|
call_error_to_structured_error(err)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -368,9 +400,7 @@ impl rmcp::handler::server::ServerHandler for ToMcpGateway {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_info(&self) -> ServerInfo {
|
fn get_info(&self) -> ServerInfo {
|
||||||
let capabilities = ServerCapabilities::builder()
|
let capabilities = ServerCapabilities::builder().enable_tools().build();
|
||||||
.enable_tools()
|
|
||||||
.build();
|
|
||||||
ServerInfo::new(capabilities)
|
ServerInfo::new(capabilities)
|
||||||
.with_server_info(Implementation::new(
|
.with_server_info(Implementation::new(
|
||||||
"alknet-to-mcp",
|
"alknet-to-mcp",
|
||||||
@@ -462,10 +492,14 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn make_echo_handler() -> alknet_call::registry::registration::Handler {
|
fn make_echo_handler() -> alknet_call::registry::registration::Handler {
|
||||||
make_handler(|input, context| async move { ResponseEnvelope::ok(context.request_id, input) })
|
make_handler(
|
||||||
|
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn full_registry_with_ops(specs: Vec<(String, OperationType, AccessControl)>) -> Arc<OperationRegistry> {
|
fn full_registry_with_ops(
|
||||||
|
specs: Vec<(String, OperationType, AccessControl)>,
|
||||||
|
) -> Arc<OperationRegistry> {
|
||||||
let mut inner = OperationRegistry::new();
|
let mut inner = OperationRegistry::new();
|
||||||
for (name, op_type, acl) in specs {
|
for (name, op_type, acl) in specs {
|
||||||
inner.register(HandlerRegistration::new(
|
inner.register(HandlerRegistration::new(
|
||||||
@@ -509,7 +543,10 @@ mod tests {
|
|||||||
Arc::new(dispatch_registry)
|
Arc::new(dispatch_registry)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dispatch(registry: Arc<OperationRegistry>, provider: Arc<dyn IdentityProvider>) -> Arc<GatewayDispatch> {
|
fn dispatch(
|
||||||
|
registry: Arc<OperationRegistry>,
|
||||||
|
provider: Arc<dyn IdentityProvider>,
|
||||||
|
) -> Arc<GatewayDispatch> {
|
||||||
Arc::new(GatewayDispatch::new(registry, provider))
|
Arc::new(GatewayDispatch::new(registry, provider))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -542,7 +579,11 @@ mod tests {
|
|||||||
TOOL_CALL => gateway.handle_call(arguments, identity).await,
|
TOOL_CALL => gateway.handle_call(arguments, identity).await,
|
||||||
TOOL_BATCH => gateway.handle_batch(arguments, identity).await,
|
TOOL_BATCH => gateway.handle_batch(arguments, identity).await,
|
||||||
unknown => {
|
unknown => {
|
||||||
let err = CallError::new("NOT_FOUND", format!("unknown gateway tool: {unknown}"), false);
|
let err = CallError::new(
|
||||||
|
"NOT_FOUND",
|
||||||
|
format!("unknown gateway tool: {unknown}"),
|
||||||
|
false,
|
||||||
|
);
|
||||||
call_error_to_structured_error(err)
|
call_error_to_structured_error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -550,10 +591,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_tools_returns_exactly_four_gateway_tools() {
|
async fn list_tools_returns_exactly_four_gateway_tools() {
|
||||||
let _gateway = ToMcpGateway::new(dispatch(
|
let _gateway = ToMcpGateway::new(dispatch(full_registry_with_ops(vec![]), provider()));
|
||||||
full_registry_with_ops(vec![]),
|
|
||||||
provider(),
|
|
||||||
));
|
|
||||||
let tools = gateway_tools();
|
let tools = gateway_tools();
|
||||||
let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
|
let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
|
||||||
assert_eq!(names.len(), 4);
|
assert_eq!(names.len(), 4);
|
||||||
@@ -583,7 +621,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn search_returns_access_control_filtered_ops_excluding_subscriptions() {
|
async fn search_returns_access_control_filtered_ops_excluding_subscriptions() {
|
||||||
let registry = full_registry_with_ops(vec![
|
let registry = full_registry_with_ops(vec![
|
||||||
("public/echo".to_string(), OperationType::Query, AccessControl::default()),
|
(
|
||||||
|
"public/echo".to_string(),
|
||||||
|
OperationType::Query,
|
||||||
|
AccessControl::default(),
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"admin/secret".to_string(),
|
"admin/secret".to_string(),
|
||||||
OperationType::Query,
|
OperationType::Query,
|
||||||
@@ -592,12 +634,21 @@ mod tests {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
("events/stream".to_string(), OperationType::Subscription, AccessControl::default()),
|
(
|
||||||
|
"events/stream".to_string(),
|
||||||
|
OperationType::Subscription,
|
||||||
|
AccessControl::default(),
|
||||||
|
),
|
||||||
]);
|
]);
|
||||||
let idp: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
let idp: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
||||||
let gateway = ToMcpGateway::new(dispatch(registry, idp));
|
let gateway = ToMcpGateway::new(dispatch(registry, idp));
|
||||||
|
|
||||||
let result = invoke_tool(&gateway, "search", None, Some(identity_with_scopes("user", &["user"])))
|
let result = invoke_tool(
|
||||||
|
&gateway,
|
||||||
|
"search",
|
||||||
|
None,
|
||||||
|
Some(identity_with_scopes("user", &["user"])),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(result.is_error, Some(false));
|
assert_eq!(result.is_error, Some(false));
|
||||||
let structured = result.structured_content.expect("structured present");
|
let structured = result.structured_content.expect("structured present");
|
||||||
@@ -610,11 +661,23 @@ mod tests {
|
|||||||
.filter_map(|o| o.get("name").and_then(Value::as_str))
|
.filter_map(|o| o.get("name").and_then(Value::as_str))
|
||||||
.collect();
|
.collect();
|
||||||
assert!(names.contains(&"public/echo"));
|
assert!(names.contains(&"public/echo"));
|
||||||
assert!(!names.contains(&"admin/secret"), "ACL-filtered op must not appear");
|
assert!(
|
||||||
assert!(!names.contains(&"events/stream"), "Subscription op must be excluded");
|
!names.contains(&"admin/secret"),
|
||||||
|
"ACL-filtered op must not appear"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!names.contains(&"events/stream"),
|
||||||
|
"Subscription op must be excluded"
|
||||||
|
);
|
||||||
for op in ops {
|
for op in ops {
|
||||||
assert!(op.get("description").is_some(), "each entry has a description");
|
assert!(
|
||||||
assert!(op.get("input_schema").is_none(), "search must not return full schemas");
|
op.get("description").is_some(),
|
||||||
|
"each entry has a description"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
op.get("input_schema").is_none(),
|
||||||
|
"search must not return full schemas"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -632,7 +695,10 @@ mod tests {
|
|||||||
let result = invoke_tool(&gateway, "schema", Some(args), None).await;
|
let result = invoke_tool(&gateway, "schema", Some(args), None).await;
|
||||||
assert_eq!(result.is_error, Some(false));
|
assert_eq!(result.is_error, Some(false));
|
||||||
let structured = result.structured_content.expect("structured present");
|
let structured = result.structured_content.expect("structured present");
|
||||||
assert_eq!(structured.get("name"), Some(&Value::String("fs/readFile".to_string())));
|
assert_eq!(
|
||||||
|
structured.get("name"),
|
||||||
|
Some(&Value::String("fs/readFile".to_string()))
|
||||||
|
);
|
||||||
assert!(structured.get("input_schema").is_some());
|
assert!(structured.get("input_schema").is_some());
|
||||||
assert!(structured.get("output_schema").is_some());
|
assert!(structured.get("output_schema").is_some());
|
||||||
assert!(structured.get("error_schemas").is_some());
|
assert!(structured.get("error_schemas").is_some());
|
||||||
@@ -649,7 +715,10 @@ mod tests {
|
|||||||
let gateway = ToMcpGateway::new(dispatch(registry, provider()));
|
let gateway = ToMcpGateway::new(dispatch(registry, provider()));
|
||||||
|
|
||||||
let mut args = Map::new();
|
let mut args = Map::new();
|
||||||
args.insert("operation".to_string(), Value::String("echo/run".to_string()));
|
args.insert(
|
||||||
|
"operation".to_string(),
|
||||||
|
Value::String("echo/run".to_string()),
|
||||||
|
);
|
||||||
args.insert("input".to_string(), serde_json::json!({ "msg": "hi" }));
|
args.insert("input".to_string(), serde_json::json!({ "msg": "hi" }));
|
||||||
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
||||||
assert_eq!(result.is_error, Some(false));
|
assert_eq!(result.is_error, Some(false));
|
||||||
@@ -665,12 +734,18 @@ mod tests {
|
|||||||
let gateway = ToMcpGateway::new(dispatch(registry, provider()));
|
let gateway = ToMcpGateway::new(dispatch(registry, provider()));
|
||||||
|
|
||||||
let mut args = Map::new();
|
let mut args = Map::new();
|
||||||
args.insert("operation".to_string(), Value::String("no/such".to_string()));
|
args.insert(
|
||||||
|
"operation".to_string(),
|
||||||
|
Value::String("no/such".to_string()),
|
||||||
|
);
|
||||||
args.insert("input".to_string(), Value::Object(Map::new()));
|
args.insert("input".to_string(), Value::Object(Map::new()));
|
||||||
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
||||||
assert_eq!(result.is_error, Some(true));
|
assert_eq!(result.is_error, Some(true));
|
||||||
let structured = result.structured_content.expect("structured error present");
|
let structured = result.structured_content.expect("structured error present");
|
||||||
assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string())));
|
assert_eq!(
|
||||||
|
structured.get("code"),
|
||||||
|
Some(&Value::String("NOT_FOUND".to_string()))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -713,12 +788,18 @@ mod tests {
|
|||||||
let gateway = ToMcpGateway::new(dispatch(registry, idp));
|
let gateway = ToMcpGateway::new(dispatch(registry, idp));
|
||||||
|
|
||||||
let mut args = Map::new();
|
let mut args = Map::new();
|
||||||
args.insert("operation".to_string(), Value::String("admin/run".to_string()));
|
args.insert(
|
||||||
|
"operation".to_string(),
|
||||||
|
Value::String("admin/run".to_string()),
|
||||||
|
);
|
||||||
args.insert("input".to_string(), Value::Object(Map::new()));
|
args.insert("input".to_string(), Value::Object(Map::new()));
|
||||||
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
let result = invoke_tool(&gateway, "call", Some(args), None).await;
|
||||||
assert_eq!(result.is_error, Some(true));
|
assert_eq!(result.is_error, Some(true));
|
||||||
let structured = result.structured_content.expect("structured error present");
|
let structured = result.structured_content.expect("structured error present");
|
||||||
assert_eq!(structured.get("code"), Some(&Value::String("FORBIDDEN".to_string())));
|
assert_eq!(
|
||||||
|
structured.get("code"),
|
||||||
|
Some(&Value::String("FORBIDDEN".to_string()))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -727,7 +808,10 @@ mod tests {
|
|||||||
let result = invoke_tool(&gateway, "bogus", None, None).await;
|
let result = invoke_tool(&gateway, "bogus", None, None).await;
|
||||||
assert_eq!(result.is_error, Some(true));
|
assert_eq!(result.is_error, Some(true));
|
||||||
let structured = result.structured_content.expect("structured error present");
|
let structured = result.structured_content.expect("structured error present");
|
||||||
assert_eq!(structured.get("code"), Some(&Value::String("NOT_FOUND".to_string())));
|
assert_eq!(
|
||||||
|
structured.get("code"),
|
||||||
|
Some(&Value::String("NOT_FOUND".to_string()))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -749,10 +833,16 @@ mod tests {
|
|||||||
let admin_identity = identity_with_scopes("admin-peer", &["admin"]);
|
let admin_identity = identity_with_scopes("admin-peer", &["admin"]);
|
||||||
let extensions = extensions_with_identity(Some(admin_identity.clone()));
|
let extensions = extensions_with_identity(Some(admin_identity.clone()));
|
||||||
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
|
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
|
||||||
assert_eq!(extracted.as_ref().map(|i| &i.id), Some(&"admin-peer".to_string()));
|
assert_eq!(
|
||||||
|
extracted.as_ref().map(|i| &i.id),
|
||||||
|
Some(&"admin-peer".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
let mut args = Map::new();
|
let mut args = Map::new();
|
||||||
args.insert("operation".to_string(), Value::String("admin/run".to_string()));
|
args.insert(
|
||||||
|
"operation".to_string(),
|
||||||
|
Value::String("admin/run".to_string()),
|
||||||
|
);
|
||||||
args.insert("input".to_string(), serde_json::json!({ "ok": 1 }));
|
args.insert("input".to_string(), serde_json::json!({ "ok": 1 }));
|
||||||
let result = gateway.handle_call(Some(args), extracted).await;
|
let result = gateway.handle_call(Some(args), extracted).await;
|
||||||
assert_eq!(result.is_error, Some(false));
|
assert_eq!(result.is_error, Some(false));
|
||||||
@@ -779,7 +869,10 @@ mod tests {
|
|||||||
let id = identity_with_scopes("caller", &["read"]);
|
let id = identity_with_scopes("caller", &["read"]);
|
||||||
let extensions = extensions_with_identity(Some(id.clone()));
|
let extensions = extensions_with_identity(Some(id.clone()));
|
||||||
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
|
let extracted = ToMcpGateway::extract_identity_from_extensions(&extensions);
|
||||||
assert_eq!(extracted.as_ref().map(|i| i.id.clone()), Some("caller".to_string()));
|
assert_eq!(
|
||||||
|
extracted.as_ref().map(|i| i.id.clone()),
|
||||||
|
Some("caller".to_string())
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
extracted.as_ref().map(|i| i.scopes.clone()),
|
extracted.as_ref().map(|i| i.scopes.clone()),
|
||||||
Some(vec!["read".to_string()])
|
Some(vec!["read".to_string()])
|
||||||
@@ -834,8 +927,14 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let mut call_args = Map::new();
|
let mut call_args = Map::new();
|
||||||
call_args.insert("operation".to_string(), Value::String(first_name.to_string()));
|
call_args.insert(
|
||||||
call_args.insert("input".to_string(), serde_json::json!({ "path": "/etc/hosts" }));
|
"operation".to_string(),
|
||||||
|
Value::String(first_name.to_string()),
|
||||||
|
);
|
||||||
|
call_args.insert(
|
||||||
|
"input".to_string(),
|
||||||
|
serde_json::json!({ "path": "/etc/hosts" }),
|
||||||
|
);
|
||||||
let call_result = invoke_tool(&gateway, "call", Some(call_args), None).await;
|
let call_result = invoke_tool(&gateway, "call", Some(call_args), None).await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
call_result.structured_content,
|
call_result.structured_content,
|
||||||
|
|||||||
@@ -125,7 +125,8 @@ fn build_client(config: &HttpClientConfig) -> Result<ClientWithMiddleware, HttpC
|
|||||||
builder = builder.timeout(timeout);
|
builder = builder.timeout(timeout);
|
||||||
}
|
}
|
||||||
if let Some(ca_bundle_path) = &config.ca_bundle {
|
if let Some(ca_bundle_path) = &config.ca_bundle {
|
||||||
let pem = std::fs::read(ca_bundle_path).map_err(|source| HttpClientBuildError::CaBundleRead {
|
let pem =
|
||||||
|
std::fs::read(ca_bundle_path).map_err(|source| HttpClientBuildError::CaBundleRead {
|
||||||
path: ca_bundle_path.clone(),
|
path: ca_bundle_path.clone(),
|
||||||
source,
|
source,
|
||||||
})?;
|
})?;
|
||||||
@@ -152,9 +153,7 @@ fn build_client(config: &HttpClientConfig) -> Result<ClientWithMiddleware, HttpC
|
|||||||
source,
|
source,
|
||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
let identity = reqwest::Identity::from_pem(
|
let identity = reqwest::Identity::from_pem(concat_pem(&cert_pem, &key_pem).as_slice())
|
||||||
concat_pem(&cert_pem, &key_pem).as_slice(),
|
|
||||||
)
|
|
||||||
.map_err(|source| HttpClientBuildError::ClientCertParse {
|
.map_err(|source| HttpClientBuildError::ClientCertParse {
|
||||||
path: client_cert_cfg.cert_pem.clone(),
|
path: client_cert_cfg.cert_pem.clone(),
|
||||||
source,
|
source,
|
||||||
@@ -163,8 +162,12 @@ fn build_client(config: &HttpClientConfig) -> Result<ClientWithMiddleware, HttpC
|
|||||||
}
|
}
|
||||||
let reqwest_client = builder.build().map_err(HttpClientBuildError::Build)?;
|
let reqwest_client = builder.build().map_err(HttpClientBuildError::Build)?;
|
||||||
let client = reqwest_middleware::ClientBuilder::new(reqwest_client)
|
let client = reqwest_middleware::ClientBuilder::new(reqwest_client)
|
||||||
.with(RetryTransientMiddleware::new_with_policy(config.retry_policy))
|
.with(RetryTransientMiddleware::new_with_policy(
|
||||||
.with(RetryAfterMiddleware::with_capacity(DEFAULT_RETRY_AFTER_CAPACITY))
|
config.retry_policy,
|
||||||
|
))
|
||||||
|
.with(RetryAfterMiddleware::with_capacity(
|
||||||
|
DEFAULT_RETRY_AFTER_CAPACITY,
|
||||||
|
))
|
||||||
.build();
|
.build();
|
||||||
Ok(client)
|
Ok(client)
|
||||||
}
|
}
|
||||||
@@ -203,10 +206,7 @@ mod tests {
|
|||||||
.build()
|
.build()
|
||||||
.expect("RequestBuilder builds");
|
.expect("RequestBuilder builds");
|
||||||
assert_eq!(request.method(), reqwest::Method::GET);
|
assert_eq!(request.method(), reqwest::Method::GET);
|
||||||
assert_eq!(
|
assert_eq!(request.url().as_str(), "https://api.example.com/v1/chat");
|
||||||
request.url().as_str(),
|
|
||||||
"https://api.example.com/v1/chat"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -99,7 +99,10 @@ impl RetryAfterMiddleware {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn len(&self) -> usize {
|
fn len(&self) -> usize {
|
||||||
self.deadlines.lock().expect("deadlines mutex poisoned").len()
|
self.deadlines
|
||||||
|
.lock()
|
||||||
|
.expect("deadlines mutex poisoned")
|
||||||
|
.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -156,8 +159,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_retry_after_http_date() {
|
fn parse_retry_after_http_date() {
|
||||||
let deadline = parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT")
|
let deadline =
|
||||||
.expect("HTTP-date value parses");
|
parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT").expect("HTTP-date value parses");
|
||||||
assert!(deadline > SystemTime::now());
|
assert!(deadline > SystemTime::now());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,7 +275,10 @@ mod tests {
|
|||||||
async fn middleware_sleeps_before_request_with_active_deadline() {
|
async fn middleware_sleeps_before_request_with_active_deadline() {
|
||||||
let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8));
|
let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8));
|
||||||
let target = url("https://api.example.com/v1/chat");
|
let target = url("https://api.example.com/v1/chat");
|
||||||
mw.record_test(target.clone(), SystemTime::now() + Duration::from_millis(50));
|
mw.record_test(
|
||||||
|
target.clone(),
|
||||||
|
SystemTime::now() + Duration::from_millis(50),
|
||||||
|
);
|
||||||
let started = SystemTime::now();
|
let started = SystemTime::now();
|
||||||
mw.maybe_sleep_for(&target).await;
|
mw.maybe_sleep_for(&target).await;
|
||||||
let elapsed = SystemTime::now().duration_since(started).unwrap();
|
let elapsed = SystemTime::now().duration_since(started).unwrap();
|
||||||
|
|||||||
@@ -83,11 +83,7 @@ impl GatewayDispatch {
|
|||||||
r.capabilities.clone(),
|
r.capabilities.clone(),
|
||||||
r.scoped_env.clone().unwrap_or_else(ScopedPeerEnv::empty),
|
r.scoped_env.clone().unwrap_or_else(ScopedPeerEnv::empty),
|
||||||
),
|
),
|
||||||
None => (
|
None => (None, Capabilities::new(), ScopedPeerEnv::empty()),
|
||||||
None,
|
|
||||||
Capabilities::new(),
|
|
||||||
ScopedPeerEnv::empty(),
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let env: Arc<dyn alknet_call::registry::env::OperationEnv + Send + Sync> =
|
let env: Arc<dyn alknet_call::registry::env::OperationEnv + Send + Sync> =
|
||||||
@@ -254,10 +250,7 @@ mod tests {
|
|||||||
.invoke(None, "echo/run", serde_json::json!({ "msg": "hi" }))
|
.invoke(None, "echo/run", serde_json::json!({ "msg": "hi" }))
|
||||||
.await;
|
.await;
|
||||||
assert!(response.result.is_ok());
|
assert!(response.result.is_ok());
|
||||||
assert_eq!(
|
assert_eq!(response.result.unwrap(), serde_json::json!({ "msg": "hi" }));
|
||||||
response.result.unwrap(),
|
|
||||||
serde_json::json!({ "msg": "hi" })
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -270,9 +263,7 @@ mod tests {
|
|||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
||||||
let dp = dispatch(registry, provider);
|
let dp = dispatch(registry, provider);
|
||||||
|
|
||||||
let response = dp
|
let response = dp.invoke(None, "/echo/run", serde_json::json!({})).await;
|
||||||
.invoke(None, "/echo/run", serde_json::json!({}))
|
|
||||||
.await;
|
|
||||||
assert!(response.result.is_ok());
|
assert!(response.result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -369,9 +360,7 @@ mod tests {
|
|||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
||||||
let dp = dispatch(registry, provider);
|
let dp = dispatch(registry, provider);
|
||||||
|
|
||||||
let response = dp
|
let response = dp.invoke(None, "no/such", serde_json::json!({})).await;
|
||||||
.invoke(None, "no/such", serde_json::json!({}))
|
|
||||||
.await;
|
|
||||||
match response.result {
|
match response.result {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
assert_eq!(e.code, "NOT_FOUND");
|
assert_eq!(e.code, "NOT_FOUND");
|
||||||
@@ -398,9 +387,7 @@ mod tests {
|
|||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
||||||
let dp = dispatch(registry, provider);
|
let dp = dispatch(registry, provider);
|
||||||
|
|
||||||
let response = dp
|
let response = dp.invoke(None, "secret/op", serde_json::json!({})).await;
|
||||||
.invoke(None, "secret/op", serde_json::json!({}))
|
|
||||||
.await;
|
|
||||||
match response.result {
|
match response.result {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
assert_eq!(e.code, "NOT_FOUND");
|
assert_eq!(e.code, "NOT_FOUND");
|
||||||
@@ -423,9 +410,7 @@ mod tests {
|
|||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
let provider: Arc<dyn IdentityProvider> = Arc::new(StaticIdentityProvider::new());
|
||||||
let dp = dispatch(registry, provider);
|
let dp = dispatch(registry, provider);
|
||||||
|
|
||||||
let response = dp
|
let response = dp.invoke(None, "admin/run", serde_json::json!({})).await;
|
||||||
.invoke(None, "admin/run", serde_json::json!({}))
|
|
||||||
.await;
|
|
||||||
match response.result {
|
match response.result {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
assert_eq!(e.code, "FORBIDDEN");
|
assert_eq!(e.code, "FORBIDDEN");
|
||||||
@@ -506,8 +491,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_root_context_carries_registration_bundle_fields() {
|
fn build_root_context_carries_registration_bundle_fields() {
|
||||||
let authority =
|
let authority = alknet_call::registry::context::CompositionAuthority::new(
|
||||||
alknet_call::registry::context::CompositionAuthority::new("agent", ["fs:read".to_string()]);
|
"agent",
|
||||||
|
["fs:read".to_string()],
|
||||||
|
);
|
||||||
let scoped = ScopedPeerEnv::new(["fs/readFile"]);
|
let scoped = ScopedPeerEnv::new(["fs/readFile"]);
|
||||||
let caps = Capabilities::new().with_api_key("google", "k".to_string());
|
let caps = Capabilities::new().with_api_key("google", "k".to_string());
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,10 @@ pub fn call_error_to_http_status(error: &CallError) -> u16 {
|
|||||||
call_error_to_http_status_with_identity(error, None)
|
call_error_to_http_status_with_identity(error, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn call_error_to_http_status_with_identity(error: &CallError, identity: Option<&Identity>) -> u16 {
|
pub fn call_error_to_http_status_with_identity(
|
||||||
|
error: &CallError,
|
||||||
|
identity: Option<&Identity>,
|
||||||
|
) -> u16 {
|
||||||
match error.code.as_str() {
|
match error.code.as_str() {
|
||||||
PROTOCOL_CODE_NOT_FOUND => STATUS_NOT_FOUND,
|
PROTOCOL_CODE_NOT_FOUND => STATUS_NOT_FOUND,
|
||||||
PROTOCOL_CODE_FORBIDDEN => {
|
PROTOCOL_CODE_FORBIDDEN => {
|
||||||
@@ -59,8 +62,8 @@ pub fn call_error_to_http_response(error: &CallError) -> Response {
|
|||||||
let retry_after = retry_after_value(error, status_code);
|
let retry_after = retry_after_value(error, status_code);
|
||||||
|
|
||||||
if let Some(retry_after) = retry_after {
|
if let Some(retry_after) = retry_after {
|
||||||
let header_value = HeaderValue::from_str(&retry_after)
|
let header_value =
|
||||||
.unwrap_or_else(|_| HeaderValue::from_static("0"));
|
HeaderValue::from_str(&retry_after).unwrap_or_else(|_| HeaderValue::from_static("0"));
|
||||||
(status, [(header::RETRY_AFTER, header_value)], Json(body)).into_response()
|
(status, [(header::RETRY_AFTER, header_value)], Json(body)).into_response()
|
||||||
} else {
|
} else {
|
||||||
(status, Json(body)).into_response()
|
(status, Json(body)).into_response()
|
||||||
@@ -139,7 +142,10 @@ mod tests {
|
|||||||
fn forbidden_with_some_identity_maps_to_403() {
|
fn forbidden_with_some_identity_maps_to_403() {
|
||||||
let error = CallError::forbidden("insufficient scopes");
|
let error = CallError::forbidden("insufficient scopes");
|
||||||
let id = identity();
|
let id = identity();
|
||||||
assert_eq!(call_error_to_http_status_with_identity(&error, Some(&id)), 403);
|
assert_eq!(
|
||||||
|
call_error_to_http_status_with_identity(&error, Some(&id)),
|
||||||
|
403
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -213,7 +219,10 @@ mod tests {
|
|||||||
let error = CallError::new("HTTP_503", "slow down", true);
|
let error = CallError::new("HTTP_503", "slow down", true);
|
||||||
let response = call_error_to_http_response(&error);
|
let response = call_error_to_http_response(&error);
|
||||||
assert_eq!(response.status(), StatusCode::from_u16(503).unwrap());
|
assert_eq!(response.status(), StatusCode::from_u16(503).unwrap());
|
||||||
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none());
|
assert!(response
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::RETRY_AFTER)
|
||||||
|
.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -221,7 +230,10 @@ mod tests {
|
|||||||
let error = CallError::new("HTTP_503", "down", false)
|
let error = CallError::new("HTTP_503", "down", false)
|
||||||
.with_details(serde_json::json!({ "retry_after": "5" }));
|
.with_details(serde_json::json!({ "retry_after": "5" }));
|
||||||
let response = call_error_to_http_response(&error);
|
let response = call_error_to_http_response(&error);
|
||||||
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none());
|
assert!(response
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::RETRY_AFTER)
|
||||||
|
.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -241,7 +253,10 @@ mod tests {
|
|||||||
let error = CallError::timeout("timed out");
|
let error = CallError::timeout("timed out");
|
||||||
let response = call_error_to_http_response(&error);
|
let response = call_error_to_http_response(&error);
|
||||||
assert_eq!(response.status(), StatusCode::from_u16(504).unwrap());
|
assert_eq!(response.status(), StatusCode::from_u16(504).unwrap());
|
||||||
assert!(response.headers().get(axum::http::header::RETRY_AFTER).is_none());
|
assert!(response
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::RETRY_AFTER)
|
||||||
|
.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -33,12 +33,12 @@ use super::auth::bearer_auth_middleware;
|
|||||||
use super::decoy::decoy_fallback;
|
use super::decoy::decoy_fallback;
|
||||||
use super::gateway_routes;
|
use super::gateway_routes;
|
||||||
use super::healthz::healthz;
|
use super::healthz::healthz;
|
||||||
use crate::websocket::upgrade::ws_upgrade_handler;
|
|
||||||
use crate::websocket::upgrade::WS_UPGRADE_PATH;
|
|
||||||
#[cfg(feature = "mcp")]
|
#[cfg(feature = "mcp")]
|
||||||
use crate::adapters::to_mcp_service;
|
use crate::adapters::to_mcp_service;
|
||||||
#[cfg(feature = "mcp")]
|
#[cfg(feature = "mcp")]
|
||||||
use crate::gateway::GatewayDispatch;
|
use crate::gateway::GatewayDispatch;
|
||||||
|
use crate::websocket::upgrade::ws_upgrade_handler;
|
||||||
|
use crate::websocket::upgrade::WS_UPGRADE_PATH;
|
||||||
|
|
||||||
const ALPN_HTTP1: &[u8] = b"http/1.1";
|
const ALPN_HTTP1: &[u8] = b"http/1.1";
|
||||||
const ALPN_H2: &[u8] = b"h2";
|
const ALPN_H2: &[u8] = b"h2";
|
||||||
@@ -47,8 +47,12 @@ const ALPN_H2: &[u8] = b"h2";
|
|||||||
pub enum DecoyConfig {
|
pub enum DecoyConfig {
|
||||||
#[default]
|
#[default]
|
||||||
NotFound,
|
NotFound,
|
||||||
StaticSite { root: PathBuf },
|
StaticSite {
|
||||||
Redirect { to: String },
|
root: PathBuf,
|
||||||
|
},
|
||||||
|
Redirect {
|
||||||
|
to: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -87,11 +91,17 @@ pub struct HttpAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HttpAdapter {
|
impl HttpAdapter {
|
||||||
pub fn new(identity_provider: Arc<dyn IdentityProvider>, registry: Arc<OperationRegistry>) -> Self {
|
pub fn new(
|
||||||
|
identity_provider: Arc<dyn IdentityProvider>,
|
||||||
|
registry: Arc<OperationRegistry>,
|
||||||
|
) -> Self {
|
||||||
Self::for_alpn(identity_provider, registry, ALPN_HTTP1)
|
Self::for_alpn(identity_provider, registry, ALPN_HTTP1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn h2(identity_provider: Arc<dyn IdentityProvider>, registry: Arc<OperationRegistry>) -> Self {
|
pub fn h2(
|
||||||
|
identity_provider: Arc<dyn IdentityProvider>,
|
||||||
|
registry: Arc<OperationRegistry>,
|
||||||
|
) -> Self {
|
||||||
Self::for_alpn(identity_provider, registry, ALPN_H2)
|
Self::for_alpn(identity_provider, registry, ALPN_H2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +173,10 @@ fn build_router(state: RouterState, extra_routes: Option<Router>) -> Router {
|
|||||||
));
|
));
|
||||||
Router::new()
|
Router::new()
|
||||||
.nest_service("/mcp", to_mcp_service(dispatch))
|
.nest_service("/mcp", to_mcp_service(dispatch))
|
||||||
.layer(from_fn_with_state(auth_state.clone(), bearer_auth_middleware))
|
.layer(from_fn_with_state(
|
||||||
|
auth_state.clone(),
|
||||||
|
bearer_auth_middleware,
|
||||||
|
))
|
||||||
};
|
};
|
||||||
#[cfg(not(feature = "mcp"))]
|
#[cfg(not(feature = "mcp"))]
|
||||||
let mcp_router: Router<RouterState> = Router::new();
|
let mcp_router: Router<RouterState> = Router::new();
|
||||||
@@ -172,7 +185,10 @@ fn build_router(state: RouterState, extra_routes: Option<Router>) -> Router {
|
|||||||
.merge(gateway_routes::gateway_router())
|
.merge(gateway_routes::gateway_router())
|
||||||
.route("/openapi.json", get(not_implemented))
|
.route("/openapi.json", get(not_implemented))
|
||||||
.route(WS_UPGRADE_PATH, get(ws_upgrade_handler))
|
.route(WS_UPGRADE_PATH, get(ws_upgrade_handler))
|
||||||
.route_layer(from_fn_with_state(auth_state.clone(), bearer_auth_middleware))
|
.route_layer(from_fn_with_state(
|
||||||
|
auth_state.clone(),
|
||||||
|
bearer_auth_middleware,
|
||||||
|
))
|
||||||
.route("/healthz", get(healthz))
|
.route("/healthz", get(healthz))
|
||||||
.fallback(decoy_fallback)
|
.fallback(decoy_fallback)
|
||||||
.merge(mcp_router);
|
.merge(mcp_router);
|
||||||
@@ -203,7 +219,10 @@ impl ProtocolHandler for HttpAdapter {
|
|||||||
let _ = connection.set_identity(identity);
|
let _ = connection.set_identity(identity);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (send, recv) = connection.accept_bi().await.map_err(stream_error_to_handler)?;
|
let (send, recv) = connection
|
||||||
|
.accept_bi()
|
||||||
|
.await
|
||||||
|
.map_err(stream_error_to_handler)?;
|
||||||
let io = QuicStream::new(send, recv);
|
let io = QuicStream::new(send, recv);
|
||||||
self.serve_io(io).await
|
self.serve_io(io).await
|
||||||
}
|
}
|
||||||
@@ -295,7 +314,10 @@ mod tests {
|
|||||||
fn resolve_from_fingerprint(&self, _: &str) -> Option<alknet_core::auth::Identity> {
|
fn resolve_from_fingerprint(&self, _: &str) -> Option<alknet_core::auth::Identity> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
fn resolve_from_token(&self, _: &alknet_core::auth::AuthToken) -> Option<alknet_core::auth::Identity> {
|
fn resolve_from_token(
|
||||||
|
&self,
|
||||||
|
_: &alknet_core::auth::AuthToken,
|
||||||
|
) -> Option<alknet_core::auth::Identity> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -341,7 +363,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn with_decoy_updates_decoy() {
|
fn with_decoy_updates_decoy() {
|
||||||
let adapter = HttpAdapter::new(provider(), empty_registry());
|
let adapter = HttpAdapter::new(provider(), empty_registry());
|
||||||
let adapter = adapter.with_decoy(DecoyConfig::Redirect { to: "https://example.com".to_string() });
|
let adapter = adapter.with_decoy(DecoyConfig::Redirect {
|
||||||
|
to: "https://example.com".to_string(),
|
||||||
|
});
|
||||||
assert!(matches!(adapter.decoy(), DecoyConfig::Redirect { .. }));
|
assert!(matches!(adapter.decoy(), DecoyConfig::Redirect { .. }));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,7 +410,10 @@ mod tests {
|
|||||||
) -> (String, tokio::task::JoinHandle<()>) {
|
) -> (String, tokio::task::JoinHandle<()>) {
|
||||||
let (mut client_send, server_recv) = duplex(8 * 1024);
|
let (mut client_send, server_recv) = duplex(8 * 1024);
|
||||||
let (server_send, mut client_recv) = duplex(8 * 1024);
|
let (server_send, mut client_recv) = duplex(8 * 1024);
|
||||||
let server_io = QuicStreamDuplex { read: server_recv, write: server_send };
|
let server_io = QuicStreamDuplex {
|
||||||
|
read: server_recv,
|
||||||
|
write: server_send,
|
||||||
|
};
|
||||||
|
|
||||||
let adapter = HttpAdapter::new(provider(), empty_registry());
|
let adapter = HttpAdapter::new(provider(), empty_registry());
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
@@ -399,7 +426,12 @@ mod tests {
|
|||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
let mut buf = [0u8; 4096];
|
let mut buf = [0u8; 4096];
|
||||||
loop {
|
loop {
|
||||||
match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await {
|
match tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(5),
|
||||||
|
client_recv.read(&mut buf),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(Ok(0)) => break,
|
Ok(Ok(0)) => break,
|
||||||
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
||||||
Ok(Err(_)) => break,
|
Ok(Err(_)) => break,
|
||||||
@@ -455,21 +487,24 @@ mod tests {
|
|||||||
let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
||||||
let (response, handle) = send_request_and_read_response(request).await;
|
let (response, handle) = send_request_and_read_response(request).await;
|
||||||
handle.await.ok();
|
handle.await.ok();
|
||||||
assert!(response.starts_with("HTTP/1.1 200 "), "expected 200, got: {response}");
|
assert!(
|
||||||
|
response.starts_with("HTTP/1.1 200 "),
|
||||||
|
"expected 200, got: {response}"
|
||||||
|
);
|
||||||
assert!(response.contains("\r\n\r\nok"));
|
assert!(response.contains("\r\n\r\nok"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn custom_route_v1_foo_coexists_with_default_surface() {
|
async fn custom_route_v1_foo_coexists_with_default_surface() {
|
||||||
let extra = Router::new().route(
|
let extra = Router::new().route("/v1/foo", get(|| async { (StatusCode::OK, "foo-body") }));
|
||||||
"/v1/foo",
|
|
||||||
get(|| async { (StatusCode::OK, "foo-body") }),
|
|
||||||
);
|
|
||||||
let adapter = HttpAdapter::new(provider(), empty_registry()).with_extra_routes(extra);
|
let adapter = HttpAdapter::new(provider(), empty_registry()).with_extra_routes(extra);
|
||||||
|
|
||||||
let (mut client_send, server_recv) = duplex(8 * 1024);
|
let (mut client_send, server_recv) = duplex(8 * 1024);
|
||||||
let (server_send, mut client_recv) = duplex(8 * 1024);
|
let (server_send, mut client_recv) = duplex(8 * 1024);
|
||||||
let server_io = QuicStreamDuplex { read: server_recv, write: server_send };
|
let server_io = QuicStreamDuplex {
|
||||||
|
read: server_recv,
|
||||||
|
write: server_send,
|
||||||
|
};
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
adapter.serve_io(server_io).await.ok();
|
adapter.serve_io(server_io).await.ok();
|
||||||
@@ -482,7 +517,12 @@ mod tests {
|
|||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
let mut buf = [0u8; 4096];
|
let mut buf = [0u8; 4096];
|
||||||
loop {
|
loop {
|
||||||
match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await {
|
match tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(5),
|
||||||
|
client_recv.read(&mut buf),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(Ok(0)) => break,
|
Ok(Ok(0)) => break,
|
||||||
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
||||||
Ok(Err(_)) => break,
|
Ok(Err(_)) => break,
|
||||||
@@ -491,7 +531,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
handle.await.ok();
|
handle.await.ok();
|
||||||
let response_str = String::from_utf8_lossy(&response);
|
let response_str = String::from_utf8_lossy(&response);
|
||||||
assert!(response_str.starts_with("HTTP/1.1 200 "), "expected 200, got: {response_str}");
|
assert!(
|
||||||
|
response_str.starts_with("HTTP/1.1 200 "),
|
||||||
|
"expected 200, got: {response_str}"
|
||||||
|
);
|
||||||
assert!(response_str.contains("foo-body"));
|
assert!(response_str.contains("foo-body"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -505,7 +548,10 @@ mod tests {
|
|||||||
|
|
||||||
let (mut client_send, server_recv) = duplex(8 * 1024);
|
let (mut client_send, server_recv) = duplex(8 * 1024);
|
||||||
let (server_send, mut client_recv) = duplex(8 * 1024);
|
let (server_send, mut client_recv) = duplex(8 * 1024);
|
||||||
let server_io = QuicStreamDuplex { read: server_recv, write: server_send };
|
let server_io = QuicStreamDuplex {
|
||||||
|
read: server_recv,
|
||||||
|
write: server_send,
|
||||||
|
};
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
adapter.serve_io(server_io).await.ok();
|
adapter.serve_io(server_io).await.ok();
|
||||||
@@ -518,7 +564,12 @@ mod tests {
|
|||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
let mut buf = [0u8; 4096];
|
let mut buf = [0u8; 4096];
|
||||||
loop {
|
loop {
|
||||||
match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await {
|
match tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(5),
|
||||||
|
client_recv.read(&mut buf),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(Ok(0)) => break,
|
Ok(Ok(0)) => break,
|
||||||
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
||||||
Ok(Err(_)) => break,
|
Ok(Err(_)) => break,
|
||||||
@@ -527,7 +578,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
handle.await.ok();
|
handle.await.ok();
|
||||||
let response_str = String::from_utf8_lossy(&response);
|
let response_str = String::from_utf8_lossy(&response);
|
||||||
assert!(response_str.starts_with("HTTP/1.1 200 "), "default GET /healthz wins, got: {response_str}");
|
assert!(
|
||||||
|
response_str.starts_with("HTTP/1.1 200 "),
|
||||||
|
"default GET /healthz wins, got: {response_str}"
|
||||||
|
);
|
||||||
assert!(response_str.contains("\r\n\r\nok"));
|
assert!(response_str.contains("\r\n\r\nok"));
|
||||||
assert!(!response_str.contains("custom-healthz"));
|
assert!(!response_str.contains("custom-healthz"));
|
||||||
}
|
}
|
||||||
@@ -547,7 +601,12 @@ mod tests {
|
|||||||
let mut response = Vec::new();
|
let mut response = Vec::new();
|
||||||
let mut buf = [0u8; 4096];
|
let mut buf = [0u8; 4096];
|
||||||
loop {
|
loop {
|
||||||
match tokio::time::timeout(std::time::Duration::from_secs(5), client_recv.read(&mut buf)).await {
|
match tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(5),
|
||||||
|
client_recv.read(&mut buf),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(Ok(0)) => break,
|
Ok(Ok(0)) => break,
|
||||||
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
Ok(Ok(n)) => response.extend_from_slice(&buf[..n]),
|
||||||
Ok(Err(_)) => break,
|
Ok(Err(_)) => break,
|
||||||
@@ -569,7 +628,10 @@ mod tests {
|
|||||||
.with_extra_routes(extra);
|
.with_extra_routes(extra);
|
||||||
let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nContent-Length: 0\r\n\r\n";
|
let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nContent-Length: 0\r\n\r\n";
|
||||||
let response = serve_and_read(adapter, request).await;
|
let response = serve_and_read(adapter, request).await;
|
||||||
assert!(response.starts_with("HTTP/1.1 200"), "expected 200, got: {response}");
|
assert!(
|
||||||
|
response.starts_with("HTTP/1.1 200"),
|
||||||
|
"expected 200, got: {response}"
|
||||||
|
);
|
||||||
assert!(response.contains("oai-proxy"));
|
assert!(response.contains("oai-proxy"));
|
||||||
assert!(!response.contains("404 Not Found"));
|
assert!(!response.contains("404 Not Found"));
|
||||||
}
|
}
|
||||||
@@ -583,32 +645,43 @@ mod tests {
|
|||||||
let adapter = HttpAdapter::new(provider(), empty_registry())
|
let adapter = HttpAdapter::new(provider(), empty_registry())
|
||||||
.with_decoy(DecoyConfig::NotFound)
|
.with_decoy(DecoyConfig::NotFound)
|
||||||
.with_extra_routes(extra);
|
.with_extra_routes(extra);
|
||||||
let request = b"GET /totally/unknown HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
let request =
|
||||||
|
b"GET /totally/unknown HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
||||||
let response = serve_and_read(adapter, request).await;
|
let response = serve_and_read(adapter, request).await;
|
||||||
assert!(response.starts_with("HTTP/1.1 404"), "expected 404 decoy, got: {response}");
|
assert!(
|
||||||
|
response.starts_with("HTTP/1.1 404"),
|
||||||
|
"expected 404 decoy, got: {response}"
|
||||||
|
);
|
||||||
assert!(response.contains("404 Not Found"));
|
assert!(response.contains("404 Not Found"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn healthz_takes_precedence_over_decoy() {
|
async fn healthz_takes_precedence_over_decoy() {
|
||||||
let adapter = HttpAdapter::new(provider(), empty_registry())
|
let adapter =
|
||||||
.with_decoy(DecoyConfig::Redirect {
|
HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect {
|
||||||
to: "https://example.com".to_string(),
|
to: "https://example.com".to_string(),
|
||||||
});
|
});
|
||||||
let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
let request = b"GET /healthz HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
||||||
let response = serve_and_read(adapter, request).await;
|
let response = serve_and_read(adapter, request).await;
|
||||||
assert!(response.starts_with("HTTP/1.1 200"), "expected 200 healthz, got: {response}");
|
assert!(
|
||||||
|
response.starts_with("HTTP/1.1 200"),
|
||||||
|
"expected 200 healthz, got: {response}"
|
||||||
|
);
|
||||||
assert!(response.contains("\r\n\r\nok"));
|
assert!(response.contains("\r\n\r\nok"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn unknown_path_with_redirect_decoy_returns_redirect_over_wire() {
|
async fn unknown_path_with_redirect_decoy_returns_redirect_over_wire() {
|
||||||
let adapter = HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect {
|
let adapter =
|
||||||
|
HttpAdapter::new(provider(), empty_registry()).with_decoy(DecoyConfig::Redirect {
|
||||||
to: "https://example.com".to_string(),
|
to: "https://example.com".to_string(),
|
||||||
});
|
});
|
||||||
let request = b"GET /nope HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
let request = b"GET /nope HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
||||||
let response = serve_and_read(adapter, request).await;
|
let response = serve_and_read(adapter, request).await;
|
||||||
assert!(response.starts_with("HTTP/1.1 302"), "expected 302 redirect, got: {response}");
|
assert!(
|
||||||
|
response.starts_with("HTTP/1.1 302"),
|
||||||
|
"expected 302 redirect, got: {response}"
|
||||||
|
);
|
||||||
assert!(response.contains("location: https://example.com"));
|
assert!(response.contains("location: https://example.com"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -80,11 +80,12 @@ where
|
|||||||
{
|
{
|
||||||
type Rejection = Infallible;
|
type Rejection = Infallible;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
parts: &mut Parts,
|
let identity = parts
|
||||||
_state: &S,
|
.extensions
|
||||||
) -> Result<Self, Self::Rejection> {
|
.get::<Option<Identity>>()
|
||||||
let identity = parts.extensions.get::<Option<Identity>>().cloned().flatten();
|
.cloned()
|
||||||
|
.flatten();
|
||||||
Ok(ResolvedIdentity(identity))
|
Ok(ResolvedIdentity(identity))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,15 +175,16 @@ mod tests {
|
|||||||
assert!(identity.is_none());
|
assert!(identity.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_middleware(
|
async fn run_middleware(idp: Arc<dyn IdentityProvider>, request: Request) -> Response {
|
||||||
idp: Arc<dyn IdentityProvider>,
|
|
||||||
request: Request,
|
|
||||||
) -> Response {
|
|
||||||
let app: Router<()> = Router::new()
|
let app: Router<()> = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/",
|
"/",
|
||||||
get(|req: Request| async move {
|
get(|req: Request| async move {
|
||||||
let identity = req.extensions().get::<Option<Identity>>().cloned().flatten();
|
let identity = req
|
||||||
|
.extensions()
|
||||||
|
.get::<Option<Identity>>()
|
||||||
|
.cloned()
|
||||||
|
.flatten();
|
||||||
if let Some(id) = identity {
|
if let Some(id) = identity {
|
||||||
(StatusCode::OK, id.id)
|
(StatusCode::OK, id.id)
|
||||||
} else {
|
} else {
|
||||||
@@ -261,14 +263,12 @@ mod tests {
|
|||||||
let app: Router<()> = Router::new()
|
let app: Router<()> = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/",
|
"/",
|
||||||
get(
|
get(|ResolvedIdentity(identity): ResolvedIdentity| async move {
|
||||||
|ResolvedIdentity(identity): ResolvedIdentity| async move {
|
|
||||||
match identity {
|
match identity {
|
||||||
Some(id) => (StatusCode::OK, id.id),
|
Some(id) => (StatusCode::OK, id.id),
|
||||||
None => (StatusCode::OK, "none".to_string()),
|
None => (StatusCode::OK, "none".to_string()),
|
||||||
}
|
}
|
||||||
},
|
}),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
.layer(from_fn_with_state(idp, bearer_auth_middleware));
|
.layer(from_fn_with_state(idp, bearer_auth_middleware));
|
||||||
|
|
||||||
@@ -287,14 +287,12 @@ mod tests {
|
|||||||
let app: Router<()> = Router::new()
|
let app: Router<()> = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/",
|
"/",
|
||||||
get(
|
get(|ResolvedIdentity(identity): ResolvedIdentity| async move {
|
||||||
|ResolvedIdentity(identity): ResolvedIdentity| async move {
|
|
||||||
match identity {
|
match identity {
|
||||||
Some(id) => (StatusCode::OK, id.id),
|
Some(id) => (StatusCode::OK, id.id),
|
||||||
None => (StatusCode::OK, "none".to_string()),
|
None => (StatusCode::OK, "none".to_string()),
|
||||||
}
|
}
|
||||||
},
|
}),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
.layer(from_fn_with_state(idp, bearer_auth_middleware));
|
.layer(from_fn_with_state(idp, bearer_auth_middleware));
|
||||||
|
|
||||||
|
|||||||
@@ -33,10 +33,8 @@ pub fn fake_nginx_404() -> Response {
|
|||||||
header::CONTENT_TYPE,
|
header::CONTENT_TYPE,
|
||||||
HeaderValue::from_static("text/html; charset=utf-8"),
|
HeaderValue::from_static("text/html; charset=utf-8"),
|
||||||
);
|
);
|
||||||
resp.headers_mut().insert(
|
resp.headers_mut()
|
||||||
header::SERVER,
|
.insert(header::SERVER, HeaderValue::from_static("nginx"));
|
||||||
HeaderValue::from_static("nginx"),
|
|
||||||
);
|
|
||||||
resp
|
resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,10 +59,8 @@ pub async fn serve_static(root: &Path, request: Request) -> Response {
|
|||||||
let content_type = mime_for_path(&resolved);
|
let content_type = mime_for_path(&resolved);
|
||||||
let mut resp = Response::new(Body::from(bytes));
|
let mut resp = Response::new(Body::from(bytes));
|
||||||
*resp.status_mut() = StatusCode::OK;
|
*resp.status_mut() = StatusCode::OK;
|
||||||
resp.headers_mut().insert(
|
resp.headers_mut()
|
||||||
header::CONTENT_TYPE,
|
.insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type));
|
||||||
HeaderValue::from_static(content_type),
|
|
||||||
);
|
|
||||||
resp
|
resp
|
||||||
}
|
}
|
||||||
Err(_) => fake_nginx_404(),
|
Err(_) => fake_nginx_404(),
|
||||||
@@ -173,10 +169,7 @@ mod tests {
|
|||||||
async fn send(router: axum::Router, uri: &str) -> axum::response::Response {
|
async fn send(router: axum::Router, uri: &str) -> axum::response::Response {
|
||||||
tower::ServiceExt::<Request<Body>>::oneshot(
|
tower::ServiceExt::<Request<Body>>::oneshot(
|
||||||
router,
|
router,
|
||||||
Request::builder()
|
Request::builder().uri(uri).body(Body::empty()).unwrap(),
|
||||||
.uri(uri)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap(),
|
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -220,9 +213,7 @@ mod tests {
|
|||||||
async fn unknown_path_with_static_site_decoy_serves_file() {
|
async fn unknown_path_with_static_site_decoy_serves_file() {
|
||||||
let dir = tempfile_dir();
|
let dir = tempfile_dir();
|
||||||
let file = dir.join("index.html");
|
let file = dir.join("index.html");
|
||||||
tokio::fs::write(&file, "<h1>hello</h1>")
|
tokio::fs::write(&file, "<h1>hello</h1>").await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let decoy = DecoyConfig::StaticSite { root: dir.clone() };
|
let decoy = DecoyConfig::StaticSite { root: dir.clone() };
|
||||||
let resp = send(decoy_router(decoy), "/").await;
|
let resp = send(decoy_router(decoy), "/").await;
|
||||||
@@ -293,10 +284,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn tempfile_dir() -> PathBuf {
|
fn tempfile_dir() -> PathBuf {
|
||||||
let dir = PathBuf::from("/tmp").join(format!(
|
let dir =
|
||||||
"alknet-http-decoy-test-{}",
|
PathBuf::from("/tmp").join(format!("alknet-http-decoy-test-{}", uuid::Uuid::new_v4()));
|
||||||
uuid::Uuid::new_v4()
|
|
||||||
));
|
|
||||||
std::fs::create_dir_all(&dir).unwrap();
|
std::fs::create_dir_all(&dir).unwrap();
|
||||||
dir
|
dir
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,13 +52,19 @@ impl GatewayState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dispatch(&self) -> GatewayDispatch {
|
fn dispatch(&self) -> GatewayDispatch {
|
||||||
GatewayDispatch::new(Arc::clone(&self.registry), Arc::clone(&self.identity_provider))
|
GatewayDispatch::new(
|
||||||
|
Arc::clone(&self.registry),
|
||||||
|
Arc::clone(&self.identity_provider),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromRef<RouterState> for GatewayState {
|
impl FromRef<RouterState> for GatewayState {
|
||||||
fn from_ref(state: &RouterState) -> Self {
|
fn from_ref(state: &RouterState) -> Self {
|
||||||
GatewayState::new(Arc::clone(&state.registry), Arc::clone(&state.identity_provider))
|
GatewayState::new(
|
||||||
|
Arc::clone(&state.registry),
|
||||||
|
Arc::clone(&state.identity_provider),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,7 +98,9 @@ pub(crate) async fn call_handler(
|
|||||||
return not_found_response(&request.operation);
|
return not_found_response(&request.operation);
|
||||||
}
|
}
|
||||||
let dispatch = state.dispatch();
|
let dispatch = state.dispatch();
|
||||||
let envelope = dispatch.invoke(identity.clone(), &request.operation, request.input).await;
|
let envelope = dispatch
|
||||||
|
.invoke(identity.clone(), &request.operation, request.input)
|
||||||
|
.await;
|
||||||
envelope_to_response(envelope, identity.as_ref())
|
envelope_to_response(envelope, identity.as_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +109,9 @@ pub(crate) async fn search_handler(
|
|||||||
ResolvedIdentity(identity): ResolvedIdentity,
|
ResolvedIdentity(identity): ResolvedIdentity,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let dispatch = state.dispatch();
|
let dispatch = state.dispatch();
|
||||||
let envelope = dispatch.invoke(identity.clone(), SERVICES_LIST, json!({})).await;
|
let envelope = dispatch
|
||||||
|
.invoke(identity.clone(), SERVICES_LIST, json!({}))
|
||||||
|
.await;
|
||||||
envelope_to_response(envelope, identity.as_ref())
|
envelope_to_response(envelope, identity.as_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +125,11 @@ pub(crate) async fn schema_handler(
|
|||||||
}
|
}
|
||||||
let dispatch = state.dispatch();
|
let dispatch = state.dispatch();
|
||||||
let envelope = dispatch
|
let envelope = dispatch
|
||||||
.invoke(identity.clone(), SERVICES_SCHEMA, json!({ "name": query.name }))
|
.invoke(
|
||||||
|
identity.clone(),
|
||||||
|
SERVICES_SCHEMA,
|
||||||
|
json!({ "name": query.name }),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
envelope_to_response(envelope, identity.as_ref())
|
envelope_to_response(envelope, identity.as_ref())
|
||||||
}
|
}
|
||||||
@@ -149,7 +163,9 @@ pub(crate) async fn subscribe_handler(
|
|||||||
subscribe_stream_internal_error(request.operation)
|
subscribe_stream_internal_error(request.operation)
|
||||||
} else {
|
} else {
|
||||||
let dispatch = state.dispatch();
|
let dispatch = state.dispatch();
|
||||||
let envelope = dispatch.invoke(identity, &request.operation, request.input).await;
|
let envelope = dispatch
|
||||||
|
.invoke(identity, &request.operation, request.input)
|
||||||
|
.await;
|
||||||
subscribe_stream_from_envelope(envelope)
|
subscribe_stream_from_envelope(envelope)
|
||||||
};
|
};
|
||||||
Sse::new(stream)
|
Sse::new(stream)
|
||||||
@@ -221,8 +237,7 @@ fn not_found_response(operation: &str) -> Response {
|
|||||||
fn forbidden_response(message: String, identity: Option<&Identity>) -> Response {
|
fn forbidden_response(message: String, identity: Option<&Identity>) -> Response {
|
||||||
let error = CallError::forbidden(message);
|
let error = CallError::forbidden(message);
|
||||||
let status_code = call_error_to_http_status_with_identity(&error, identity);
|
let status_code = call_error_to_http_status_with_identity(&error, identity);
|
||||||
let status =
|
let status = StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
let body = serde_json::to_value(&error).unwrap_or(Value::Null);
|
let body = serde_json::to_value(&error).unwrap_or(Value::Null);
|
||||||
(status, Json(body)).into_response()
|
(status, Json(body)).into_response()
|
||||||
}
|
}
|
||||||
@@ -248,7 +263,9 @@ fn is_internal_op(registry: &OperationRegistry, operation: &str) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn envelope_to_sse_stream(envelope: ResponseEnvelope) -> impl Stream<Item = Result<Event, Infallible>> {
|
fn envelope_to_sse_stream(
|
||||||
|
envelope: ResponseEnvelope,
|
||||||
|
) -> impl Stream<Item = Result<Event, Infallible>> {
|
||||||
stream::once(async move {
|
stream::once(async move {
|
||||||
match envelope.result {
|
match envelope.result {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
@@ -756,7 +773,10 @@ mod tests {
|
|||||||
.get(axum::http::header::CONTENT_TYPE)
|
.get(axum::http::header::CONTENT_TYPE)
|
||||||
.map(|v| v.to_str().unwrap().to_string());
|
.map(|v| v.to_str().unwrap().to_string());
|
||||||
assert!(
|
assert!(
|
||||||
ctype.as_deref().unwrap_or("").starts_with("text/event-stream"),
|
ctype
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("")
|
||||||
|
.starts_with("text/event-stream"),
|
||||||
"expected text/event-stream, got {ctype:?}"
|
"expected text/event-stream, got {ctype:?}"
|
||||||
);
|
);
|
||||||
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
|
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
|
||||||
|
|||||||
@@ -128,7 +128,10 @@ mod tests {
|
|||||||
let out: EventEnvelope = response.into();
|
let out: EventEnvelope = response.into();
|
||||||
assert_eq!(out.r#type, EVENT_RESPONDED);
|
assert_eq!(out.r#type, EVENT_RESPONDED);
|
||||||
assert_eq!(out.id, "ws-rt-1");
|
assert_eq!(out.id, "ws-rt-1");
|
||||||
assert_eq!(out.payload.get("output"), Some(&serde_json::json!({ "v": 7 })));
|
assert_eq!(
|
||||||
|
out.payload.get("output"),
|
||||||
|
Some(&serde_json::json!({ "v": 7 }))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -160,7 +163,10 @@ mod tests {
|
|||||||
async fn ws_overlay_only_connection_holds_overlay_and_pending() {
|
async fn ws_overlay_only_connection_holds_overlay_and_pending() {
|
||||||
let conn = CallConnection::new_overlay_only(identity("ws-peer"));
|
let conn = CallConnection::new_overlay_only(identity("ws-peer"));
|
||||||
assert!(conn.connection().is_none());
|
assert!(conn.connection().is_none());
|
||||||
assert_eq!(conn.identity().map(|i| i.id.clone()), Some("ws-peer".to_string()));
|
assert_eq!(
|
||||||
|
conn.identity().map(|i| i.id.clone()),
|
||||||
|
Some("ws-peer".to_string())
|
||||||
|
);
|
||||||
assert!(conn.pending().lock().is_empty());
|
assert!(conn.pending().lock().is_empty());
|
||||||
|
|
||||||
let env = conn.overlay_env();
|
let env = conn.overlay_env();
|
||||||
|
|||||||
@@ -84,8 +84,9 @@ async fn ws_upgrade_handler_inner(
|
|||||||
};
|
};
|
||||||
|
|
||||||
match ws_upgrade {
|
match ws_upgrade {
|
||||||
Some(upgrade) => upgrade
|
Some(upgrade) => upgrade.on_upgrade(move |socket| {
|
||||||
.on_upgrade(move |socket| run_ws_session(socket, registry, identity_provider, identity)),
|
run_ws_session(socket, registry, identity_provider, identity)
|
||||||
|
}),
|
||||||
None => {
|
None => {
|
||||||
let _ = registry;
|
let _ = registry;
|
||||||
let _ = identity_provider;
|
let _ = identity_provider;
|
||||||
@@ -240,19 +241,19 @@ fn serialize_envelope(envelope: &EventEnvelope) -> Result<Vec<u8>, serde_json::E
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use alknet_call::registry::context::{
|
||||||
|
AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv,
|
||||||
|
};
|
||||||
use alknet_call::registry::discovery::{
|
use alknet_call::registry::discovery::{
|
||||||
services_list_handler, services_list_spec, services_schema_handler, services_schema_spec,
|
services_list_handler, services_list_spec, services_schema_handler, services_schema_spec,
|
||||||
};
|
};
|
||||||
|
use alknet_call::registry::env::OperationEnv;
|
||||||
use alknet_call::registry::registration::{
|
use alknet_call::registry::registration::{
|
||||||
make_handler, HandlerRegistration, OperationProvenance,
|
make_handler, HandlerRegistration, OperationProvenance,
|
||||||
};
|
};
|
||||||
use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||||
use alknet_core::auth::{AuthToken, Identity};
|
use alknet_core::auth::{AuthToken, Identity};
|
||||||
use alknet_core::types::Capabilities;
|
use alknet_core::types::Capabilities;
|
||||||
use alknet_call::registry::context::{
|
|
||||||
AbortPolicy, CompositionAuthority, OperationContext, ScopedPeerEnv,
|
|
||||||
};
|
|
||||||
use alknet_call::registry::env::OperationEnv;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Mutex as StdMutex;
|
use std::sync::Mutex as StdMutex;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -331,9 +332,7 @@ mod tests {
|
|||||||
let mut registry = OperationRegistry::new();
|
let mut registry = OperationRegistry::new();
|
||||||
registry.register(HandlerRegistration::new(
|
registry.register(HandlerRegistration::new(
|
||||||
external_spec("echo/run", AccessControl::default()),
|
external_spec("echo/run", AccessControl::default()),
|
||||||
make_handler(|input, ctx| async move {
|
make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }),
|
||||||
ResponseEnvelope::ok(ctx.request_id, input)
|
|
||||||
}),
|
|
||||||
OperationProvenance::Local,
|
OperationProvenance::Local,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@@ -352,9 +351,7 @@ mod tests {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
make_handler(|input, ctx| async move {
|
make_handler(|input, ctx| async move { ResponseEnvelope::ok(ctx.request_id, input) }),
|
||||||
ResponseEnvelope::ok(ctx.request_id, input)
|
|
||||||
}),
|
|
||||||
OperationProvenance::Local,
|
OperationProvenance::Local,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@@ -519,9 +516,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn handle_inbound_envelope_forbidden_yields_call_error() {
|
async fn handle_inbound_envelope_forbidden_yields_call_error() {
|
||||||
let registry = registry_with_restricted_op();
|
let registry = registry_with_restricted_op();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("none", identity("unpriv")),
|
Arc::new(StaticIdentityProvider::new().with_token("none", identity("unpriv")));
|
||||||
);
|
|
||||||
let dp = dispatcher(registry, provider);
|
let dp = dispatcher(registry, provider);
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("unpriv")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("unpriv")));
|
||||||
|
|
||||||
@@ -727,9 +723,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn round_trip_call_requested_to_call_responded_over_ws_message_stream() {
|
async fn round_trip_call_requested_to_call_responded_over_ws_message_stream() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -753,9 +748,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn subscription_streams_multiple_call_responded_events() {
|
async fn subscription_streams_multiple_call_responded_events() {
|
||||||
let registry = registry_with_subscription();
|
let registry = registry_with_subscription();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(registry, provider);
|
let dp = dispatcher(registry, provider);
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -782,8 +776,10 @@ mod tests {
|
|||||||
.with_token("no-admin", identity_with_scopes("user", &["user"])),
|
.with_token("no-admin", identity_with_scopes("user", &["user"])),
|
||||||
);
|
);
|
||||||
let dp = dispatcher(registry, provider);
|
let dp = dispatcher(registry, provider);
|
||||||
let conn =
|
let conn = Arc::new(CallConnection::new_overlay_only(identity_with_scopes(
|
||||||
Arc::new(CallConnection::new_overlay_only(identity_with_scopes("user", &["user"])));
|
"user",
|
||||||
|
&["user"],
|
||||||
|
)));
|
||||||
|
|
||||||
let request = EventEnvelope::requested(
|
let request = EventEnvelope::requested(
|
||||||
"req-admin",
|
"req-admin",
|
||||||
@@ -882,8 +878,10 @@ mod tests {
|
|||||||
let overlay_env = conn.overlay_env();
|
let overlay_env = conn.overlay_env();
|
||||||
assert!(overlay_env.contains("ui/dragged"));
|
assert!(overlay_env.contains("ui/dragged"));
|
||||||
|
|
||||||
let composed_env: Arc<dyn OperationEnv + Send + Sync> = dp
|
let composed_env: Arc<dyn OperationEnv + Send + Sync> = dp.compose_root_env(
|
||||||
.compose_root_env(&conn, &root_context_for_compose("hub-call-1", overlay_env.clone()));
|
&conn,
|
||||||
|
&root_context_for_compose("hub-call-1", overlay_env.clone()),
|
||||||
|
);
|
||||||
let ctx = root_context_with_env("hub-call-1", composed_env);
|
let ctx = root_context_with_env("hub-call-1", composed_env);
|
||||||
let response = overlay_env
|
let response = overlay_env
|
||||||
.invoke("ui", "dragged", serde_json::json!({ "x": 5 }), &ctx)
|
.invoke("ui", "dragged", serde_json::json!({ "x": 5 }), &ctx)
|
||||||
@@ -935,9 +933,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_round_trips_binary_call_requested_to_call_responded() {
|
async fn drive_ws_session_round_trips_binary_call_requested_to_call_responded() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -960,7 +957,10 @@ mod tests {
|
|||||||
let env: EventEnvelope = serde_json::from_slice(&bytes).unwrap();
|
let env: EventEnvelope = serde_json::from_slice(&bytes).unwrap();
|
||||||
assert_eq!(env.r#type, EVENT_RESPONDED);
|
assert_eq!(env.r#type, EVENT_RESPONDED);
|
||||||
assert_eq!(env.id, "ws-socket-1");
|
assert_eq!(env.id, "ws-socket-1");
|
||||||
assert_eq!(env.payload.get("output"), Some(&serde_json::json!({ "v": 7 })));
|
assert_eq!(
|
||||||
|
env.payload.get("output"),
|
||||||
|
Some(&serde_json::json!({ "v": 7 }))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
other => panic!("expected binary, got {other:?}"),
|
other => panic!("expected binary, got {other:?}"),
|
||||||
}
|
}
|
||||||
@@ -972,9 +972,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_rejects_text_with_protocol_close() {
|
async fn drive_ws_session_rejects_text_with_protocol_close() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -999,9 +998,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_disconnect_aborts_in_flight_pending() {
|
async fn drive_ws_session_disconnect_aborts_in_flight_pending() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -1036,9 +1034,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_subscription_streams_call_responded_events() {
|
async fn drive_ws_session_subscription_streams_call_responded_events() {
|
||||||
let registry = registry_with_subscription();
|
let registry = registry_with_subscription();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -1077,9 +1074,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_invalid_binary_closes_with_protocol_error() {
|
async fn drive_ws_session_invalid_binary_closes_with_protocol_error() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -1102,9 +1098,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn drive_ws_session_client_close_terminates_server() {
|
async fn drive_ws_session_client_close_terminates_server() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
let dp = dispatcher(Arc::clone(®istry), Arc::clone(&provider));
|
||||||
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
let conn = Arc::new(CallConnection::new_overlay_only(identity("ws-peer")));
|
||||||
|
|
||||||
@@ -1164,17 +1159,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn send_text(&mut self, text: String) {
|
async fn send_text(&mut self, text: String) {
|
||||||
self.outbound_tx
|
self.outbound_tx.send(Message::Text(text.into())).await.ok();
|
||||||
.send(Message::Text(text.into()))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_close(&mut self) {
|
async fn send_close(&mut self) {
|
||||||
self.outbound_tx
|
self.outbound_tx.send(Message::Close(None)).await.ok();
|
||||||
.send(Message::Close(None))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn close(&mut self) {
|
async fn close(&mut self) {
|
||||||
@@ -1215,9 +1204,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn ws_upgrade_handler_returns_401_when_identity_is_none() {
|
async fn ws_upgrade_handler_returns_401_when_identity_is_none() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let identity: Option<Identity> = None;
|
let identity: Option<Identity> = None;
|
||||||
|
|
||||||
let response = ws_upgrade_handler_inner(registry, provider, identity, None).await;
|
let response = ws_upgrade_handler_inner(registry, provider, identity, None).await;
|
||||||
@@ -1227,9 +1215,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn ws_upgrade_handler_does_not_reject_when_identity_present() {
|
async fn ws_upgrade_handler_does_not_reject_when_identity_present() {
|
||||||
let registry = echo_registry();
|
let registry = echo_registry();
|
||||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
let provider: Arc<dyn IdentityProvider> =
|
||||||
StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")),
|
Arc::new(StaticIdentityProvider::new().with_token("ws-token", identity("ws-peer")));
|
||||||
);
|
|
||||||
let identity = identity("ws-peer");
|
let identity = identity("ws-peer");
|
||||||
|
|
||||||
let response = ws_upgrade_handler_inner(registry, provider, Some(identity), None).await;
|
let response = ws_upgrade_handler_inner(registry, provider, Some(identity), None).await;
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use alknet_call::client::OperationAdapter;
|
||||||
use alknet_call::protocol::wire::ResponseEnvelope;
|
use alknet_call::protocol::wire::ResponseEnvelope;
|
||||||
use alknet_call::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv};
|
use alknet_call::registry::context::{AbortPolicy, OperationContext, ScopedPeerEnv};
|
||||||
use alknet_call::registry::env::OperationEnv;
|
use alknet_call::registry::env::OperationEnv;
|
||||||
use alknet_call::registry::registration::OperationProvenance;
|
use alknet_call::registry::registration::OperationProvenance;
|
||||||
use alknet_call::client::OperationAdapter;
|
|
||||||
use alknet_core::types::Capabilities;
|
use alknet_core::types::Capabilities;
|
||||||
use alknet_http::adapters::FromMCP;
|
use alknet_http::adapters::FromMCP;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
@@ -22,8 +22,8 @@ use rmcp::model::{
|
|||||||
};
|
};
|
||||||
use rmcp::service::RequestContext;
|
use rmcp::service::RequestContext;
|
||||||
use rmcp::transport::{
|
use rmcp::transport::{
|
||||||
StreamableHttpServerConfig,
|
|
||||||
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
|
streamable_http_server::{session::local::LocalSessionManager, tower::StreamableHttpService},
|
||||||
|
StreamableHttpServerConfig,
|
||||||
};
|
};
|
||||||
use rmcp::{RoleServer, ServerHandler};
|
use rmcp::{RoleServer, ServerHandler};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -72,18 +72,19 @@ impl ServerHandler for EchoServer {
|
|||||||
&self,
|
&self,
|
||||||
_request: Option<PaginatedRequestParams>,
|
_request: Option<PaginatedRequestParams>,
|
||||||
_context: RequestContext<RoleServer>,
|
_context: RequestContext<RoleServer>,
|
||||||
) -> impl std::future::Future<
|
) -> impl std::future::Future<Output = Result<ListToolsResult, rmcp::ErrorData>>
|
||||||
Output = Result<ListToolsResult, rmcp::ErrorData>,
|
+ rmcp::service::MaybeSendFuture
|
||||||
> + rmcp::service::MaybeSendFuture + '_ {
|
+ '_ {
|
||||||
let tools = vec![
|
let tools = vec![
|
||||||
Tool::new_with_raw(
|
Tool::new_with_raw(
|
||||||
"echo",
|
"echo",
|
||||||
Some("Echo the input back as structured content".into()),
|
Some("Echo the input back as structured content".into()),
|
||||||
Arc::new(serde_json::Map::new()),
|
Arc::new(serde_json::Map::new()),
|
||||||
)
|
)
|
||||||
.with_raw_output_schema(Arc::new(serde_json::Map::from_iter([
|
.with_raw_output_schema(Arc::new(serde_json::Map::from_iter([(
|
||||||
("type".to_string(), Value::String("object".into())),
|
"type".to_string(),
|
||||||
]))),
|
Value::String("object".into()),
|
||||||
|
)]))),
|
||||||
Tool::new_with_raw(
|
Tool::new_with_raw(
|
||||||
"legacy",
|
"legacy",
|
||||||
Some("Legacy tool returning text content blocks".into()),
|
Some("Legacy tool returning text content blocks".into()),
|
||||||
@@ -101,22 +102,17 @@ impl ServerHandler for EchoServer {
|
|||||||
&self,
|
&self,
|
||||||
request: CallToolRequestParams,
|
request: CallToolRequestParams,
|
||||||
_context: RequestContext<RoleServer>,
|
_context: RequestContext<RoleServer>,
|
||||||
) -> impl std::future::Future<
|
) -> impl std::future::Future<Output = Result<CallToolResult, rmcp::ErrorData>>
|
||||||
Output = Result<CallToolResult, rmcp::ErrorData>,
|
+ rmcp::service::MaybeSendFuture
|
||||||
> + rmcp::service::MaybeSendFuture + '_ {
|
+ '_ {
|
||||||
let name = request.name.to_string();
|
let name = request.name.to_string();
|
||||||
std::future::ready(Ok(match name.as_str() {
|
std::future::ready(Ok(match name.as_str() {
|
||||||
"echo" => {
|
"echo" => {
|
||||||
let args = request
|
let args = request.arguments.map(Value::Object).unwrap_or(Value::Null);
|
||||||
.arguments
|
|
||||||
.map(Value::Object)
|
|
||||||
.unwrap_or(Value::Null);
|
|
||||||
CallToolResult::structured(serde_json::json!({ "echoed": args }))
|
CallToolResult::structured(serde_json::json!({ "echoed": args }))
|
||||||
}
|
}
|
||||||
"legacy" => CallToolResult::success(vec![Content::text("plain text result")]),
|
"legacy" => CallToolResult::success(vec![Content::text("plain text result")]),
|
||||||
other => CallToolResult::error(vec![Content::text(format!(
|
other => CallToolResult::error(vec![Content::text(format!("unknown tool: {other}"))]),
|
||||||
"unknown tool: {other}"
|
|
||||||
))]),
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user