diff --git a/crates/alknet-secret/src/cache.rs b/crates/alknet-secret/src/cache.rs new file mode 100644 index 0000000..383f4bb --- /dev/null +++ b/crates/alknet-secret/src/cache.rs @@ -0,0 +1,339 @@ +//! TTL-based key cache with LRU eviction for SecretService. +//! +//! The `KeyCache` stores derived key material keyed by derivation path. Entries +//! expire after a configurable TTL (default: 1 hour) and are evicted lazily on +//! access. When the cache exceeds `max_entries` (default: 64), the least recently +//! used entry is evicted. All entries are zeroized on removal per ADR-038. + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use zeroize::Zeroize; + +use crate::protocol::KeyType; + +/// Default TTL for cached keys (1 hour). +pub const DEFAULT_TTL: Duration = Duration::from_secs(3600); + +/// Default maximum number of cache entries. +pub const DEFAULT_MAX_ENTRIES: usize = 64; + +/// A cached derived key with metadata for TTL and LRU tracking. +/// +/// The `private_key` field is zeroized on drop via `#[zeroize(drop)]`. +/// This is a separate internal type from `DerivedKey` — it holds the same +/// data but is managed within the cache lifecycle. +#[derive(Zeroize)] +#[zeroize(drop)] +pub struct CachedKey { + /// When this key was derived (for TTL checking). + #[zeroize(skip)] + pub derived_at: Instant, + /// The type of key that was derived. + #[zeroize(skip)] + pub key_type: KeyType, + /// The private key bytes (sensitive — zeroized on drop). + #[zeroize] + pub private_key: Vec, + /// The public key bytes. + #[zeroize(skip)] + pub public_key: Vec, + /// Last access time for LRU ordering. + #[zeroize(skip)] + last_accessed: Instant, +} + +impl CachedKey { + /// Create a new `CachedKey` from derived key material. + pub fn new(key_type: KeyType, private_key: Vec, public_key: Vec) -> Self { + let now = Instant::now(); + Self { + derived_at: now, + key_type, + private_key, + public_key, + last_accessed: now, + } + } + + /// Check whether this cached entry has expired. + pub fn is_expired(&self, ttl: Duration) -> bool { + Instant::now().duration_since(self.derived_at) > ttl + } + + /// Touch the entry to update its last-accessed time (for LRU). + pub fn touch(&mut self) { + self.last_accessed = Instant::now(); + } +} + +/// Configuration for the key cache. +#[derive(Debug, Clone)] +pub struct CacheConfig { + /// Time-to-live for cached entries. Expired entries are evicted lazily on access. + pub ttl: Duration, + /// Maximum number of entries. When exceeded, the least recently used entry is evicted. + pub max_entries: usize, +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + ttl: DEFAULT_TTL, + max_entries: DEFAULT_MAX_ENTRIES, + } + } +} + +impl CacheConfig { + /// Create a new `CacheConfig` with the given TTL and max entries. + pub fn new(ttl: Duration, max_entries: usize) -> Self { + Self { ttl, max_entries } + } +} + +/// LRU key cache backed by a HashMap with access-order tracking. +/// +/// The cache uses a `HashMap` for O(1) lookups and a separate ordering list +/// for LRU eviction. For the default 64 entries, this is efficient enough +/// without needing the `lru` crate. +pub struct KeyCache { + entries: HashMap, + /// Access order: most recently used at the back, least recently at the front. + order: Vec, + config: CacheConfig, +} + +impl KeyCache { + /// Create a new empty `KeyCache` with the given configuration. + pub fn new(config: CacheConfig) -> Self { + Self { + entries: HashMap::new(), + order: Vec::with_capacity(config.max_entries), + config, + } + } + + /// Create a new empty `KeyCache` with default configuration. + pub fn with_defaults() -> Self { + Self::new(CacheConfig::default()) + } + + /// Get a cached entry by derivation path if it exists and is within TTL. + /// + /// Returns `None` if the entry does not exist or has expired (expired entries + /// are evicted). A successful get updates the LRU ordering. + pub fn get(&mut self, path: &str) -> Option<&CachedKey> { + if let Some(entry) = self.entries.get_mut(path) { + if entry.is_expired(self.config.ttl) { + self.remove_entry(path); + return None; + } + entry.touch(); + self.move_to_back(path); + Some(self.entries.get(path)?) + } else { + None + } + } + + /// Insert a cached key by derivation path. + /// + /// If the cache is at capacity, the least recently used entry is evicted + /// (and zeroized). If an entry with the same path already exists, it is + /// replaced (the old entry is zeroized on drop). + pub fn insert(&mut self, path: &str, key: CachedKey) { + if self.entries.contains_key(path) { + self.remove_entry(path); + } else if self.entries.len() >= self.config.max_entries { + self.evict_lru(); + } + self.entries.insert(path.to_string(), key); + self.order.push(path.to_string()); + } + + /// Remove all entries that have exceeded the TTL, zeroizing them. + pub fn evict_expired(&mut self) { + let ttl = self.config.ttl; + let expired: Vec = self + .entries + .iter() + .filter(|(_, v)| v.is_expired(ttl)) + .map(|(k, _)| k.clone()) + .collect(); + + for path in expired { + self.remove_entry(&path); + } + } + + /// Clear all cache entries, zeroizing each one before removal. + pub fn clear(&mut self) { + self.entries.clear(); + self.order.clear(); + } + + /// Returns the number of entries currently in the cache. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Returns `true` if the cache contains no entries. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + fn remove_entry(&mut self, path: &str) { + self.entries.remove(path); + self.order.retain(|p| p != path); + } + + fn evict_lru(&mut self) { + if let Some(lru_path) = self.order.first().cloned() { + self.remove_entry(&lru_path); + } + } + + fn move_to_back(&mut self, path: &str) { + self.order.retain(|p| p != path); + self.order.push(path.to_string()); + } +} + +impl Default for KeyCache { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_cached_key(key_type: KeyType) -> CachedKey { + CachedKey::new(key_type, vec![0xABu8; 32], vec![0xCDu8; 32]) + } + + #[test] + fn test_cache_insert_and_get() { + let mut cache = KeyCache::with_defaults(); + cache.insert("m/74'/0'/0'/0'", make_cached_key(KeyType::Ed25519)); + + let entry = cache.get("m/74'/0'/0'/0'").unwrap(); + assert_eq!(entry.key_type, KeyType::Ed25519); + } + + #[test] + fn test_cache_miss_returns_none() { + let mut cache = KeyCache::with_defaults(); + assert!(cache.get("m/74'/0'/0'/0'").is_none()); + } + + #[test] + fn test_cache_expired_entry_evicted_on_access() { + let mut config = CacheConfig::default(); + config.ttl = Duration::from_millis(1); + + let mut cache = KeyCache::new(config); + cache.insert("m/74'/0'/0'/0'", make_cached_key(KeyType::Ed25519)); + + std::thread::sleep(Duration::from_millis(5)); + + assert!(cache.get("m/74'/0'/0'/0'").is_none()); + assert_eq!(cache.len(), 0); + } + + #[test] + fn test_cache_lru_eviction() { + let mut config = CacheConfig::default(); + config.max_entries = 3; + + let mut cache = KeyCache::new(config); + + cache.insert("path1", make_cached_key(KeyType::Ed25519)); + cache.insert("path2", make_cached_key(KeyType::Aes256Gcm)); + cache.insert("path3", make_cached_key(KeyType::Secp256k1)); + + assert_eq!(cache.len(), 3); + + cache.insert("path4", make_cached_key(KeyType::Ed25519)); + + assert_eq!(cache.len(), 3); + assert!(cache.get("path1").is_none()); + assert!(cache.get("path2").is_some()); + assert!(cache.get("path3").is_some()); + assert!(cache.get("path4").is_some()); + } + + #[test] + fn test_cache_lru_access_reorders() { + let mut config = CacheConfig::default(); + config.max_entries = 3; + + let mut cache = KeyCache::new(config); + + cache.insert("path1", make_cached_key(KeyType::Ed25519)); + cache.insert("path2", make_cached_key(KeyType::Aes256Gcm)); + cache.insert("path3", make_cached_key(KeyType::Secp256k1)); + + cache.get("path1"); + + cache.insert("path4", make_cached_key(KeyType::Ed25519)); + + assert_eq!(cache.len(), 3); + assert!(cache.get("path1").is_some()); + assert!(cache.get("path2").is_none()); + assert!(cache.get("path3").is_some()); + assert!(cache.get("path4").is_some()); + } + + #[test] + fn test_cache_clear_zeroizes_and_removes_all() { + let mut cache = KeyCache::with_defaults(); + cache.insert("path1", make_cached_key(KeyType::Ed25519)); + cache.insert("path2", make_cached_key(KeyType::Aes256Gcm)); + + assert_eq!(cache.len(), 2); + + cache.clear(); + + assert_eq!(cache.len(), 0); + assert!(cache.is_empty()); + } + + #[test] + fn test_evict_expired_removes_only_expired() { + let mut config = CacheConfig::default(); + config.ttl = Duration::from_millis(10); + + let mut cache = KeyCache::new(config); + cache.insert("path1", make_cached_key(KeyType::Ed25519)); + + std::thread::sleep(Duration::from_millis(20)); + + cache.insert("path2", make_cached_key(KeyType::Aes256Gcm)); + + cache.evict_expired(); + + assert_eq!(cache.len(), 1); + assert!(cache.get("path2").is_some()); + } + + #[test] + fn test_cache_replace_existing_path() { + let mut cache = KeyCache::with_defaults(); + cache.insert( + "path1", + CachedKey::new(KeyType::Ed25519, vec![1u8; 32], vec![2u8; 32]), + ); + cache.insert( + "path1", + CachedKey::new(KeyType::Aes256Gcm, vec![3u8; 32], vec![4u8; 32]), + ); + + let entry = cache.get("path1").unwrap(); + assert_eq!(entry.key_type, KeyType::Aes256Gcm); + assert_eq!(entry.private_key, vec![3u8; 32]); + assert_eq!(cache.len(), 1); + } +} diff --git a/crates/alknet-secret/src/lib.rs b/crates/alknet-secret/src/lib.rs index c23ad96..60db39f 100644 --- a/crates/alknet-secret/src/lib.rs +++ b/crates/alknet-secret/src/lib.rs @@ -28,6 +28,7 @@ //! - [`service`] — `SecretService` implementation with Unlock/Lock lifecycle //! - [`ethereum`] — BIP-0032 secp256k1 HD key derivation (behind `secp256k1` feature) +pub mod cache; pub mod derivation; pub mod encryption; pub mod mnemonic; @@ -38,6 +39,7 @@ pub mod service; pub mod ethereum; // Re-export primary public API +pub use cache::CacheConfig; pub use derivation::{DerivationError, ExtendedPrivKey, PATHS}; pub use encryption::{EncryptedData, EncryptionError}; pub use mnemonic::{Language, Mnemonic, Seed}; diff --git a/crates/alknet-secret/src/service.rs b/crates/alknet-secret/src/service.rs index 26e6d42..40afcfa 100644 --- a/crates/alknet-secret/src/service.rs +++ b/crates/alknet-secret/src/service.rs @@ -37,6 +37,7 @@ use std::sync::{Arc, RwLock}; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; +use crate::cache::{CacheConfig, CachedKey, KeyCache}; use crate::derivation::{self, DerivationError, PATHS}; use crate::encryption::{self, EncryptedData, EncryptionKey}; use crate::mnemonic::{Language, Mnemonic, Seed}; @@ -59,6 +60,8 @@ struct SecretServiceInner { seed: Option, /// Whether the service is unlocked. unlocked: bool, + /// TTL-based key cache with LRU eviction. + cache: KeyCache, } /// Errors that can occur during secret service operations. @@ -99,13 +102,19 @@ impl From for SecretServiceError { } impl SecretServiceHandle { - /// Create a new SecretServiceHandle in the locked state. + /// Create a new SecretServiceHandle in the locked state with default cache config. pub fn new() -> Self { + Self::with_cache_config(CacheConfig::default()) + } + + /// Create a new SecretServiceHandle with the given cache configuration. + pub fn with_cache_config(config: CacheConfig) -> Self { Self { inner: Arc::new(RwLock::new(SecretServiceInner { mnemonic: None, seed: None, unlocked: false, + cache: KeyCache::new(config), })), } } @@ -156,6 +165,7 @@ impl SecretServiceHandle { /// material per ADR-038. pub fn lock(&self) { let mut inner = self.inner.write().unwrap(); + inner.cache.clear(); inner.seed = None; // Seed's Zeroize drop handles the zeroization inner.mnemonic = None; // Mnemonic's Zeroize drop handles the zeroization inner.unlocked = false; @@ -168,39 +178,63 @@ impl SecretServiceHandle { /// Derive an Ed25519 keypair at the given path. pub fn derive_ed25519(&self, path: &str) -> Result { - let inner = self.inner.read().unwrap(); + let mut inner = self.inner.write().unwrap(); if !inner.unlocked { return Err(SecretServiceError::ServiceLocked); } + + if let Some(cached) = inner.cache.get(path) { + return Ok(DerivedKey { + key_type: cached.key_type.clone(), + private_key: cached.private_key.clone(), + public_key: cached.public_key.clone(), + }); + } + let seed = inner .seed .as_ref() .ok_or(SecretServiceError::ServiceLocked)?; - let key = derivation::derive_path_from_seed(seed.as_bytes(), path)?; + let private_key = key.private_key().to_vec(); + let public_key = key.public_key().to_vec(); + let cached = CachedKey::new(KeyType::Ed25519, private_key.clone(), public_key.clone()); + inner.cache.insert(path, cached); Ok(DerivedKey { key_type: KeyType::Ed25519, - private_key: key.private_key().to_vec(), - public_key: key.public_key().to_vec(), + private_key, + public_key, }) } /// Derive an AES-256-GCM encryption key at the given path. pub fn derive_encryption_key(&self, path: &str) -> Result { - let inner = self.inner.read().unwrap(); + let mut inner = self.inner.write().unwrap(); if !inner.unlocked { return Err(SecretServiceError::ServiceLocked); } + + if let Some(cached) = inner.cache.get(path) { + return Ok(DerivedKey { + key_type: cached.key_type.clone(), + private_key: cached.private_key.clone(), + public_key: cached.public_key.clone(), + }); + } + let seed = inner .seed .as_ref() .ok_or(SecretServiceError::ServiceLocked)?; - let key = derivation::derive_path_from_seed(seed.as_bytes(), path)?; + let private_key = key.private_key().to_vec(); + let public_key = key.public_key().to_vec(); + let cached = CachedKey::new(KeyType::Aes256Gcm, private_key.clone(), public_key.clone()); + inner.cache.insert(path, cached); Ok(DerivedKey { key_type: KeyType::Aes256Gcm, - private_key: key.private_key().to_vec(), - public_key: key.public_key().to_vec(), + private_key, + public_key, }) } @@ -212,20 +246,33 @@ impl SecretServiceHandle { pub fn derive_ethereum_key(&self, path: &str) -> Result { #[cfg(feature = "secp256k1")] { - let inner = self.inner.read().unwrap(); + let mut inner = self.inner.write().unwrap(); if !inner.unlocked { return Err(SecretServiceError::ServiceLocked); } + + if let Some(cached) = inner.cache.get(path) { + return Ok(DerivedKey { + key_type: cached.key_type.clone(), + private_key: cached.private_key.clone(), + public_key: cached.public_key.clone(), + }); + } + let seed = inner .seed .as_ref() .ok_or(SecretServiceError::ServiceLocked)?; let key = crate::ethereum::derive_secp256k1_path(seed.as_bytes(), path)?; + let private_key = key.private_key().to_vec(); + let public_key = key.public_key().to_vec(); + let cached = CachedKey::new(KeyType::Secp256k1, private_key.clone(), public_key.clone()); + inner.cache.insert(path, cached); Ok(DerivedKey { key_type: KeyType::Secp256k1, - private_key: key.private_key().to_vec(), - public_key: key.public_key().to_vec(), + private_key, + public_key, }) } @@ -274,35 +321,54 @@ impl SecretServiceHandle { plaintext: &str, key_version: u32, ) -> Result { - let inner = self.inner.read().unwrap(); + let mut inner = self.inner.write().unwrap(); if !inner.unlocked { return Err(SecretServiceError::ServiceLocked); } - let seed = inner - .seed - .as_ref() - .ok_or(SecretServiceError::ServiceLocked)?; - let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?; - let enc_key = EncryptionKey::from_derived_bytes(derived.private_key(), key_version); + let private_key = if let Some(cached) = inner.cache.get(PATHS::ENCRYPTION) { + cached.private_key.clone() + } else { + let seed = inner + .seed + .as_ref() + .ok_or(SecretServiceError::ServiceLocked)?; + let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?; + let pk = derived.private_key().to_vec(); + let pubk = derived.public_key().to_vec(); + let cached = CachedKey::new(KeyType::Aes256Gcm, pk.clone(), pubk); + inner.cache.insert(PATHS::ENCRYPTION, cached); + pk + }; + + let enc_key = EncryptionKey::from_derived_bytes(&private_key, key_version); encryption::encrypt(plaintext, &enc_key).map_err(|e| e.into()) } /// Decrypt an EncryptedData blob using the derived encryption key. pub fn decrypt(&self, encrypted: &EncryptedData) -> Result { - let inner = self.inner.read().unwrap(); + let mut inner = self.inner.write().unwrap(); if !inner.unlocked { return Err(SecretServiceError::ServiceLocked); } - let seed = inner - .seed - .as_ref() - .ok_or(SecretServiceError::ServiceLocked)?; - let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?; - let enc_key = - EncryptionKey::from_derived_bytes(derived.private_key(), encrypted.key_version); + let private_key = if let Some(cached) = inner.cache.get(PATHS::ENCRYPTION) { + cached.private_key.clone() + } else { + let seed = inner + .seed + .as_ref() + .ok_or(SecretServiceError::ServiceLocked)?; + let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?; + let pk = derived.private_key().to_vec(); + let pubk = derived.public_key().to_vec(); + let cached = CachedKey::new(KeyType::Aes256Gcm, pk.clone(), pubk); + inner.cache.insert(PATHS::ENCRYPTION, cached); + pk + }; + + let enc_key = EncryptionKey::from_derived_bytes(&private_key, encrypted.key_version); encryption::decrypt(encrypted, &enc_key).map_err(|e| e.into()) } @@ -396,7 +462,7 @@ mod tests { assert!(service.derive_ed25519(PATHS::IDENTITY).is_err()); // Unlock - let phrase = service.unlock_new(24).unwrap(); + let _phrase = service.unlock_new(24).unwrap(); assert!(service.is_unlocked()); // Can derive while unlocked @@ -556,4 +622,95 @@ mod tests { Err(SecretServiceError::UnsupportedKeyType) )); } + + #[test] + fn test_cache_hit_avoids_re_derivation() { + let service = SecretServiceHandle::new(); + service.unlock_new(24).unwrap(); + + let key1 = service.derive_ed25519(PATHS::IDENTITY).unwrap(); + let key2 = service.derive_ed25519(PATHS::IDENTITY).unwrap(); + + assert_eq!(key1.private_key, key2.private_key); + assert_eq!(key1.public_key, key2.public_key); + + let cache_len = service.inner.read().unwrap().cache.len(); + assert_eq!(cache_len, 1); + } + + #[test] + fn test_cache_miss_derives_and_caches() { + let service = SecretServiceHandle::new(); + service.unlock_new(24).unwrap(); + + assert_eq!(service.inner.read().unwrap().cache.len(), 0); + + service.derive_ed25519(PATHS::IDENTITY).unwrap(); + + assert_eq!(service.inner.read().unwrap().cache.len(), 1); + } + + #[test] + fn test_expired_entry_evicted_on_access() { + let config = crate::cache::CacheConfig::new(std::time::Duration::from_millis(5), 64); + let service = SecretServiceHandle::with_cache_config(config); + service.unlock_new(24).unwrap(); + + let key1 = service.derive_ed25519(PATHS::IDENTITY).unwrap(); + assert_eq!(service.inner.read().unwrap().cache.len(), 1); + + std::thread::sleep(std::time::Duration::from_millis(10)); + + let key2 = service.derive_ed25519(PATHS::IDENTITY).unwrap(); + assert_eq!(key1.private_key, key2.private_key); + assert_eq!(service.inner.read().unwrap().cache.len(), 1); + } + + #[test] + fn test_lru_eviction_when_over_max_entries() { + let config = crate::cache::CacheConfig::new(std::time::Duration::from_secs(3600), 2); + let service = SecretServiceHandle::with_cache_config(config); + service.unlock_new(24).unwrap(); + + service.derive_ed25519(PATHS::IDENTITY).unwrap(); + service.derive_ed25519(PATHS::SSH_HOST).unwrap(); + assert_eq!(service.inner.read().unwrap().cache.len(), 2); + + service.derive_ed25519(PATHS::ENCRYPTION).unwrap(); + assert_eq!(service.inner.read().unwrap().cache.len(), 2); + + let mut inner = service.inner.write().unwrap(); + assert!(inner.cache.get(PATHS::IDENTITY).is_none()); + assert!(inner.cache.get(PATHS::SSH_HOST).is_some()); + assert!(inner.cache.get(PATHS::ENCRYPTION).is_some()); + } + + #[test] + fn test_lock_clears_all_cache_entries() { + let service = SecretServiceHandle::new(); + service.unlock_new(24).unwrap(); + + service.derive_ed25519(PATHS::IDENTITY).unwrap(); + service.derive_ed25519(PATHS::SSH_HOST).unwrap(); + assert_eq!(service.inner.read().unwrap().cache.len(), 2); + + service.lock(); + + assert_eq!(service.inner.read().unwrap().cache.len(), 0); + } + + #[test] + fn test_encrypt_decrypt_uses_cached_encryption_key() { + let service = SecretServiceHandle::new(); + service.unlock_new(24).unwrap(); + + let plaintext = "cached-encryption-test"; + let encrypted = service.encrypt(plaintext, 1).unwrap(); + assert_eq!(service.inner.read().unwrap().cache.len(), 1); + + let decrypted = service.decrypt(&encrypted).unwrap(); + assert_eq!(decrypted, plaintext); + + assert_eq!(service.inner.read().unwrap().cache.len(), 1); + } }