diff --git a/Cargo.lock b/Cargo.lock index 8c6edf9..d703b40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,9 +105,12 @@ version = "0.1.0" dependencies = [ "alknet-call", "alknet-core", + "arc-swap", "async-trait", "axum", "futures", + "http", + "httpdate", "hyper", "openapiv3", "reqwest 0.13.4", @@ -119,6 +122,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", + "url", "uuid", ] @@ -3244,6 +3248,7 @@ version = "0.11.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fcb935c5bec503c2f0e306bdd3e58bb9029dcb14fa8d9ac76e3a5256ac0763e" dependencies = [ + "aws-lc-rs", "bytes", "fastbloom", "getrandom 0.3.4", @@ -3538,15 +3543,21 @@ dependencies = [ "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "rustls-platform-verifier", "serde", "serde_json", "sync_wrapper", "tokio", + "tokio-rustls", "tokio-util", "tower", "tower-http", diff --git a/crates/alknet-http/Cargo.toml b/crates/alknet-http/Cargo.toml index 621112f..3812cf5 100644 --- a/crates/alknet-http/Cargo.toml +++ b/crates/alknet-http/Cargo.toml @@ -18,9 +18,11 @@ http1 = ["dep:hyper"] [dependencies] alknet-core = { path = "../alknet-core" } alknet-call = { path = "../alknet-call" } +arc-swap = "1" axum = { version = "0.8", features = ["ws"] } hyper = { version = "1", optional = true, features = ["server", "http1", "http2"] } -reqwest = { version = "0.13", default-features = false, features = ["json", "stream"] } +httpdate = "1" +reqwest = { version = "0.13", default-features = false, features = ["json", "stream", "rustls"] } reqwest-middleware = "0.5" reqwest-retry = "0.9" tokio = { version = "1", features = ["full"] } @@ -32,6 +34,8 @@ thiserror = "2" uuid = { version = "1", features = ["v4"] } futures = "0.3" openapiv3 = "2" +http = "1" +url = "2" rmcp = { version = "1.8", optional = true, default-features = false, features = [ "client", "server", diff --git a/crates/alknet-http/src/client/http_client.rs b/crates/alknet-http/src/client/http_client.rs new file mode 100644 index 0000000..c0155dd --- /dev/null +++ b/crates/alknet-http/src/client/http_client.rs @@ -0,0 +1,329 @@ +//! Shared HTTP client: `reqwest_middleware::ClientWithMiddleware` with a +//! retry stack (RetryTransientMiddleware + inlined RetryAfterMiddleware), +//! connection pooling, keep-alive, TLS, and rebuild-and-swap hot-reload. +//! +//! Credential injection happens per-request (from +//! `OperationContext.capabilities`), not at client construction — the +//! client is shared across all operations, the credentials are per-call. + +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use reqwest::ClientBuilder; +use reqwest_middleware::ClientWithMiddleware; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; +use thiserror::Error; + +use super::retry_after::RetryAfterMiddleware; + +const DEFAULT_RETRY_AFTER_CAPACITY: usize = 256; + +#[derive(Debug, Clone)] +pub struct ClientCertConfig { + pub cert_pem: PathBuf, + pub key_pem: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct HttpClientConfig { + pub pool_max_idle_per_host: Option, + pub request_timeout: Option, + pub retry_policy: ExponentialBackoff, + pub ca_bundle: Option, + pub client_cert: Option, +} + +impl Default for HttpClientConfig { + fn default() -> Self { + Self { + pool_max_idle_per_host: None, + request_timeout: None, + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + ca_bundle: None, + client_cert: None, + } + } +} + +#[derive(Debug, Error)] +pub enum HttpClientBuildError { + #[error("failed to read CA bundle from {path}: {source}")] + CaBundleRead { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("failed to parse CA bundle at {path}: {source}")] + CaBundleParse { + path: PathBuf, + #[source] + source: reqwest::Error, + }, + #[error("failed to read client cert from {path}: {source}")] + ClientCertRead { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("failed to parse client cert at {path}: {source}")] + ClientCertParse { + path: PathBuf, + #[source] + source: reqwest::Error, + }, + #[error("failed to build reqwest client: {0}")] + Build(reqwest::Error), +} + +pub struct SharedHttpClient { + inner: ArcSwap, + config: ArcSwap, +} + +impl std::fmt::Debug for SharedHttpClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedHttpClient") + .field("config", &self.config.load()) + .finish_non_exhaustive() + } +} + +impl SharedHttpClient { + pub fn new(config: HttpClientConfig) -> Result { + let client = build_client(&config)?; + Ok(Self { + inner: ArcSwap::from_pointee(client), + config: ArcSwap::from_pointee(config), + }) + } + + pub fn client(&self) -> Arc { + self.inner.load_full() + } + + pub fn config(&self) -> Arc { + self.config.load_full() + } + + pub fn reload(&self, config: HttpClientConfig) -> Result<(), HttpClientBuildError> { + let client = build_client(&config)?; + self.config.store(Arc::new(config)); + self.inner.store(Arc::new(client)); + Ok(()) + } +} + +fn build_client(config: &HttpClientConfig) -> Result { + let mut builder = ClientBuilder::new(); + if let Some(pool_max_idle) = config.pool_max_idle_per_host { + builder = builder.pool_max_idle_per_host(pool_max_idle); + } + if let Some(timeout) = config.request_timeout { + builder = builder.timeout(timeout); + } + if let Some(ca_bundle_path) = &config.ca_bundle { + let pem = std::fs::read(ca_bundle_path).map_err(|source| HttpClientBuildError::CaBundleRead { + path: ca_bundle_path.clone(), + source, + })?; + let certs = reqwest::Certificate::from_pem_bundle(&pem).map_err(|source| { + HttpClientBuildError::CaBundleParse { + path: ca_bundle_path.clone(), + source, + } + })?; + for cert in certs { + builder = builder.add_root_certificate(cert); + } + } + if let Some(client_cert_cfg) = &config.client_cert { + let cert_pem = std::fs::read(&client_cert_cfg.cert_pem).map_err(|source| { + HttpClientBuildError::ClientCertRead { + path: client_cert_cfg.cert_pem.clone(), + source, + } + })?; + let key_pem = std::fs::read(&client_cert_cfg.key_pem).map_err(|source| { + HttpClientBuildError::ClientCertRead { + path: client_cert_cfg.key_pem.clone(), + source, + } + })?; + let identity = reqwest::Identity::from_pem( + concat_pem(&cert_pem, &key_pem).as_slice(), + ) + .map_err(|source| HttpClientBuildError::ClientCertParse { + path: client_cert_cfg.cert_pem.clone(), + source, + })?; + builder = builder.identity(identity); + } + let reqwest_client = builder.build().map_err(HttpClientBuildError::Build)?; + let client = reqwest_middleware::ClientBuilder::new(reqwest_client) + .with(RetryTransientMiddleware::new_with_policy(config.retry_policy)) + .with(RetryAfterMiddleware::with_capacity(DEFAULT_RETRY_AFTER_CAPACITY)) + .build(); + Ok(client) +} + +fn concat_pem(cert: &[u8], key: &[u8]) -> Vec { + let mut combined = Vec::with_capacity(cert.len() + key.len() + 1); + combined.extend_from_slice(cert); + if !cert.is_empty() && cert.last() != Some(&b'\n') { + combined.push(b'\n'); + } + combined.extend_from_slice(key); + combined +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::SystemTime; + + fn minimal_config() -> HttpClientConfig { + HttpClientConfig { + pool_max_idle_per_host: Some(8), + request_timeout: Some(Duration::from_secs(30)), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(2), + ca_bundle: None, + client_cert: None, + } + } + + #[test] + fn client_returns_a_usable_client_with_middleware() { + let http = SharedHttpClient::new(minimal_config()).expect("client builds"); + let client = http.client(); + let request = client + .get("https://api.example.com/v1/chat") + .build() + .expect("RequestBuilder builds"); + assert_eq!(request.method(), reqwest::Method::GET); + assert_eq!( + request.url().as_str(), + "https://api.example.com/v1/chat" + ); + } + + #[test] + fn reload_swaps_the_client_returned_by_client() { + let http = SharedHttpClient::new(minimal_config()).expect("client builds"); + let before = http.client(); + let new_config = HttpClientConfig { + pool_max_idle_per_host: Some(32), + request_timeout: Some(Duration::from_secs(10)), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(5), + ca_bundle: None, + client_cert: None, + }; + http.reload(new_config.clone()).expect("reload succeeds"); + let after = http.client(); + assert!( + !Arc::ptr_eq(&before, &after), + "reload must swap in a new ClientWithMiddleware" + ); + let config = http.config(); + assert_eq!(config.pool_max_idle_per_host, Some(32)); + assert_eq!(config.request_timeout, Some(Duration::from_secs(10))); + } + + #[test] + fn config_returns_current_config() { + let http = SharedHttpClient::new(minimal_config()).expect("client builds"); + let config = http.config(); + assert_eq!(config.pool_max_idle_per_host, Some(8)); + assert_eq!(config.request_timeout, Some(Duration::from_secs(30))); + } + + #[test] + fn default_config_has_sensible_defaults() { + let config = HttpClientConfig::default(); + assert!(config.pool_max_idle_per_host.is_none()); + assert!(config.request_timeout.is_none()); + assert!(config.ca_bundle.is_none()); + assert!(config.client_cert.is_none()); + assert_eq!(config.retry_policy.max_n_retries, Some(3)); + } + + #[test] + fn reload_with_ca_bundle_missing_file_errors() { + let http = SharedHttpClient::new(minimal_config()).expect("client builds"); + let bad_config = HttpClientConfig { + ca_bundle: Some(PathBuf::from("/nonexistent/ca-bundle.pem")), + ..minimal_config() + }; + let err = http.reload(bad_config).unwrap_err(); + assert!(matches!(err, HttpClientBuildError::CaBundleRead { .. })); + } + + #[test] + fn concat_pem_inserts_separator_between_cert_and_key() { + let cert = b"-----BEGIN CERTIFICATE-----\ncert-body\n-----END CERTIFICATE-----"; + let key = b"-----BEGIN PRIVATE KEY-----\nkey-body\n-----END PRIVATE KEY-----"; + let combined = concat_pem(cert, key); + assert!(combined.starts_with(b"-----BEGIN CERTIFICATE-----")); + assert!(combined.windows(20).any(|w| w == b"-----END CERTIFICATE")); + assert!(combined.windows(18).any(|w| w == b"-----BEGIN PRIVATE")); + } + + #[test] + fn concat_pem_handles_cert_already_terminated_with_newline() { + let cert = b"-----BEGIN CERTIFICATE-----\ncert-body\n-----END CERTIFICATE-----\n"; + let key = b"-----BEGIN PRIVATE KEY-----\nkey-body\n-----END PRIVATE KEY-----"; + let combined = concat_pem(cert, key); + let joined = std::str::from_utf8(&combined).unwrap(); + assert!( + !joined.contains("-----END CERTIFICATE----------BEGIN PRIVATE"), + "must not concatenate without a separator when cert lacks trailing newline" + ); + assert!(joined.contains("-----END CERTIFICATE-----\n-----BEGIN PRIVATE")); + } + + #[test] + fn client_cert_config_constructs() { + let cfg = ClientCertConfig { + cert_pem: PathBuf::from("/etc/cert.pem"), + key_pem: PathBuf::from("/etc/key.pem"), + }; + assert_eq!(cfg.cert_pem, PathBuf::from("/etc/cert.pem")); + assert_eq!(cfg.key_pem, PathBuf::from("/etc/key.pem")); + } + + #[test] + fn new_with_missing_ca_bundle_errors() { + let config = HttpClientConfig { + ca_bundle: Some(PathBuf::from("/nonexistent/ca-bundle.pem")), + ..HttpClientConfig::default() + }; + let err = SharedHttpClient::new(config).unwrap_err(); + assert!(matches!(err, HttpClientBuildError::CaBundleRead { .. })); + } + + #[test] + fn build_error_display_contains_path() { + let err = HttpClientBuildError::CaBundleRead { + path: PathBuf::from("/nonexistent/ca.pem"), + source: std::io::Error::new(std::io::ErrorKind::NotFound, "missing"), + }; + let rendered = format!("{err}"); + assert!(rendered.contains("/nonexistent/ca.pem")); + } + + #[test] + fn retry_after_capacity_constant_is_bounded() { + let cap = DEFAULT_RETRY_AFTER_CAPACITY; + assert!(cap > 0, "RetryAfterMiddleware storage must be non-zero"); + assert!(cap <= 4096, "RetryAfterMiddleware storage must be bounded"); + } + + #[test] + fn no_env_vars_read_in_default_config() { + let _ = SystemTime::now(); + let config = HttpClientConfig::default(); + assert!(config.ca_bundle.is_none()); + } +} \ No newline at end of file diff --git a/crates/alknet-http/src/client/mod.rs b/crates/alknet-http/src/client/mod.rs index a301713..5101a4f 100644 --- a/crates/alknet-http/src/client/mod.rs +++ b/crates/alknet-http/src/client/mod.rs @@ -3,4 +3,8 @@ //! //! See `docs/architecture/crates/http/http-adapters.md` and OQ-40. -// TODO: implement +mod http_client; +mod retry_after; + +pub use http_client::{ClientCertConfig, HttpClientBuildError, HttpClientConfig, SharedHttpClient}; +pub use retry_after::RetryAfterMiddleware; \ No newline at end of file diff --git a/crates/alknet-http/src/client/retry_after.rs b/crates/alknet-http/src/client/retry_after.rs new file mode 100644 index 0000000..a8f0995 --- /dev/null +++ b/crates/alknet-http/src/client/retry_after.rs @@ -0,0 +1,284 @@ +//! Inlined `RetryAfterMiddleware`: parses the `Retry-After` header on +//! 429/503 and sleeps before the next request to that URL. +//! +//! Inlined (MIT, from `melotic/reqwest-retry-after`) so the upstream's +//! unbounded `HashMap` storage can be bounded for a +//! long-running process. The bound is enforced via LRU eviction: when +//! the map is at capacity, the entry with the earliest deadline is +//! evicted first (those are the most likely to have already elapsed +//! and are the cheapest to drop). + +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::{Duration, SystemTime}; + +use http::Extensions; +use reqwest::{Request, Response, StatusCode}; +use reqwest_middleware::{Middleware, Next, Result}; +use url::Url; + +const RETRY_AFTER_HEADER: &str = "retry-after"; +const THROTTLED_STATUS: &[u16] = &[StatusCode::TOO_MANY_REQUESTS.as_u16(), 503]; + +fn is_throttled(status: u16) -> bool { + THROTTLED_STATUS.contains(&status) +} + +fn parse_retry_after(value: &str) -> Option { + let trimmed = value.trim(); + if let Ok(secs) = trimmed.parse::() { + return SystemTime::now() + .checked_add(Duration::from_secs(secs)) + .filter(|deadline| *deadline > SystemTime::now()); + } + httpdate::parse_http_date(trimmed) + .ok() + .filter(|deadline| *deadline > SystemTime::now()) +} + +pub struct RetryAfterMiddleware { + deadlines: Mutex>, + capacity: usize, +} + +impl RetryAfterMiddleware { + pub fn with_capacity(capacity: usize) -> Self { + Self { + deadlines: Mutex::new(HashMap::with_capacity(capacity.min(128))), + capacity, + } + } + + fn record(&self, url: Url, deadline: SystemTime) { + let mut deadlines = self.deadlines.lock().expect("deadlines mutex poisoned"); + if !deadlines.contains_key(&url) && deadlines.len() >= self.capacity { + self.evict(&mut deadlines); + } + deadlines.insert(url, deadline); + } + + fn evict(&self, deadlines: &mut HashMap) { + if let Some(evict_url) = deadlines + .iter() + .min_by_key(|(_, deadline)| *deadline) + .map(|(url, _)| url.clone()) + { + deadlines.remove(&evict_url); + } + } + + fn deadline_for(&self, url: &Url) -> Option { + let deadlines = self.deadlines.lock().expect("deadlines mutex poisoned"); + deadlines.get(url).copied() + } + + async fn maybe_sleep_for(&self, url: &Url) { + if let Some(deadline) = self.deadline_for(url) { + if let Ok(remaining) = deadline.duration_since(SystemTime::now()) { + if !remaining.is_zero() { + tokio::time::sleep(remaining).await; + } + } + } + } + + fn record_if_throttled(&self, url: Url, response: &Response) { + let status = response.status(); + if is_throttled(status.as_u16()) { + if let Some(retry_after) = response + .headers() + .get(RETRY_AFTER_HEADER) + .and_then(|value| value.to_str().ok()) + { + if let Some(deadline) = parse_retry_after(retry_after) { + self.record(url, deadline); + } + } + } + } + + #[cfg(test)] + fn len(&self) -> usize { + self.deadlines.lock().expect("deadlines mutex poisoned").len() + } + + #[cfg(test)] + fn deadline_for_test(&self, url: &Url) -> Option { + self.deadline_for(url) + } + + #[cfg(test)] + fn record_test(&self, url: Url, deadline: SystemTime) { + self.record(url, deadline); + } +} + +#[async_trait::async_trait] +impl Middleware for RetryAfterMiddleware { + async fn handle( + &self, + req: Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> Result { + let req_url = req.url().clone(); + self.maybe_sleep_for(&req_url).await; + let response = next.run(req, extensions).await?; + self.record_if_throttled(req_url, &response); + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn url(s: &str) -> Url { + Url::parse(s).unwrap() + } + + fn synthetic_response(status: StatusCode, retry_after: Option<&str>) -> Response { + let mut builder = http::Response::builder().status(status); + if let Some(value) = retry_after { + builder = builder.header(RETRY_AFTER_HEADER, value); + } + builder.body("").unwrap().into() + } + + #[test] + fn parse_retry_after_seconds() { + let deadline = parse_retry_after("5").expect("seconds value parses"); + let now = SystemTime::now(); + let lower = now.checked_add(Duration::from_secs(4)).unwrap(); + let upper = now.checked_add(Duration::from_secs(6)).unwrap(); + assert!(deadline > lower && deadline < upper); + } + + #[test] + fn parse_retry_after_http_date() { + let deadline = parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT") + .expect("HTTP-date value parses"); + assert!(deadline > SystemTime::now()); + } + + #[test] + fn parse_retry_after_past_http_date_yields_none() { + let deadline = parse_retry_after("Wed, 21 Oct 2015 07:28:00 GMT"); + assert!( + deadline.is_none(), + "a deadline already in the past must not be recorded" + ); + } + + #[test] + fn parse_retry_after_invalid_yields_none() { + assert!(parse_retry_after("not-a-date").is_none()); + assert!(parse_retry_after("").is_none()); + } + + #[test] + fn record_stores_deadline_for_url() { + let mw = RetryAfterMiddleware::with_capacity(8); + let u = url("https://api.example.com/v1/chat"); + let deadline = SystemTime::now() + Duration::from_secs(10); + mw.record_test(u.clone(), deadline); + assert_eq!(mw.deadline_for_test(&u), Some(deadline)); + assert_eq!(mw.len(), 1); + } + + #[test] + fn record_evicts_oldest_when_at_capacity() { + let mw = RetryAfterMiddleware::with_capacity(2); + let u1 = url("https://a.example.com"); + let u2 = url("https://b.example.com"); + let u3 = url("https://c.example.com"); + mw.record_test(u1.clone(), SystemTime::now() + Duration::from_secs(100)); + mw.record_test(u2.clone(), SystemTime::now() + Duration::from_secs(1)); + assert_eq!(mw.len(), 2); + mw.record_test(u3.clone(), SystemTime::now() + Duration::from_secs(50)); + assert_eq!(mw.len(), 2, "capacity must be enforced"); + assert!( + mw.deadline_for_test(&u2).is_none(), + "entry with the earliest deadline must be evicted" + ); + assert!(mw.deadline_for_test(&u1).is_some()); + assert!(mw.deadline_for_test(&u3).is_some()); + } + + #[test] + fn record_overwrites_existing_url_deadline_without_evicting() { + let mw = RetryAfterMiddleware::with_capacity(2); + let u = url("https://api.example.com/v1/chat"); + let far = url("https://far.example.com"); + mw.record_test(u.clone(), SystemTime::now() + Duration::from_secs(10)); + mw.record_test(far.clone(), SystemTime::now() + Duration::from_secs(20)); + mw.record_test(u.clone(), SystemTime::now() + Duration::from_secs(30)); + assert_eq!( + mw.len(), + 2, + "overwriting an existing URL must not evict another entry" + ); + assert!(mw.deadline_for_test(&far).is_some()); + } + + #[tokio::test] + async fn middleware_records_deadline_from_seconds_header() { + let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); + let target = url("https://api.example.com/v1/chat"); + let response = synthetic_response(StatusCode::TOO_MANY_REQUESTS, Some("5")); + mw.record_if_throttled(target.clone(), &response); + let deadline = mw + .deadline_for_test(&target) + .expect("429 with Retry-After records a deadline"); + let now = SystemTime::now(); + assert!(deadline > now, "deadline must be in the future"); + assert!(deadline < now + Duration::from_secs(6)); + } + + #[tokio::test] + async fn middleware_records_deadline_from_http_date_header() { + let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); + let target = url("https://api.example.com/v1/chat"); + let response = synthetic_response( + StatusCode::SERVICE_UNAVAILABLE, + Some("Wed, 21 Oct 2099 07:28:00 GMT"), + ); + mw.record_if_throttled(target.clone(), &response); + let deadline = mw + .deadline_for_test(&target) + .expect("503 with Retry-After HTTP-date records a deadline"); + assert!(deadline > SystemTime::now()); + } + + #[tokio::test] + async fn middleware_does_not_record_on_non_throttled_status() { + let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); + let target = url("https://api.example.com/v1/chat"); + let response = synthetic_response(StatusCode::OK, Some("5")); + mw.record_if_throttled(target.clone(), &response); + assert!(mw.deadline_for_test(&target).is_none()); + } + + #[tokio::test] + async fn middleware_does_not_record_when_header_absent() { + let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); + let target = url("https://api.example.com/v1/chat"); + let response = synthetic_response(StatusCode::TOO_MANY_REQUESTS, None); + mw.record_if_throttled(target.clone(), &response); + assert!(mw.deadline_for_test(&target).is_none()); + } + + #[tokio::test] + async fn middleware_sleeps_before_request_with_active_deadline() { + let mw = std::sync::Arc::new(RetryAfterMiddleware::with_capacity(8)); + let target = url("https://api.example.com/v1/chat"); + mw.record_test(target.clone(), SystemTime::now() + Duration::from_millis(50)); + let started = SystemTime::now(); + mw.maybe_sleep_for(&target).await; + let elapsed = SystemTime::now().duration_since(started).unwrap(); + assert!( + elapsed >= Duration::from_millis(40), + "middleware must sleep until the deadline elapses" + ); + } +} \ No newline at end of file