diff --git a/crates/alknet-core/src/call/env.rs b/crates/alknet-core/src/call/env.rs index 94511f4..89410da 100644 --- a/crates/alknet-core/src/call/env.rs +++ b/crates/alknet-core/src/call/env.rs @@ -5,19 +5,44 @@ use serde_json::Value; use crate::call::context::OperationContext; use crate::call::registry::OperationRegistry; use crate::call::response::ResponseEnvelope; +use crate::credentials::{CredentialProvider, CredentialSet, SecretStoreCredentialProvider}; -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct OperationEnv { registry: Arc, + credential_provider: Arc, +} + +impl std::fmt::Debug for OperationEnv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OperationEnv") + .field("registry", &self.registry) + .finish() + } } impl OperationEnv { pub fn local(registry: OperationRegistry) -> Self { Self { registry: Arc::new(registry), + credential_provider: Arc::new(SecretStoreCredentialProvider::new()), } } + pub fn with_credential_provider( + registry: OperationRegistry, + credential_provider: Arc, + ) -> Self { + Self { + registry: Arc::new(registry), + credential_provider, + } + } + + pub fn credentials(&self, service: &str) -> Option { + self.credential_provider.get_credentials(service) + } + pub fn invoke(&self, namespace: &str, operation: &str, input: Value) -> ResponseEnvelope { let name = format!("/{namespace}/{operation}"); let request_id = format!("env{name}"); @@ -42,6 +67,10 @@ mod tests { use super::*; use crate::call::registry::OperationRegistryBuilder; use crate::call::spec::{AccessControl, OperationSpec, OperationType}; + use crate::config::{AuthPolicy, DynamicConfig}; + use crate::credentials::ConfigCredentialProvider; + use arc_swap::ArcSwap; + use std::collections::HashMap; fn make_spec(name: &str, namespace: &str) -> OperationSpec { OperationSpec { @@ -101,4 +130,61 @@ mod tests { let result = env.invoke("auth", "verify", serde_json::json!(null)); assert!(result.result.is_ok()); } + + #[test] + fn operation_env_provides_credentials_from_handler_context() { + let mut credentials = HashMap::new(); + credentials.insert( + "vast-ai".to_string(), + CredentialSet::Bearer { + token: "test-token".to_string(), + }, + ); + let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials); + let dynamic = Arc::new(ArcSwap::new(Arc::new(config))); + let provider = Arc::new(ConfigCredentialProvider::new(dynamic)); + + let registry = OperationRegistryBuilder::new() + .with( + make_spec("/test/creds", "test"), + Arc::new(|_input, ctx| { + let creds = ctx.env.credentials("vast-ai"); + match creds { + Some(CredentialSet::Bearer { token }) => ResponseEnvelope::ok( + &ctx.request_id, + serde_json::json!({"token": token}), + ), + _ => ResponseEnvelope::ok( + &ctx.request_id, + serde_json::json!({"found": false}), + ), + } + }), + ) + .build(); + + let env = OperationEnv::with_credential_provider(registry, provider); + let result = env.invoke("test", "creds", serde_json::json!(null)); + assert!(result.result.is_ok()); + let value = result.result.unwrap(); + assert_eq!(value["token"], "test-token"); + } + + #[test] + fn operation_env_credentials_returns_none_for_missing_service() { + let config = DynamicConfig::default(); + let dynamic = Arc::new(ArcSwap::new(Arc::new(config))); + let provider = Arc::new(ConfigCredentialProvider::new(dynamic)); + + let registry = OperationRegistry::new(); + let env = OperationEnv::with_credential_provider(registry, provider); + assert!(env.credentials("nonexistent").is_none()); + } + + #[test] + fn operation_env_default_credentials_returns_none() { + let registry = OperationRegistry::new(); + let env = OperationEnv::local(registry); + assert!(env.credentials("vast-ai").is_none()); + } } diff --git a/crates/alknet-core/src/config/config_service.rs b/crates/alknet-core/src/config/config_service.rs index fe55c2b..8615ca8 100644 --- a/crates/alknet-core/src/config/config_service.rs +++ b/crates/alknet-core/src/config/config_service.rs @@ -79,6 +79,7 @@ mod tests { auth: AuthPolicy::empty(), forwarding: ForwardingPolicy::deny_all(), rate_limits: RateLimitConfig::default(), + credentials: std::collections::HashMap::new(), }; service.reload(new_config); diff --git a/crates/alknet-core/src/config/dynamic_config.rs b/crates/alknet-core/src/config/dynamic_config.rs index ca7f588..97332d7 100644 --- a/crates/alknet-core/src/config/dynamic_config.rs +++ b/crates/alknet-core/src/config/dynamic_config.rs @@ -11,6 +11,7 @@ use russh::keys::ssh_key::HashAlg; use crate::auth::identity::Identity; use crate::auth::ServerAuthConfig; use crate::config::forwarding::ForwardingPolicy; +use crate::credentials::CredentialSet; pub struct AuthPolicy { pub authorized_keys: std::collections::HashSet, @@ -238,6 +239,7 @@ pub struct DynamicConfig { pub auth: AuthPolicy, pub forwarding: ForwardingPolicy, pub rate_limits: RateLimitConfig, + pub credentials: HashMap, } impl DynamicConfig { @@ -246,6 +248,7 @@ impl DynamicConfig { auth, forwarding: ForwardingPolicy::allow_all(), rate_limits: RateLimitConfig::default(), + credentials: HashMap::new(), } } @@ -258,6 +261,7 @@ impl DynamicConfig { auth, forwarding, rate_limits, + credentials: HashMap::new(), } } @@ -270,6 +274,11 @@ impl DynamicConfig { self.rate_limits = limits; self } + + pub fn with_credentials(mut self, credentials: HashMap) -> Self { + self.credentials = credentials; + self + } } impl Default for DynamicConfig { @@ -278,6 +287,7 @@ impl Default for DynamicConfig { auth: AuthPolicy::empty(), forwarding: ForwardingPolicy::allow_all(), rate_limits: RateLimitConfig::default(), + credentials: HashMap::new(), } } } @@ -351,6 +361,7 @@ mod tests { auth: AuthPolicy::empty(), forwarding: ForwardingPolicy::deny_all(), rate_limits: RateLimitConfig::default(), + credentials: HashMap::new(), }; handle.reload(new_config); diff --git a/crates/alknet-core/src/credentials/mod.rs b/crates/alknet-core/src/credentials/mod.rs new file mode 100644 index 0000000..babb404 --- /dev/null +++ b/crates/alknet-core/src/credentials/mod.rs @@ -0,0 +1,241 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arc_swap::ArcSwap; +use serde::{Deserialize, Serialize}; + +use crate::config::DynamicConfig; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[non_exhaustive] +pub enum CredentialSet { + ApiKey { + header_name: String, + token: String, + }, + Basic { + username: String, + password: String, + }, + Bearer { + token: String, + }, + S3AccessKey { + access_key: String, + secret_key: String, + session_token: Option, + }, + OidcToken { + access_token: String, + refresh_token: Option, + expires_at: Option, + }, + Custom { + scheme: String, + params: HashMap, + }, +} + +pub trait CredentialProvider: Send + Sync + 'static { + fn get_credentials(&self, service: &str) -> Option; + fn refresh_credentials(&self, service: &str) -> Option; +} + +pub struct ConfigCredentialProvider { + dynamic: Arc>, +} + +impl ConfigCredentialProvider { + pub fn new(dynamic: Arc>) -> Self { + Self { dynamic } + } +} + +impl CredentialProvider for ConfigCredentialProvider { + fn get_credentials(&self, service: &str) -> Option { + let config = self.dynamic.load(); + config.credentials.get(service).cloned() + } + + fn refresh_credentials(&self, service: &str) -> Option { + self.get_credentials(service) + } +} + +pub struct SecretStoreCredentialProvider; + +impl SecretStoreCredentialProvider { + pub fn new() -> Self { + Self + } +} + +impl Default for SecretStoreCredentialProvider { + fn default() -> Self { + Self::new() + } +} + +impl CredentialProvider for SecretStoreCredentialProvider { + fn get_credentials(&self, _service: &str) -> Option { + None + } + + fn refresh_credentials(&self, _service: &str) -> Option { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::AuthPolicy; + + fn make_dynamic_with_credentials() -> Arc> { + let mut credentials = HashMap::new(); + credentials.insert( + "vast-ai".to_string(), + CredentialSet::Bearer { + token: "secret-token".to_string(), + }, + ); + credentials.insert( + "custom-service".to_string(), + CredentialSet::ApiKey { + header_name: "X-API-Key".to_string(), + token: "api-key-123".to_string(), + }, + ); + let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials); + Arc::new(ArcSwap::new(Arc::new(config))) + } + + fn make_dynamic_empty() -> Arc> { + let config = DynamicConfig::default(); + Arc::new(ArcSwap::new(Arc::new(config))) + } + + #[test] + fn config_credential_provider_returns_configured_credentials() { + let dynamic = make_dynamic_with_credentials(); + let provider = ConfigCredentialProvider::new(dynamic); + let creds = provider.get_credentials("vast-ai"); + assert!(creds.is_some()); + match creds.unwrap() { + CredentialSet::Bearer { token } => assert_eq!(token, "secret-token"), + _ => panic!("expected Bearer variant"), + } + } + + #[test] + fn config_credential_provider_returns_api_key_variant() { + let dynamic = make_dynamic_with_credentials(); + let provider = ConfigCredentialProvider::new(dynamic); + let creds = provider.get_credentials("custom-service"); + assert!(creds.is_some()); + match creds.unwrap() { + CredentialSet::ApiKey { header_name, token } => { + assert_eq!(header_name, "X-API-Key"); + assert_eq!(token, "api-key-123"); + } + _ => panic!("expected ApiKey variant"), + } + } + + #[test] + fn config_credential_provider_returns_none_for_unknown_service() { + let dynamic = make_dynamic_with_credentials(); + let provider = ConfigCredentialProvider::new(dynamic); + let creds = provider.get_credentials("nonexistent"); + assert!(creds.is_none()); + } + + #[test] + fn config_credential_provider_empty_config_returns_none() { + let dynamic = make_dynamic_empty(); + let provider = ConfigCredentialProvider::new(dynamic); + let creds = provider.get_credentials("vast-ai"); + assert!(creds.is_none()); + } + + #[test] + fn secret_store_credential_provider_returns_none() { + let provider = SecretStoreCredentialProvider::new(); + assert!(provider.get_credentials("vast-ai").is_none()); + assert!(provider.get_credentials("rustfs").is_none()); + assert!(provider.get_credentials("gitea").is_none()); + } + + #[test] + fn secret_store_credential_provider_refresh_returns_none() { + let provider = SecretStoreCredentialProvider::new(); + assert!(provider.refresh_credentials("vast-ai").is_none()); + } + + #[test] + fn credential_set_bearer_serialization() { + let creds = CredentialSet::Bearer { + token: "tok".to_string(), + }; + let json = serde_json::to_string(&creds).unwrap(); + let deserialized: CredentialSet = serde_json::from_str(&json).unwrap(); + assert_eq!(creds, deserialized); + } + + #[test] + fn credential_set_s3_access_key_serialization() { + let creds = CredentialSet::S3AccessKey { + access_key: "AKIA123".to_string(), + secret_key: "secret".to_string(), + session_token: Some("session".to_string()), + }; + let json = serde_json::to_string(&creds).unwrap(); + let deserialized: CredentialSet = serde_json::from_str(&json).unwrap(); + assert_eq!(creds, deserialized); + } + + #[test] + fn credential_set_oidc_token_serialization() { + let creds = CredentialSet::OidcToken { + access_token: "access".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(1234567890), + }; + let json = serde_json::to_string(&creds).unwrap(); + let deserialized: CredentialSet = serde_json::from_str(&json).unwrap(); + assert_eq!(creds, deserialized); + } + + #[test] + fn credential_set_custom_serialization() { + let mut params = HashMap::new(); + params.insert("key1".to_string(), "val1".to_string()); + let creds = CredentialSet::Custom { + scheme: "X-Custom".to_string(), + params, + }; + let json = serde_json::to_string(&creds).unwrap(); + let deserialized: CredentialSet = serde_json::from_str(&json).unwrap(); + assert_eq!(creds, deserialized); + } + + #[test] + fn credential_set_basic_serialization() { + let creds = CredentialSet::Basic { + username: "user".to_string(), + password: "pass".to_string(), + }; + let json = serde_json::to_string(&creds).unwrap(); + let deserialized: CredentialSet = serde_json::from_str(&json).unwrap(); + assert_eq!(creds, deserialized); + } + + #[test] + fn credential_set_clone() { + let creds = CredentialSet::Bearer { + token: "tok".to_string(), + }; + let cloned = creds.clone(); + assert_eq!(creds, cloned); + } +} diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index 7dbbdf9..c4e41f5 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -55,6 +55,7 @@ pub mod auth; pub mod call; pub mod client; pub mod config; +pub mod credentials; pub mod error; pub mod interface; pub mod server; @@ -84,6 +85,9 @@ pub use config::{ AuthPolicy, ConfigReloadHandle, ConfigServiceImpl, DynamicConfig, ForwardingAction, ForwardingPolicy, ForwardingRule, RateLimitConfig, StaticConfig, TargetPattern, }; +pub use credentials::{ + ConfigCredentialProvider, CredentialProvider, CredentialSet, SecretStoreCredentialProvider, +}; pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError}; pub use interface::{ is_valid_pair, DnsInterface, DnsInterfaceConfig, HttpInterface, HttpInterfaceConfig, diff --git a/crates/alknet-core/src/server/handler.rs b/crates/alknet-core/src/server/handler.rs index e4f0076..16be80c 100644 --- a/crates/alknet-core/src/server/handler.rs +++ b/crates/alknet-core/src/server/handler.rs @@ -869,6 +869,7 @@ mod tests { auth: dynamic.auth.clone(), forwarding: deny_policy, rate_limits: dynamic.rate_limits.clone(), + credentials: dynamic.credentials.clone(), }; drop(dynamic); auth_config.store(Arc::new(new_dynamic));