Compare commits
1 Commits
feat/napi/
...
feat/serve
| Author | SHA1 | Date | |
|---|---|---|---|
| 24b70f5651 |
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -2395,7 +2395,6 @@ version = "3.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2"
|
checksum = "f1d395473824516f38dd1071a1a37bc57daa7be65b293ebba4ead5f7abb017a2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
|
||||||
"bitflags 2.11.1",
|
"bitflags 2.11.1",
|
||||||
"ctor",
|
"ctor",
|
||||||
"futures",
|
"futures",
|
||||||
@@ -2403,7 +2402,6 @@ dependencies = [
|
|||||||
"napi-sys",
|
"napi-sys",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"tokio",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5595,7 +5593,6 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"futures",
|
|
||||||
"ipnetwork",
|
"ipnetwork",
|
||||||
"iroh",
|
"iroh",
|
||||||
"rand 0.10.1",
|
"rand 0.10.1",
|
||||||
@@ -5623,8 +5620,6 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"napi",
|
"napi",
|
||||||
"napi-derive",
|
"napi-derive",
|
||||||
"russh",
|
|
||||||
"tokio",
|
|
||||||
"wraith-core",
|
"wraith-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ impl LocalForwarder {
|
|||||||
handle: Arc<Mutex<client::Handle<H>>>,
|
handle: Arc<Mutex<client::Handle<H>>>,
|
||||||
) -> Result<(), ForwardError> {
|
) -> Result<(), ForwardError> {
|
||||||
let listen_addr = self.spec.listen_addr()?;
|
let listen_addr = self.spec.listen_addr()?;
|
||||||
let listener = TcpListener::bind(listen_addr)
|
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||||
self.listener = Some(listener);
|
self.listener = Some(listener);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use russh::keys::ssh_key::HashAlg;
|
use russh::keys::ssh_key::HashAlg;
|
||||||
@@ -7,6 +8,7 @@ use russh::server::{Auth, Handler, Msg, Session};
|
|||||||
use russh::Channel;
|
use russh::Channel;
|
||||||
|
|
||||||
use crate::auth::ServerAuthConfig;
|
use crate::auth::ServerAuthConfig;
|
||||||
|
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
|
|
||||||
const WRAITH_PREFIX: &str = "wraith-";
|
const WRAITH_PREFIX: &str = "wraith-";
|
||||||
|
|
||||||
@@ -22,10 +24,32 @@ pub struct ProxyConfig {
|
|||||||
pub mode: ProxyMode,
|
pub mode: ProxyMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub enum TransportKind {
|
||||||
|
Tcp,
|
||||||
|
Tls,
|
||||||
|
Iroh,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TransportKind {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TransportKind::Tcp => write!(f, "tcp"),
|
||||||
|
TransportKind::Tls => write!(f, "tls"),
|
||||||
|
TransportKind::Iroh => write!(f, "iroh"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ServerHandler {
|
pub struct ServerHandler {
|
||||||
auth_config: Arc<ServerAuthConfig>,
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
|
transport: TransportKind,
|
||||||
|
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||||
|
connection_allowed: bool,
|
||||||
|
auth_limiter: AuthAttemptLimiter,
|
||||||
|
connected_at: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerHandler {
|
impl ServerHandler {
|
||||||
@@ -33,11 +57,65 @@ impl ServerHandler {
|
|||||||
auth_config: Arc<ServerAuthConfig>,
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
outbound_proxy: Option<ProxyConfig>,
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
remote_addr: Option<SocketAddr>,
|
remote_addr: Option<SocketAddr>,
|
||||||
|
transport: TransportKind,
|
||||||
|
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||||
|
max_auth_attempts: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let allowed = if let Some(addr) = remote_addr {
|
||||||
|
let ip = addr.ip();
|
||||||
|
if connection_limiter.check(ip) {
|
||||||
|
connection_limiter.on_connect(ip);
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
transport = %transport,
|
||||||
|
"connection opened"
|
||||||
|
);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
transport = %transport,
|
||||||
|
"connection rejected"
|
||||||
|
);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
auth_config,
|
auth_config,
|
||||||
outbound_proxy,
|
outbound_proxy,
|
||||||
remote_addr,
|
remote_addr,
|
||||||
|
transport,
|
||||||
|
connection_limiter,
|
||||||
|
connection_allowed: allowed,
|
||||||
|
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||||
|
connected_at: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_connection_allowed(&self) -> bool {
|
||||||
|
self.connection_allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||||
|
self.remote_addr.map(|a| a.ip())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ServerHandler {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(addr) = self.remote_addr {
|
||||||
|
if self.connection_allowed {
|
||||||
|
self.connection_limiter.on_disconnect(addr.ip());
|
||||||
|
}
|
||||||
|
let duration = self.connected_at.elapsed();
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %addr,
|
||||||
|
duration_secs = duration.as_secs_f64(),
|
||||||
|
"connection closed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -51,6 +129,23 @@ impl Handler for ServerHandler {
|
|||||||
user: &str,
|
user: &str,
|
||||||
public_key: &russh::keys::ssh_key::PublicKey,
|
public_key: &russh::keys::ssh_key::PublicKey,
|
||||||
) -> Result<Auth, Self::Error> {
|
) -> Result<Auth, Self::Error> {
|
||||||
|
if !self.auth_limiter.check() {
|
||||||
|
let remote_addr_display = self
|
||||||
|
.remote_addr
|
||||||
|
.map_or("unknown".to_string(), |a| a.to_string());
|
||||||
|
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||||
|
tracing::info!(
|
||||||
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
|
key_fingerprint = %fingerprint,
|
||||||
|
result = "reject",
|
||||||
|
"auth attempt"
|
||||||
|
);
|
||||||
|
return Ok(Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||||
let remote_addr_display = self
|
let remote_addr_display = self
|
||||||
.remote_addr
|
.remote_addr
|
||||||
@@ -63,6 +158,7 @@ impl Handler for ServerHandler {
|
|||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
remote_addr = %remote_addr_display,
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
key_fingerprint = %fingerprint,
|
key_fingerprint = %fingerprint,
|
||||||
result = "accept",
|
result = "accept",
|
||||||
"auth attempt"
|
"auth attempt"
|
||||||
@@ -70,8 +166,10 @@ impl Handler for ServerHandler {
|
|||||||
Ok(Auth::Accept)
|
Ok(Auth::Accept)
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
self.auth_limiter.on_failure();
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
remote_addr = %remote_addr_display,
|
remote_addr = %remote_addr_display,
|
||||||
|
user = user,
|
||||||
key_fingerprint = %fingerprint,
|
key_fingerprint = %fingerprint,
|
||||||
result = "reject",
|
result = "reject",
|
||||||
"auth attempt"
|
"auth attempt"
|
||||||
@@ -188,10 +286,22 @@ mod tests {
|
|||||||
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_handler(
|
||||||
|
auth_config: Arc<ServerAuthConfig>,
|
||||||
|
outbound_proxy: Option<ProxyConfig>,
|
||||||
|
remote_addr: Option<SocketAddr>,
|
||||||
|
) -> ServerHandler {
|
||||||
|
ServerHandler::new(auth_config, outbound_proxy, remote_addr, TransportKind::Tcp, default_limiter(), 10)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_accepts_known_key() {
|
async fn auth_delegation_accepts_known_key() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||||
@@ -201,7 +311,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_rejects_unknown_key() {
|
async fn auth_delegation_rejects_unknown_key() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||||
let other_ssh_key = russh::keys::parse_public_key_base64(
|
let other_ssh_key = russh::keys::parse_public_key_base64(
|
||||||
@@ -224,7 +334,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn auth_delegation_empty_config_rejects_all() {
|
async fn auth_delegation_empty_config_rejects_all() {
|
||||||
let auth_config = make_empty_auth_config();
|
let auth_config = make_empty_auth_config();
|
||||||
let mut handler = ServerHandler::new(auth_config, None, None);
|
let mut handler = make_handler(auth_config, None, None);
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let result = handler
|
let result = handler
|
||||||
@@ -243,7 +353,7 @@ mod tests {
|
|||||||
async fn auth_logging_includes_remote_addr() {
|
async fn auth_logging_includes_remote_addr() {
|
||||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||||
let mut handler = ServerHandler::new(auth_config, None, Some(remote_addr));
|
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||||
|
|
||||||
let ssh_key = load_key().public_key().clone();
|
let ssh_key = load_key().public_key().clone();
|
||||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||||
@@ -287,7 +397,7 @@ mod tests {
|
|||||||
});
|
});
|
||||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||||
|
|
||||||
let handler = ServerHandler::new(auth_config, proxy.clone(), remote);
|
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||||
assert!(handler.outbound_proxy.is_some());
|
assert!(handler.outbound_proxy.is_some());
|
||||||
assert!(handler.remote_addr.is_some());
|
assert!(handler.remote_addr.is_some());
|
||||||
}
|
}
|
||||||
@@ -295,9 +405,108 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn one_handler_per_connection() {
|
fn one_handler_per_connection() {
|
||||||
let auth_config = make_empty_auth_config();
|
let auth_config = make_empty_auth_config();
|
||||||
let handler1 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
let handler1 = make_handler(auth_config.clone(), None, Some("10.0.0.1:22".parse().unwrap()));
|
||||||
let handler2 = ServerHandler::new(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
let handler2 = make_handler(auth_config.clone(), None, Some("10.0.0.2:22".parse().unwrap()));
|
||||||
|
|
||||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||||
|
let mut handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("10.0.0.1:22".parse().unwrap()),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter,
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
|
||||||
|
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
|
assert_eq!(r1, Auth::Reject { proceed_with_methods: None });
|
||||||
|
|
||||||
|
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||||
|
assert_eq!(r2, Auth::Reject { proceed_with_methods: None });
|
||||||
|
|
||||||
|
assert!(!handler.auth_limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_rate_limit_blocks_over_limit() {
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||||
|
|
||||||
|
let h1 = ServerHandler::new(
|
||||||
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter.clone(),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(h1.is_connection_allowed());
|
||||||
|
|
||||||
|
let h2 = ServerHandler::new(
|
||||||
|
auth_config.clone(),
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter.clone(),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(!h2.is_connection_allowed());
|
||||||
|
|
||||||
|
drop(h1);
|
||||||
|
|
||||||
|
let h3 = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some(addr),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
limiter,
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
assert!(h3.is_connection_allowed());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn transport_kind_display() {
|
||||||
|
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||||
|
assert_eq!(TransportKind::Tls.to_string(), "tls");
|
||||||
|
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_log_includes_user_field() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let mut handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("203.0.113.50:12345".parse().unwrap()),
|
||||||
|
TransportKind::Tls,
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0)),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ssh_key = load_key().public_key().clone();
|
||||||
|
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_closed_logs_duration_on_drop() {
|
||||||
|
let auth_config = make_empty_auth_config();
|
||||||
|
let _handler = ServerHandler::new(
|
||||||
|
auth_config,
|
||||||
|
None,
|
||||||
|
Some("203.0.113.50:12345".parse().unwrap()),
|
||||||
|
TransportKind::Tcp,
|
||||||
|
Arc::new(ConnectionRateLimiter::new(0)),
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
pub mod handler;
|
pub mod handler;
|
||||||
|
pub mod rate_limit;
|
||||||
|
|
||||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
|
||||||
|
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||||
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
193
crates/wraith-core/src/server/rate_limit.rs
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
pub struct ConnectionRateLimiter {
|
||||||
|
max_per_ip: usize,
|
||||||
|
active: Mutex<HashMap<IpAddr, usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionRateLimiter {
|
||||||
|
pub fn new(max_per_ip: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
max_per_ip,
|
||||||
|
active: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(&self, ip: IpAddr) -> bool {
|
||||||
|
if self.max_per_ip == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
let active = self.active.lock().unwrap();
|
||||||
|
let count = active.get(&ip).copied().unwrap_or(0);
|
||||||
|
count < self.max_per_ip
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_connect(&self, ip: IpAddr) {
|
||||||
|
let mut active = self.active.lock().unwrap();
|
||||||
|
*active.entry(ip).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||||
|
let mut active = self.active.lock().unwrap();
|
||||||
|
if let Some(count) = active.get_mut(&ip) {
|
||||||
|
if *count > 1 {
|
||||||
|
*count -= 1;
|
||||||
|
} else {
|
||||||
|
active.remove(&ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AuthAttemptLimiter {
|
||||||
|
max_attempts: usize,
|
||||||
|
failures: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthAttemptLimiter {
|
||||||
|
pub fn new(max_attempts: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
max_attempts,
|
||||||
|
failures: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(&self) -> bool {
|
||||||
|
if self.max_attempts == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
self.failures < self.max_attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_failure(&mut self) {
|
||||||
|
self.failures += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
fn ip(n: u8) -> IpAddr {
|
||||||
|
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_allows_when_under_limit() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(3);
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_blocks_when_at_limit() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(2);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_allows_after_disconnect() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(2);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
limiter.on_disconnect(ip(1));
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_unlimited_when_zero() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(0);
|
||||||
|
for _ in 0..100 {
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
}
|
||||||
|
assert!(limiter.check(ip(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_tracks_per_ip_independently() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(1);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
assert!(!limiter.check(ip(1)));
|
||||||
|
assert!(limiter.check(ip(2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_ipv6() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(1);
|
||||||
|
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||||
|
limiter.on_connect(ip6);
|
||||||
|
assert!(!limiter.check(ip6));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||||
|
let limiter = ConnectionRateLimiter::new(3);
|
||||||
|
limiter.on_connect(ip(1));
|
||||||
|
limiter.on_disconnect(ip(1));
|
||||||
|
{
|
||||||
|
let active = limiter.active.lock().unwrap();
|
||||||
|
assert!(!active.contains_key(&ip(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_allows_when_under_limit() {
|
||||||
|
let limiter = AuthAttemptLimiter::new(3);
|
||||||
|
assert!(limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_blocks_after_max_failures() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(2);
|
||||||
|
limiter.on_failure();
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(!limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_unlimited_when_zero() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(0);
|
||||||
|
for _ in 0..100 {
|
||||||
|
limiter.on_failure();
|
||||||
|
}
|
||||||
|
assert!(limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||||
|
let mut limiter = AuthAttemptLimiter::new(3);
|
||||||
|
limiter.on_failure();
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(limiter.check());
|
||||||
|
limiter.on_failure();
|
||||||
|
assert!(!limiter.check());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_limiter_thread_safety() {
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
let lim = Arc::clone(&limiter);
|
||||||
|
handles.push(thread::spawn(move || {
|
||||||
|
let ip_addr = ip((i % 3) as u8 + 1);
|
||||||
|
lim.on_connect(ip_addr);
|
||||||
|
assert!(lim.check(ip_addr));
|
||||||
|
lim.on_disconnect(ip_addr);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,8 +7,6 @@ edition = "2021"
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
wraith-core = { path = "../wraith-core", features = ["tls", "iroh"] }
|
wraith-core = { path = "../wraith-core" }
|
||||||
napi = { version = "3", features = ["async", "error_anyhow"] }
|
napi = "3"
|
||||||
napi-derive = "3"
|
napi-derive = "3"
|
||||||
tokio = { version = "1", features = ["io-util", "sync"] }
|
|
||||||
russh = "0.49"
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
use std::net::SocketAddr;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use napi::bindgen_prelude::*;
|
|
||||||
use napi_derive::napi;
|
|
||||||
use russh::client;
|
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
use wraith_core::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
|
||||||
use wraith_core::auth::keys::KeySource;
|
|
||||||
use wraith_core::transport::{TcpTransport, TlsTransport, Transport};
|
|
||||||
|
|
||||||
const DEFAULT_HOST: &str = "wraith-control";
|
|
||||||
const DEFAULT_PORT: u32 = 0;
|
|
||||||
|
|
||||||
#[napi(object)]
|
|
||||||
pub struct WraithConnectOptions {
|
|
||||||
pub server: Option<String>,
|
|
||||||
pub peer: Option<String>,
|
|
||||||
pub transport: String,
|
|
||||||
pub identity: Option<Either<String, Buffer>>,
|
|
||||||
pub tls_server_name: Option<String>,
|
|
||||||
pub insecure: Option<bool>,
|
|
||||||
pub iroh_relay: Option<String>,
|
|
||||||
pub proxy: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_key_source(identity: &Option<Either<String, Buffer>>) -> Result<KeySource> {
|
|
||||||
match identity {
|
|
||||||
None => Err(Error::new(
|
|
||||||
Status::InvalidArg,
|
|
||||||
"identity is required: provide a file path (string) or key data (Buffer)",
|
|
||||||
)),
|
|
||||||
Some(Either::A(path)) => Ok(KeySource::File(path.into())),
|
|
||||||
Some(Either::B(buf)) => Ok(KeySource::Memory(buf.to_vec())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_addr(addr_str: &str) -> Result<SocketAddr> {
|
|
||||||
addr_str.parse().map_err(|e| {
|
|
||||||
Error::new(
|
|
||||||
Status::InvalidArg,
|
|
||||||
format!("invalid server address '{}': {}", addr_str, e),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub struct WraithStream {
|
|
||||||
read: Arc<Mutex<tokio::io::ReadHalf<russh::ChannelStream<client::Msg>>>>,
|
|
||||||
write: Arc<Mutex<tokio::io::WriteHalf<russh::ChannelStream<client::Msg>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
impl WraithStream {
|
|
||||||
#[napi]
|
|
||||||
pub async fn read(&self, size: u32) -> Result<Buffer> {
|
|
||||||
let mut buf = vec![0u8; size as usize];
|
|
||||||
let mut guard = self.read.lock().await;
|
|
||||||
let n = guard.read(&mut buf).await.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("read failed: {}", e))
|
|
||||||
})?;
|
|
||||||
if n == 0 {
|
|
||||||
return Ok(Vec::<u8>::new().into());
|
|
||||||
}
|
|
||||||
buf.truncate(n);
|
|
||||||
Ok(buf.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub async fn write(&self, data: Buffer) -> Result<()> {
|
|
||||||
let mut guard = self.write.lock().await;
|
|
||||||
guard.write_all(&data).await.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("write failed: {}", e))
|
|
||||||
})?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub async fn close(&self) -> Result<()> {
|
|
||||||
let mut guard = self.write.lock().await;
|
|
||||||
guard.shutdown().await.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("close failed: {}", e))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub async fn connect(options: WraithConnectOptions) -> Result<WraithStream> {
|
|
||||||
let key_source = resolve_key_source(&options.identity)?;
|
|
||||||
let auth_config = Arc::new(ClientAuthConfig::from_key_source(key_source).map_err(|e| {
|
|
||||||
Error::new(Status::InvalidArg, format!("invalid identity key: {}", e))
|
|
||||||
})?);
|
|
||||||
|
|
||||||
let transport_mode = options.transport.to_lowercase();
|
|
||||||
let handler = ClientHandler::from_config(&auth_config);
|
|
||||||
let username = "wraith".to_string();
|
|
||||||
|
|
||||||
let config = Arc::new(client::Config::default());
|
|
||||||
|
|
||||||
let mut handle: client::Handle<ClientHandler> = match transport_mode.as_str() {
|
|
||||||
"tcp" => {
|
|
||||||
let server = options.server.as_ref().ok_or_else(|| {
|
|
||||||
Error::new(Status::InvalidArg, "server is required for tcp transport")
|
|
||||||
})?;
|
|
||||||
let addr = parse_addr(server)?;
|
|
||||||
let transport = TcpTransport::new(addr);
|
|
||||||
let stream = transport.connect().await.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("tcp connect failed: {}", e))
|
|
||||||
})?;
|
|
||||||
client::connect_stream(config, stream, handler)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
Error::new(
|
|
||||||
Status::GenericFailure,
|
|
||||||
format!("ssh handshake failed: {}", e),
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
"tls" => {
|
|
||||||
let server = options.server.as_ref().ok_or_else(|| {
|
|
||||||
Error::new(Status::InvalidArg, "server is required for tls transport")
|
|
||||||
})?;
|
|
||||||
let addr = parse_addr(server)?;
|
|
||||||
let mut transport = TlsTransport::new(addr);
|
|
||||||
if let Some(ref name) = options.tls_server_name {
|
|
||||||
transport = transport.with_server_name(name);
|
|
||||||
}
|
|
||||||
if let Some(true) = options.insecure {
|
|
||||||
transport = transport.with_insecure(true);
|
|
||||||
}
|
|
||||||
let stream = transport.connect().await.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("tls connect failed: {}", e))
|
|
||||||
})?;
|
|
||||||
client::connect_stream(config, stream, handler)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
Error::new(
|
|
||||||
Status::GenericFailure,
|
|
||||||
format!("ssh handshake failed: {}", e),
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
"iroh" => {
|
|
||||||
return Err(Error::new(
|
|
||||||
Status::GenericFailure,
|
|
||||||
"iroh transport is not yet supported in napi connect()".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err(Error::new(
|
|
||||||
Status::InvalidArg,
|
|
||||||
format!("unknown transport '{}'; expected tcp, tls, or iroh", transport_mode),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let auth_ok = auth_config
|
|
||||||
.authenticate(&mut handle, &username)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e))
|
|
||||||
})?;
|
|
||||||
if !auth_ok {
|
|
||||||
return Err(Error::new(Status::GenericFailure, "ssh authentication rejected"));
|
|
||||||
}
|
|
||||||
|
|
||||||
let channel = handle
|
|
||||||
.channel_open_direct_tcpip(DEFAULT_HOST, DEFAULT_PORT, "127.0.0.1", 0)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
Error::new(
|
|
||||||
Status::GenericFailure,
|
|
||||||
format!("failed to open ssh channel: {}", e),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let stream = channel.into_stream();
|
|
||||||
let (read_half, write_half) = tokio::io::split(stream);
|
|
||||||
|
|
||||||
Ok(WraithStream {
|
|
||||||
read: Arc::new(Mutex::new(read_half)),
|
|
||||||
write: Arc::new(Mutex::new(write_half)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_key_source_file_path() {
|
|
||||||
let identity = Some(Either::<String, Buffer>::A("/path/to/key".to_string()));
|
|
||||||
let result = resolve_key_source(&identity);
|
|
||||||
assert!(result.is_ok());
|
|
||||||
match result.unwrap() {
|
|
||||||
KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/key")),
|
|
||||||
_ => panic!("expected File variant"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_key_source_buffer() {
|
|
||||||
let identity = Some(Either::<String, Buffer>::B(Buffer::from(ED25519_PRIVATE_KEY.as_bytes().to_vec())));
|
|
||||||
let result = resolve_key_source(&identity);
|
|
||||||
assert!(result.is_ok());
|
|
||||||
match result.unwrap() {
|
|
||||||
KeySource::Memory(data) => assert!(!data.is_empty()),
|
|
||||||
_ => panic!("expected Memory variant"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn resolve_key_source_missing() {
|
|
||||||
let identity: Option<Either<String, Buffer>> = None;
|
|
||||||
let result = resolve_key_source(&identity);
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn parse_addr_valid() {
|
|
||||||
let addr = parse_addr("127.0.0.1:22");
|
|
||||||
assert!(addr.is_ok());
|
|
||||||
assert_eq!(addr.unwrap().port(), 22);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn parse_addr_invalid() {
|
|
||||||
let addr = parse_addr("not-an-address");
|
|
||||||
assert!(addr.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn auth_config_from_memory_key() {
|
|
||||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
|
||||||
let config = ClientAuthConfig::from_key_source(source);
|
|
||||||
assert!(config.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn auth_config_from_invalid_key() {
|
|
||||||
let source = KeySource::Memory(b"not-a-key".to_vec());
|
|
||||||
let config = ClientAuthConfig::from_key_source(source);
|
|
||||||
assert!(config.is_err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate napi_derive;
|
extern crate napi_derive;
|
||||||
|
|
||||||
mod connect;
|
|
||||||
Reference in New Issue
Block a user