diff --git a/Cargo.lock b/Cargo.lock index 5b099c6..0ae3e33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5636,6 +5636,7 @@ dependencies = [ name = "wraith-napi" version = "0.1.0" dependencies = [ + "async-trait", "napi", "napi-derive", "russh", diff --git a/crates/wraith-napi/Cargo.toml b/crates/wraith-napi/Cargo.toml index e24ef2e..9ce6786 100644 --- a/crates/wraith-napi/Cargo.toml +++ b/crates/wraith-napi/Cargo.toml @@ -10,5 +10,6 @@ crate-type = ["cdylib"] wraith-core = { path = "../wraith-core", features = ["tls", "iroh"] } napi = { version = "3", features = ["async", "error_anyhow"] } napi-derive = "3" -tokio = { version = "1", features = ["io-util", "sync"] } -russh = "0.49" \ No newline at end of file +tokio = { version = "1", features = ["io-util", "sync", "rt", "macros", "net", "time", "signal"] } +russh = "0.49" +async-trait = "0.1" \ No newline at end of file diff --git a/crates/wraith-napi/src/connect.rs b/crates/wraith-napi/src/connect.rs index 3d7049a..ebc0469 100644 --- a/crates/wraith-napi/src/connect.rs +++ b/crates/wraith-napi/src/connect.rs @@ -58,9 +58,10 @@ impl WraithStream { pub async fn read(&self, size: u32) -> Result { let mut buf = vec![0u8; size as usize]; let mut guard = self.read.lock().await; - let n = guard.read(&mut buf).await.map_err(|e| { - Error::new(Status::GenericFailure, format!("read failed: {}", e)) - })?; + let n = guard + .read(&mut buf) + .await + .map_err(|e| Error::new(Status::GenericFailure, format!("read failed: {}", e)))?; if n == 0 { return Ok(Vec::::new().into()); } @@ -71,27 +72,30 @@ impl WraithStream { #[napi] pub async fn write(&self, data: Buffer) -> Result<()> { let mut guard = self.write.lock().await; - guard.write_all(&data).await.map_err(|e| { - Error::new(Status::GenericFailure, format!("write failed: {}", e)) - })?; + guard + .write_all(&data) + .await + .map_err(|e| Error::new(Status::GenericFailure, format!("write failed: {}", e)))?; Ok(()) } #[napi] pub async fn close(&self) -> Result<()> { let mut guard = self.write.lock().await; - guard.shutdown().await.map_err(|e| { - Error::new(Status::GenericFailure, format!("close failed: {}", e)) - }) + guard + .shutdown() + .await + .map_err(|e| Error::new(Status::GenericFailure, format!("close failed: {}", e))) } } #[napi] pub async fn connect(options: WraithConnectOptions) -> Result { let key_source = resolve_key_source(&options.identity)?; - let auth_config = Arc::new(ClientAuthConfig::from_key_source(key_source).map_err(|e| { - Error::new(Status::InvalidArg, format!("invalid identity key: {}", e)) - })?); + let auth_config = Arc::new( + ClientAuthConfig::from_key_source(key_source) + .map_err(|e| Error::new(Status::InvalidArg, format!("invalid identity key: {}", e)))?, + ); let transport_mode = options.transport.to_lowercase(); let handler = ClientHandler::from_config(&auth_config); @@ -151,7 +155,10 @@ pub async fn connect(options: WraithConnectOptions) -> Result { _ => { return Err(Error::new( Status::InvalidArg, - format!("unknown transport '{}'; expected tcp, tls, or iroh", transport_mode), + format!( + "unknown transport '{}'; expected tcp, tls, or iroh", + transport_mode + ), )); } }; @@ -159,11 +166,12 @@ pub async fn connect(options: WraithConnectOptions) -> Result { let auth_ok = auth_config .authenticate(&mut handle, &username) .await - .map_err(|e| { - Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e)) - })?; + .map_err(|e| Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e)))?; if !auth_ok { - return Err(Error::new(Status::GenericFailure, "ssh authentication rejected")); + return Err(Error::new( + Status::GenericFailure, + "ssh authentication rejected", + )); } let channel = handle @@ -204,7 +212,9 @@ mod tests { #[test] fn resolve_key_source_buffer() { - let identity = Some(Either::::B(Buffer::from(ED25519_PRIVATE_KEY.as_bytes().to_vec()))); + let identity = Some(Either::::B(Buffer::from( + ED25519_PRIVATE_KEY.as_bytes().to_vec(), + ))); let result = resolve_key_source(&identity); assert!(result.is_ok()); match result.unwrap() { @@ -246,4 +256,4 @@ mod tests { let config = ClientAuthConfig::from_key_source(source); assert!(config.is_err()); } -} \ No newline at end of file +} diff --git a/crates/wraith-napi/src/lib.rs b/crates/wraith-napi/src/lib.rs index fd2f6e6..a697854 100644 --- a/crates/wraith-napi/src/lib.rs +++ b/crates/wraith-napi/src/lib.rs @@ -2,4 +2,5 @@ #[macro_use] extern crate napi_derive; -mod connect; \ No newline at end of file +mod connect; +mod serve; diff --git a/crates/wraith-napi/src/serve.rs b/crates/wraith-napi/src/serve.rs new file mode 100644 index 0000000..3ca1cfe --- /dev/null +++ b/crates/wraith-napi/src/serve.rs @@ -0,0 +1,703 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use napi::bindgen_prelude::*; +use napi::threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}; +use napi_derive::napi; +use russh::server; +use russh::Channel; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Mutex; + +use wraith_core::auth::keys::KeySource; +use wraith_core::auth::server_auth::ServerAuthConfig; +use wraith_core::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; +use wraith_core::server::serve::{ServeOptions, ServeTransportMode, Server}; +use wraith_core::transport::{TcpAcceptor, TransportAcceptor}; + +#[napi(object)] +pub struct WraithServeOptions { + pub transport: String, + pub host_key: Option>, + pub authorized_keys: Option>, + pub cert_authority: Option>, + pub tls_cert: Option, + pub tls_key: Option, + pub acme_domain: Option, + pub listen: Option, + pub iroh_relay: Option, +} + +fn resolve_key_source( + key: &Option>, + field: &str, +) -> napi::Result { + match key { + None => Err(napi::Error::new( + napi::Status::InvalidArg, + format!( + "{} is required: provide a file path (string) or key data (Buffer)", + field + ), + )), + Some(Either::A(path)) => Ok(KeySource::File(path.into())), + Some(Either::B(buf)) => Ok(KeySource::Memory(buf.to_vec())), + } +} + +fn resolve_optional_key_source(key: &Option>) -> Option { + match key { + None => None, + Some(Either::A(path)) => Some(KeySource::File(path.into())), + Some(Either::B(buf)) => Some(KeySource::Memory(buf.to_vec())), + } +} + +fn parse_addr(addr_str: &str) -> napi::Result { + addr_str.parse().map_err(|e| { + napi::Error::new( + napi::Status::InvalidArg, + format!("invalid address '{}': {}", addr_str, e), + ) + }) +} + +#[napi(object)] +pub struct ConnectionInfo { + pub remote_addr: Option, + pub transport_kind: String, +} + +#[napi] +pub struct WraithServerStream { + read: Arc>>>, + write: Arc>>>, +} + +#[napi] +impl WraithServerStream { + #[napi] + pub async fn read(&self, size: u32) -> napi::Result { + let mut buf = vec![0u8; size as usize]; + let mut guard = self.read.lock().await; + let n = guard.read(&mut buf).await.map_err(|e| { + napi::Error::new(napi::Status::GenericFailure, format!("read failed: {}", e)) + })?; + if n == 0 { + return Ok(Vec::::new().into()); + } + buf.truncate(n); + Ok(buf.into()) + } + + #[napi] + pub async fn write(&self, data: Buffer) -> napi::Result<()> { + let mut guard = self.write.lock().await; + guard.write_all(&data).await.map_err(|e| { + napi::Error::new(napi::Status::GenericFailure, format!("write failed: {}", e)) + })?; + Ok(()) + } + + #[napi] + pub async fn close(&self) -> napi::Result<()> { + let mut guard = self.write.lock().await; + guard.shutdown().await.map_err(|e| { + napi::Error::new(napi::Status::GenericFailure, format!("close failed: {}", e)) + }) + } +} + +struct NapiServerHandler { + auth_config: Arc, + remote_addr: Option, + connection_limiter: Arc, + connection_allowed: bool, + auth_limiter: AuthAttemptLimiter, + channel_sender: Arc>>>>, +} + +impl NapiServerHandler { + fn new( + auth_config: Arc, + remote_addr: Option, + connection_limiter: Arc, + max_auth_attempts: usize, + channel_sender: Arc< + Mutex>>>, + >, + ) -> Self { + let allowed = if let Some(addr) = remote_addr { + let ip = addr.ip(); + if connection_limiter.check(ip) { + connection_limiter.on_connect(ip); + true + } else { + false + } + } else { + true + }; + + Self { + auth_config, + remote_addr, + connection_limiter, + connection_allowed: allowed, + auth_limiter: AuthAttemptLimiter::new(max_auth_attempts), + channel_sender, + } + } + + fn is_connection_allowed(&self) -> bool { + self.connection_allowed + } +} + +impl Drop for NapiServerHandler { + fn drop(&mut self) { + if let Some(addr) = self.remote_addr { + if self.connection_allowed { + self.connection_limiter.on_disconnect(addr.ip()); + } + } + } +} + +#[async_trait::async_trait] +impl russh::server::Handler for NapiServerHandler { + type Error = russh::Error; + + async fn auth_publickey( + &mut self, + user: &str, + public_key: &russh::keys::ssh_key::PublicKey, + ) -> std::result::Result { + if !self.auth_limiter.check() { + return Ok(russh::server::Auth::Reject { + proceed_with_methods: None, + }); + } + + let russh_pub = russh::keys::PublicKey::new(public_key.key_data().clone(), user); + let result = self.auth_config.authenticate_publickey(&russh_pub); + + match result { + Ok(()) => Ok(russh::server::Auth::Accept), + Err(_) => { + self.auth_limiter.on_failure(); + Ok(russh::server::Auth::Reject { + proceed_with_methods: None, + }) + } + } + } + + async fn channel_open_direct_tcpip( + &mut self, + channel: Channel, + host_to_connect: &str, + _port_to_connect: u32, + _originator_address: &str, + _originator_port: u32, + _session: &mut russh::server::Session, + ) -> std::result::Result { + if host_to_connect.starts_with("wraith-") { + let guard = self.channel_sender.lock().await; + if let Some(ref tx) = *guard { + let _ = tx.send(channel); + } + return Ok(true); + } + + let _ = channel; + Ok(false) + } + + async fn channel_open_session( + &mut self, + _channel: Channel, + _session: &mut russh::server::Session, + ) -> std::result::Result { + Ok(false) + } + + async fn channel_open_x11( + &mut self, + _channel: Channel, + _originator_address: &str, + _originator_port: u32, + _session: &mut russh::server::Session, + ) -> std::result::Result { + Ok(false) + } + + async fn channel_open_forwarded_tcpip( + &mut self, + _channel: Channel, + _host_to_connect: &str, + _port_to_connect: u32, + _originator_address: &str, + _originator_port: u32, + _session: &mut russh::server::Session, + ) -> std::result::Result { + Ok(false) + } +} + +type ServerTsfn = ThreadsafeFunction; + +#[napi] +pub struct WraithServer { + shutdown_tx: tokio::sync::watch::Sender, + listen_addr: String, + on_connection_tsfn: Arc>>, +} + +struct ConnectionEventWrapper { + stream: WraithServerStream, + info: ConnectionInfo, +} + +impl ToNapiValue for ConnectionEventWrapper { + unsafe fn to_napi_value( + env: napi::sys::napi_env, + val: Self, + ) -> napi::Result { + let mut raw_obj: napi::sys::napi_value = std::ptr::null_mut(); + napi::check_status!( + napi::sys::napi_create_object(env, &mut raw_obj), + "Failed to create object" + )?; + + let stream_val = ::to_napi_value(env, val.stream)?; + let key_stream = std::ffi::CString::new("stream").unwrap(); + napi::check_status!( + napi::sys::napi_set_named_property(env, raw_obj, key_stream.as_ptr(), stream_val), + "Failed to set stream property" + )?; + + let info_val = ::to_napi_value(env, val.info)?; + let key_info = std::ffi::CString::new("info").unwrap(); + napi::check_status!( + napi::sys::napi_set_named_property(env, raw_obj, key_info.as_ptr(), info_val), + "Failed to set info property" + )?; + + Ok(raw_obj) + } +} + +impl TypeName for ConnectionEventWrapper { + fn type_name() -> &'static str { + "ConnectionEventWrapper" + } + + fn value_type() -> napi::ValueType { + napi::ValueType::Object + } +} + +impl ValidateNapiValue for ConnectionEventWrapper {} + +#[napi] +impl WraithServer { + #[napi] + pub async fn close(&self) -> napi::Result<()> { + let _ = self.shutdown_tx.send(true); + Ok(()) + } + + #[napi(ts_return_type = "void")] + pub fn on_connection(&self, callback: Function<(), ()>) -> napi::Result<()> { + let tsfn = callback + .build_threadsafe_function::() + .callee_handled::() + .build_callback(|ctx| Ok(ctx.value))?; + + let holder = self.on_connection_tsfn.clone(); + *holder.blocking_lock() = Some(tsfn); + Ok(()) + } + + #[napi(getter)] + pub fn listen_addr(&self) -> napi::Result { + Ok(self.listen_addr.clone()) + } +} + +#[napi] +pub async fn serve(options: WraithServeOptions) -> napi::Result { + let host_key_source = resolve_key_source(&options.host_key, "hostKey")?; + let authorized_keys_source = resolve_optional_key_source(&options.authorized_keys); + let cert_authority_source = resolve_optional_key_source(&options.cert_authority); + + let transport_mode = match options.transport.to_lowercase().as_str() { + "tcp" => ServeTransportMode::Tcp, + "tls" => ServeTransportMode::Tls, + "iroh" => ServeTransportMode::Iroh, + other => { + return Err(napi::Error::new( + napi::Status::InvalidArg, + format!("unknown transport '{}'; expected tcp, tls, or iroh", other), + )); + } + }; + + let listen_addr_str = options.listen.as_deref().unwrap_or("0.0.0.0:22"); + + let mut serve_opts = ServeOptions::new(host_key_source.clone()) + .transport_mode(transport_mode.clone()) + .listen_addr(listen_addr_str); + + if let Some(aks) = authorized_keys_source.clone() { + serve_opts = serve_opts.authorized_keys(aks); + } + if let Some(cas) = cert_authority_source.clone() { + serve_opts = serve_opts.cert_authority(cas); + } + if let Some(ref cert) = options.tls_cert { + serve_opts = serve_opts.tls_cert(cert); + } + if let Some(ref key) = options.tls_key { + serve_opts = serve_opts.tls_key(key); + } + if let Some(ref domain) = options.acme_domain { + serve_opts = serve_opts.acme_domain(domain); + } + if let Some(ref relay) = options.iroh_relay { + serve_opts = serve_opts.iroh_relay(relay); + } + + let _core_server = Server::new(serve_opts).map_err(|e| { + napi::Error::new( + napi::Status::InvalidArg, + format!("server config error: {}", e), + ) + })?; + + let shutdown_tx = _core_server.shutdown_sender(); + + match transport_mode { + ServeTransportMode::Tcp => { + let addr = parse_addr(listen_addr_str)?; + let acceptor = TcpAcceptor::bind(addr).await.map_err(|e| { + napi::Error::new( + napi::Status::GenericFailure, + format!("tcp bind failed: {}", e), + ) + })?; + let actual_listen = acceptor.listen_addr().to_string(); + + let auth_config = Arc::new( + ServerAuthConfig::from_keys_and_ca(authorized_keys_source, cert_authority_source) + .map_err(|e| { + napi::Error::new( + napi::Status::InvalidArg, + format!("auth config error: {}", e), + ) + })?, + ); + + let private_key = + wraith_core::auth::keys::load_private_key(host_key_source).map_err(|e| { + napi::Error::new(napi::Status::InvalidArg, format!("host key error: {}", e)) + })?; + + let config = Arc::new(server::Config { + keys: vec![private_key], + ..Default::default() + }); + + let connection_limiter = Arc::new(ConnectionRateLimiter::new(0)); + let shutdown_rx = shutdown_tx.subscribe(); + let tsfn_holder: Arc>> = Arc::new(Mutex::new(None)); + + let tsfn_for_loop = tsfn_holder.clone(); + + tokio::spawn(async move { + run_accept_loop( + acceptor, + config, + auth_config, + connection_limiter, + shutdown_rx, + tsfn_for_loop, + ) + .await; + }); + + Ok(WraithServer { + shutdown_tx, + listen_addr: actual_listen, + on_connection_tsfn: tsfn_holder, + }) + } + ServeTransportMode::Tls => Err(napi::Error::new( + napi::Status::GenericFailure, + "TLS transport is not yet supported in napi serve()".to_string(), + )), + ServeTransportMode::Iroh => Err(napi::Error::new( + napi::Status::GenericFailure, + "iroh transport is not yet supported in napi serve()".to_string(), + )), + } +} + +async fn run_accept_loop( + acceptor: A, + config: Arc, + auth_config: Arc, + connection_limiter: Arc, + mut shutdown_rx: tokio::sync::watch::Receiver, + tsfn_holder: Arc>>, +) where + A: TransportAcceptor + Send + Sync + 'static, +{ + loop { + if *shutdown_rx.borrow() { + break; + } + + let accept_result = tokio::select! { + result = acceptor.accept() => result, + _ = shutdown_rx.changed() => break, + }; + + let (stream, info) = match accept_result { + Ok(conn) => conn, + Err(_) => continue, + }; + + let remote_addr = info.remote_addr; + let (channel_tx, mut channel_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + let channel_sender = Arc::new(Mutex::new(Some(channel_tx))); + + let handler = NapiServerHandler::new( + Arc::clone(&auth_config), + remote_addr, + Arc::clone(&connection_limiter), + 10, + channel_sender, + ); + + if !handler.is_connection_allowed() { + continue; + } + + let config = Arc::clone(&config); + let tsfn_holder = tsfn_holder.clone(); + let remote_addr_str = remote_addr.map(|a| a.to_string()); + let transport_kind_str = "tcp".to_string(); + + tokio::spawn(async move { + let running = match server::run_stream(config, stream, handler).await { + Ok(r) => r, + Err(_) => return, + }; + + loop { + let channel = channel_rx.recv().await; + match channel { + Some(ch) => { + let channel_stream = ch.into_stream(); + let (read_half, write_half) = tokio::io::split(channel_stream); + let server_stream = WraithServerStream { + read: Arc::new(Mutex::new(read_half)), + write: Arc::new(Mutex::new(write_half)), + }; + + let conn_info = ConnectionInfo { + remote_addr: remote_addr_str.clone(), + transport_kind: transport_kind_str.clone(), + }; + + let event = ConnectionEventWrapper { + stream: server_stream, + info: conn_info, + }; + + let tsfn_guard = tsfn_holder.lock().await; + if let Some(ref tsfn) = *tsfn_guard { + let _ = tsfn.call(Ok(event), ThreadsafeFunctionCallMode::NonBlocking); + } + } + None => break, + } + } + + let _ = running.await; + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use russh::server::Handler; + + const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n"; + + #[test] + fn resolve_key_source_file_path() { + let key = Some(Either::::A("/path/to/key".to_string())); + let result = resolve_key_source(&key, "hostKey"); + assert!(result.is_ok()); + match result.unwrap() { + KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/key")), + _ => panic!("expected File variant"), + } + } + + #[test] + fn resolve_key_source_buffer() { + let key = Some(Either::::B(Buffer::from( + ED25519_PRIVATE_KEY.as_bytes().to_vec(), + ))); + let result = resolve_key_source(&key, "hostKey"); + assert!(result.is_ok()); + match result.unwrap() { + KeySource::Memory(data) => assert!(!data.is_empty()), + _ => panic!("expected Memory variant"), + } + } + + #[test] + fn resolve_key_source_missing() { + let key: Option> = None; + assert!(resolve_key_source(&key, "hostKey").is_err()); + } + + #[test] + fn resolve_optional_key_source_none() { + let key: Option> = None; + assert!(resolve_optional_key_source(&key).is_none()); + } + + #[test] + fn resolve_optional_key_source_file() { + let key = Some(Either::::A("/path/to/keys".to_string())); + let result = resolve_optional_key_source(&key); + assert!(result.is_some()); + match result.unwrap() { + KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/keys")), + _ => panic!("expected File variant"), + } + } + + #[test] + fn resolve_optional_key_source_buffer() { + let key = Some(Either::::B(Buffer::from( + b"keydata".to_vec(), + ))); + let result = resolve_optional_key_source(&key); + assert!(result.is_some()); + match result.unwrap() { + KeySource::Memory(data) => assert_eq!(data, b"keydata".to_vec()), + _ => panic!("expected Memory variant"), + } + } + + #[test] + fn parse_addr_valid() { + let addr = parse_addr("127.0.0.1:22"); + assert!(addr.is_ok()); + assert_eq!(addr.unwrap().port(), 22); + } + + #[test] + fn parse_addr_invalid() { + assert!(parse_addr("not-an-address").is_err()); + } + + #[test] + fn connection_info_fields() { + let info = ConnectionInfo { + remote_addr: Some("127.0.0.1:12345".to_string()), + transport_kind: "tcp".to_string(), + }; + assert_eq!(info.remote_addr, Some("127.0.0.1:12345".to_string())); + assert_eq!(info.transport_kind, "tcp"); + } + + #[test] + fn napi_server_handler_allows_connection() { + let auth_config = Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()); + let (tx, _) = tokio::sync::mpsc::unbounded_channel::>(); + let handler = NapiServerHandler::new( + auth_config, + None, + Arc::new(ConnectionRateLimiter::new(0)), + 10, + Arc::new(Mutex::new(Some(tx))), + ); + assert!(handler.is_connection_allowed()); + } + + #[tokio::test] + async fn napi_server_handler_rejects_unknown_key() { + let auth_config = Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()); + let (tx, _) = tokio::sync::mpsc::unbounded_channel::>(); + let mut handler = NapiServerHandler::new( + auth_config, + None, + Arc::new(ConnectionRateLimiter::new(0)), + 10, + Arc::new(Mutex::new(Some(tx))), + ); + + let test_key_str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE test@host"; + let public_key = + russh::keys::parse_public_key_base64(test_key_str.split_whitespace().nth(1).unwrap()) + .unwrap(); + + let result = handler + .auth_publickey("testuser", &public_key) + .await + .unwrap(); + assert_eq!( + result, + russh::server::Auth::Reject { + proceed_with_methods: None + } + ); + } + + #[test] + fn napi_server_handler_connection_limiter() { + let limiter = Arc::new(ConnectionRateLimiter::new(1)); + let auth_config = Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()); + let (tx, _) = tokio::sync::mpsc::unbounded_channel::>(); + let addr: SocketAddr = "10.0.0.1:22".parse().unwrap(); + + let h1 = NapiServerHandler::new( + auth_config.clone(), + Some(addr), + limiter.clone(), + 10, + Arc::new(Mutex::new(Some(tx.clone()))), + ); + assert!(h1.is_connection_allowed()); + + let h2 = NapiServerHandler::new( + auth_config, + Some(addr), + limiter, + 10, + Arc::new(Mutex::new(Some(tx))), + ); + assert!(!h2.is_connection_allowed()); + + drop(h1); + + let h3 = NapiServerHandler::new( + Arc::new(ServerAuthConfig::from_keys_and_ca(None, None).unwrap()), + Some(addr), + Arc::new(ConnectionRateLimiter::new(1)), + 10, + Arc::new(Mutex::new(None)), + ); + assert!(h3.is_connection_allowed()); + } +}