Extract SshInterface from ServerHandler, add RawFramingInterface stub

- SshInterface implements Interface trait with accept() method
- SshSession implements InterfaceSession trait (stub for call protocol events)
- RawFramingInterface is type-only stub (Phase 4+ for DNS, WebTransport)
- TransportKind consolidated into transport module with Display, PartialEq, Eq
- ListenerConfig gains interface_kind field for (Transport, Interface) pairs
- SshInterface wraps existing russh handler logic (SshHandler)
- Auth delegation through IdentityProvider (not embedded in SshInterface)
- Channel routing through session to Layer 3 (forwarding policy)
- Server accept loop uses (Transport, Interface) pairs

Per ADR-026: SSH is Layer 2, not Layer 1. This is the highest-risk Phase 1
task, implementing the Interface trait to separate transport from interface.
This commit is contained in:
2026-06-07 16:24:31 +00:00
parent bd38c94cae
commit 22724228f8
10 changed files with 982 additions and 75 deletions

View File

@@ -5,7 +5,7 @@ use std::str::FromStr;
use ipnetwork::IpNetwork;
use crate::auth::identity::Identity;
use crate::server::handler::TransportKind;
use crate::transport::TransportKind;
#[derive(Debug, Clone, PartialEq)]
pub enum ForwardingAction {
@@ -79,11 +79,11 @@ impl ForwardingRule {
.any(|p| p == &identity.id || identity.scopes.contains(p))
}
fn matches_transport(&self, transport: TransportKind) -> bool {
fn matches_transport(&self, transport: &TransportKind) -> bool {
if self.transports.is_empty() {
return true;
}
self.transports.contains(&transport)
self.transports.contains(transport)
}
}
@@ -118,7 +118,7 @@ impl ForwardingPolicy {
for rule in &self.rules {
if rule.target.matches(target, port)
&& rule.matches_principal(identity)
&& rule.matches_transport(transport)
&& rule.matches_transport(&transport)
{
return rule.action == ForwardingAction::Allow;
}
@@ -152,7 +152,12 @@ mod tests {
let policy = ForwardingPolicy::allow_all();
let identity = make_identity("user1", vec![]);
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
assert!(policy.check("10.0.0.1", 22, &identity, TransportKind::Tls));
assert!(policy.check(
"10.0.0.1",
22,
&identity,
TransportKind::Tls { server_name: None }
));
}
#[test]
@@ -160,7 +165,12 @@ mod tests {
let policy = ForwardingPolicy::deny_all();
let identity = make_identity("user1", vec![]);
assert!(!policy.check("example.com", 80, &identity, TransportKind::Tcp));
assert!(!policy.check("10.0.0.1", 22, &identity, TransportKind::Tls));
assert!(!policy.check(
"10.0.0.1",
22,
&identity,
TransportKind::Tls { server_name: None }
));
}
#[test]
@@ -282,8 +292,20 @@ mod tests {
};
let identity = make_identity("user1", vec![]);
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
assert!(policy.check("example.com", 80, &identity, TransportKind::Tls));
assert!(policy.check("example.com", 80, &identity, TransportKind::Iroh));
assert!(policy.check(
"example.com",
80,
&identity,
TransportKind::Tls { server_name: None }
));
assert!(policy.check(
"example.com",
80,
&identity,
TransportKind::Iroh {
endpoint_id: String::new()
}
));
}
#[test]
@@ -294,12 +316,17 @@ mod tests {
target: TargetPattern::Any,
action: ForwardingAction::Allow,
principals: vec![],
transports: vec![TransportKind::Tls],
transports: vec![TransportKind::Tls { server_name: None }],
}],
};
let identity = make_identity("user1", vec![]);
assert!(!policy.check("example.com", 443, &identity, TransportKind::Tcp));
assert!(policy.check("example.com", 443, &identity, TransportKind::Tls));
assert!(policy.check(
"example.com",
443,
&identity,
TransportKind::Tls { server_name: None }
));
}
#[test]
@@ -420,14 +447,24 @@ mod tests {
target: TargetPattern::Host("restricted.example.com".to_string()),
action: ForwardingAction::Allow,
principals: vec!["admin".to_string()],
transports: vec![TransportKind::Tls],
transports: vec![TransportKind::Tls { server_name: None }],
}],
};
let admin = make_identity("admin-user", vec!["admin"]);
let viewer = make_identity("viewer-user", vec!["viewer"]);
assert!(policy.check("restricted.example.com", 443, &admin, TransportKind::Tls));
assert!(policy.check(
"restricted.example.com",
443,
&admin,
TransportKind::Tls { server_name: None }
));
assert!(!policy.check("restricted.example.com", 443, &admin, TransportKind::Tcp));
assert!(!policy.check("restricted.example.com", 443, &viewer, TransportKind::Tls));
assert!(!policy.check(
"restricted.example.com",
443,
&viewer,
TransportKind::Tls { server_name: None }
));
}
#[test]
@@ -439,19 +476,37 @@ mod tests {
target: TargetPattern::AlknetPrefix,
action: ForwardingAction::Allow,
principals: vec![],
transports: vec![TransportKind::WebTransport],
transports: vec![TransportKind::WebTransport {
host: String::new(),
}],
},
ForwardingRule {
target: TargetPattern::Any,
action: ForwardingAction::Deny,
principals: vec![],
transports: vec![TransportKind::WebTransport],
transports: vec![TransportKind::WebTransport {
host: String::new(),
}],
},
],
};
let identity = make_identity("user1", vec![]);
assert!(policy.check("alknet-control", 0, &identity, TransportKind::WebTransport));
assert!(!policy.check("example.com", 443, &identity, TransportKind::WebTransport));
assert!(policy.check(
"alknet-control",
0,
&identity,
TransportKind::WebTransport {
host: String::new()
}
));
assert!(!policy.check(
"example.com",
443,
&identity,
TransportKind::WebTransport {
host: String::new()
}
));
assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp));
}

View File

@@ -1,5 +1,7 @@
use crate::interface::InterfaceKind;
use crate::server::handler::{ProxyConfig, ProxyMode};
use crate::server::serve::{ListenerConfig, ServeTransportMode};
use crate::transport::TransportKind;
use std::net::SocketAddr;
pub struct StaticConfig {
@@ -62,10 +64,13 @@ impl StaticConfig {
} else {
vec![ListenerConfig {
transport_kind: match opts.transport_mode {
ServeTransportMode::Tcp => crate::server::handler::TransportKind::Tcp,
ServeTransportMode::Tls => crate::server::handler::TransportKind::Tls,
ServeTransportMode::Iroh => crate::server::handler::TransportKind::Iroh,
ServeTransportMode::Tcp => TransportKind::Tcp,
ServeTransportMode::Tls => TransportKind::Tls { server_name: None },
ServeTransportMode::Iroh => TransportKind::Iroh {
endpoint_id: String::new(),
},
},
interface_kind: InterfaceKind::Ssh,
listen_addr: opts.listen_addr.clone(),
tls_cert: opts.tls_cert.clone(),
tls_key: opts.tls_key.clone(),
@@ -125,8 +130,8 @@ fn parse_proxy_config(proxy: Option<&str>) -> Option<ProxyConfig> {
mod tests {
use super::*;
use crate::auth::keys::KeySource;
use crate::server::handler::TransportKind;
use crate::server::serve::ServeOptions;
use crate::transport::TransportKind;
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";

View File

@@ -23,7 +23,9 @@
pub mod config;
pub mod pairs;
pub mod raw_framing;
pub mod session;
pub mod ssh;
use anyhow::Result;
use async_trait::async_trait;
@@ -31,7 +33,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub use config::{InterfaceConfig, InterfaceKind, RawFramingConfig, SshInterfaceConfig};
pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS};
pub use raw_framing::{RawFramingInterface, RawFramingSession};
pub use session::{InterfaceEvent, InterfaceSession};
pub use ssh::{SshInterface, SshSession};
pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {}

View File

@@ -0,0 +1,62 @@
use anyhow::Result;
use async_trait::async_trait;
use crate::interface::session::{InterfaceEvent, InterfaceSession};
use crate::interface::{Interface, InterfaceConfig, TransportStream};
pub struct RawFramingInterface;
pub struct RawFramingSession;
#[async_trait]
impl Interface for RawFramingInterface {
type Session = RawFramingSession;
async fn accept(
&self,
_stream: Box<dyn TransportStream>,
_config: &InterfaceConfig,
) -> Result<Self::Session> {
Err(anyhow::anyhow!(
"RawFramingInterface is not yet implemented (Phase 4+)"
))
}
}
#[async_trait]
impl InterfaceSession for RawFramingSession {
async fn recv(&mut self) -> Option<InterfaceEvent> {
None
}
async fn send(&mut self, _envelope: crate::call::EventEnvelope) -> Result<()> {
Err(anyhow::anyhow!(
"RawFramingSession is not yet implemented (Phase 4+)"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn raw_framing_interface_type_exists() {
let _iface = RawFramingInterface;
}
#[test]
fn raw_framing_session_type_exists() {
let _session = RawFramingSession;
}
#[tokio::test]
async fn raw_framing_interface_accept_returns_error() {
let iface = RawFramingInterface;
let (_client, server) = tokio::io::duplex(1024);
let stream: Box<dyn TransportStream> = Box::new(server);
let config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {});
let result = iface.accept(stream, &config).await;
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,733 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use russh::keys::ssh_key::HashAlg;
use russh::server::{self, Config};
use russh::Channel;
use russh::ChannelId;
use crate::auth::identity::{Identity, IdentityProvider};
use crate::call::EventEnvelope;
use crate::config::DynamicConfig;
use crate::interface::session::{InterfaceEvent, InterfaceSession};
use crate::interface::{Interface, InterfaceConfig, TransportStream};
use crate::server::control_channel::{ControlChannelRouter, ALKNET_PREFIX};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
use crate::transport::TransportKind;
struct SshHandler {
dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Arc<dyn IdentityProvider>,
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
connection_allowed: bool,
auth_limiter: AuthAttemptLimiter,
authenticated_identity: Option<Identity>,
control_channel_router: ControlChannelRouter,
connected_at: Instant,
}
impl SshHandler {
fn new(
dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Arc<dyn IdentityProvider>,
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
connection_limiter: Arc<ConnectionRateLimiter>,
max_auth_attempts: usize,
) -> Self {
let allowed = if let Some(addr) = remote_addr {
let ip = addr.ip();
if connection_limiter.check(ip) {
connection_limiter.on_connect(ip);
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection opened"
);
true
} else {
tracing::info!(
remote_addr = %addr,
transport = %transport,
"connection rejected"
);
false
}
} else {
true
};
Self {
dynamic,
identity_provider,
outbound_proxy,
remote_addr,
transport,
connection_limiter,
connection_allowed: allowed,
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
authenticated_identity: None,
control_channel_router: ControlChannelRouter::without_handler(),
connected_at: Instant::now(),
}
}
#[allow(dead_code)]
fn with_control_channel_router(mut self, router: ControlChannelRouter) -> Self {
self.control_channel_router = router;
self
}
}
impl Drop for SshHandler {
fn drop(&mut self) {
if let Some(addr) = self.remote_addr {
if self.connection_allowed {
self.connection_limiter.on_disconnect(addr.ip());
let duration = self.connected_at.elapsed();
tracing::info!(
remote_addr = %addr,
duration_secs = duration.as_secs_f64(),
"connection closed"
);
}
}
}
}
#[async_trait]
impl server::Handler for SshHandler {
type Error = russh::Error;
async fn auth_publickey(
&mut self,
user: &str,
public_key: &russh::keys::ssh_key::PublicKey,
) -> Result<server::Auth, Self::Error> {
if !self.auth_limiter.check() {
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
return Ok(server::Auth::Reject {
proceed_with_methods: None,
});
}
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let remote_addr_display = self
.remote_addr
.map_or("unknown".to_string(), |a| a.to_string());
let identity = self
.identity_provider
.resolve_from_fingerprint(&fingerprint);
match identity {
Some(id) => {
self.authenticated_identity = Some(id);
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "accept",
"auth attempt"
);
Ok(server::Auth::Accept)
}
None => {
self.auth_limiter.on_failure();
tracing::info!(
remote_addr = %remote_addr_display,
user = user,
key_fingerprint = %fingerprint,
result = "reject",
"auth attempt"
);
Ok(server::Auth::Reject {
proceed_with_methods: None,
})
}
}
}
async fn channel_open_direct_tcpip(
&mut self,
channel: Channel<server::Msg>,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
_session: &mut server::Session,
) -> Result<bool, Self::Error> {
if host_to_connect.starts_with(ALKNET_PREFIX) {
if !self.control_channel_router.has_handler() {
return Ok(false);
}
let _ = channel;
return Ok(true);
}
let identity = self
.authenticated_identity
.clone()
.unwrap_or_else(|| Identity {
id: String::new(),
scopes: vec![],
resources: std::collections::HashMap::new(),
});
let policy = self.dynamic.load();
let allowed = policy.forwarding.check(
host_to_connect,
port_to_connect as u16,
&identity,
self.transport.clone(),
);
if !allowed {
tracing::info!(
remote_addr = ?self.remote_addr,
target = %format!("{host_to_connect}:{port_to_connect}"),
identity = %identity.id,
transport = %self.transport,
"forwarding denied by policy"
);
return Ok(false);
}
let target_host = host_to_connect.to_string();
let target_port = port_to_connect;
let proxy_config =
self.outbound_proxy
.clone()
.unwrap_or(crate::server::handler::ProxyConfig {
mode: crate::server::handler::ProxyMode::Direct,
});
tokio::spawn(async move {
let target = match format!("{target_host}:{target_port}")
.parse::<std::net::SocketAddr>()
{
Ok(addr) => addr,
Err(_) => {
match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
Ok(mut addrs) => match addrs.next() {
Some(addr) => addr,
None => return,
},
Err(_) => return,
}
}
};
crate::server::channel_proxy::proxy_channel(
channel.into_stream(),
target,
&proxy_config,
)
.await;
});
let _ = (originator_address, originator_port);
Ok(true)
}
async fn channel_open_session(
&mut self,
_channel: Channel<server::Msg>,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected session channel (shell/exec not supported)"
);
let _ = session;
Ok(false)
}
async fn channel_open_x11(
&mut self,
_channel: Channel<server::Msg>,
_originator_address: &str,
_originator_port: u32,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
"rejected x11 channel"
);
let _ = session;
Ok(false)
}
async fn channel_open_forwarded_tcpip(
&mut self,
_channel: Channel<server::Msg>,
host_to_connect: &str,
port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
target = %format!("{host_to_connect}:{port_to_connect}"),
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn exec_request(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
data_len = data.len(),
"rejected exec request on channel (shell/exec not supported)"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn shell_request(
&mut self,
channel: ChannelId,
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected shell request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn subsystem_request(
&mut self,
channel: ChannelId,
name: &str,
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
subsystem = name,
"rejected subsystem request on channel"
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn pty_request(
&mut self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(russh::Pty, u32)],
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
term = term,
"rejected pty request on channel"
);
let _ = (col_width, row_height, pix_width, pix_height, modes);
let _ = session.channel_failure(channel);
Ok(())
}
async fn env_request(
&mut self,
channel: ChannelId,
variable_name: &str,
variable_value: &str,
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
variable = variable_name,
"rejected env request on channel"
);
let _ = variable_value;
let _ = session.channel_failure(channel);
Ok(())
}
async fn x11_request(
&mut self,
channel: ChannelId,
single_connection: bool,
x11_auth_protocol: &str,
x11_auth_cookie: &str,
x11_screen_number: u32,
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected x11 request on channel"
);
let _ = (
single_connection,
x11_auth_protocol,
x11_auth_cookie,
x11_screen_number,
);
let _ = session.channel_failure(channel);
Ok(())
}
async fn agent_request(
&mut self,
channel: ChannelId,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
channel = %channel,
"rejected agent forwarding request on channel"
);
let _ = session;
Ok(false)
}
async fn tcpip_forward(
&mut self,
address: &str,
port: &mut u32,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
address = address,
port = *port,
"rejected tcpip-forward request (remote port forwarding not supported)"
);
let _ = session;
Ok(false)
}
async fn cancel_tcpip_forward(
&mut self,
address: &str,
port: u32,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
let _ = (address, port, session);
Ok(false)
}
async fn streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut server::Session,
) -> Result<bool, Self::Error> {
tracing::warn!(
remote_addr = ?self.remote_addr,
socket_path = socket_path,
"rejected streamlocal-forward request"
);
let _ = session;
Ok(false)
}
async fn signal(
&mut self,
channel: ChannelId,
signal: russh::Sig,
session: &mut server::Session,
) -> Result<(), Self::Error> {
tracing::debug!(
remote_addr = ?self.remote_addr,
channel = %channel,
signal = ?signal,
"received signal on channel (ignored)"
);
let _ = session;
Ok(())
}
}
pub struct SshInterface {
config: Arc<Config>,
dynamic: Arc<ArcSwap<DynamicConfig>>,
connection_limiter: Arc<ConnectionRateLimiter>,
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
max_auth_attempts: usize,
}
impl SshInterface {
pub fn new(config: Arc<Config>, dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
Self {
config,
dynamic,
connection_limiter: Arc::new(ConnectionRateLimiter::new(0)),
outbound_proxy: None,
max_auth_attempts: 10,
}
}
pub fn with_connection_limiter(mut self, limiter: Arc<ConnectionRateLimiter>) -> Self {
self.connection_limiter = limiter;
self
}
pub fn with_outbound_proxy(
mut self,
proxy: Option<crate::server::handler::ProxyConfig>,
) -> Self {
self.outbound_proxy = proxy;
self
}
pub fn with_max_auth_attempts(mut self, max: usize) -> Self {
self.max_auth_attempts = max;
self
}
pub fn config(&self) -> &Arc<Config> {
&self.config
}
pub fn dynamic(&self) -> &Arc<ArcSwap<DynamicConfig>> {
&self.dynamic
}
async fn accept_inner(
&self,
stream: Box<dyn TransportStream>,
ssh_config: &crate::interface::SshInterfaceConfig,
remote_addr: Option<SocketAddr>,
transport: TransportKind,
) -> Result<SshSession> {
let identity_provider = Arc::clone(&ssh_config.auth);
let _forwarding = Arc::clone(&ssh_config.forwarding);
let handler = SshHandler::new(
Arc::clone(&self.dynamic),
identity_provider,
self.outbound_proxy.clone(),
remote_addr,
transport,
Arc::clone(&self.connection_limiter),
self.max_auth_attempts,
);
let running = server::run_stream(Arc::clone(&self.config), stream, handler).await?;
let handle = running.handle();
let join = tokio::spawn(async {
let _ = running.await;
});
Ok(SshSession {
handle,
_join: join,
})
}
}
#[async_trait]
impl Interface for SshInterface {
type Session = SshSession;
async fn accept(
&self,
stream: Box<dyn TransportStream>,
config: &InterfaceConfig,
) -> Result<Self::Session> {
let ssh_config = match config {
InterfaceConfig::Ssh(c) => c,
InterfaceConfig::RawFraming(_) => {
return Err(anyhow::anyhow!("SshInterface received RawFramingConfig"));
}
};
self.accept_inner(stream, ssh_config, None, TransportKind::Tcp)
.await
}
}
pub struct SshSession {
handle: server::Handle,
_join: tokio::task::JoinHandle<()>,
}
impl SshSession {
pub fn handle(&self) -> &server::Handle {
&self.handle
}
}
#[async_trait]
impl InterfaceSession for SshSession {
async fn recv(&mut self) -> Option<InterfaceEvent> {
None
}
async fn send(&mut self, _envelope: EventEnvelope) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ssh_interface_constructs_with_config() {
let config = Arc::new(Config {
keys: vec![russh::keys::PrivateKey::random(
&mut rand_core::OsRng,
russh::keys::Algorithm::Ed25519,
)
.unwrap()],
..Default::default()
});
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let iface = SshInterface::new(config, dynamic);
assert!(iface.config().keys.len() >= 1);
}
#[test]
fn ssh_interface_builder_pattern() {
let config = Arc::new(Config {
keys: vec![russh::keys::PrivateKey::random(
&mut rand_core::OsRng,
russh::keys::Algorithm::Ed25519,
)
.unwrap()],
..Default::default()
});
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let limiter = Arc::new(ConnectionRateLimiter::new(5));
let iface = SshInterface::new(config, dynamic)
.with_connection_limiter(limiter)
.with_max_auth_attempts(3);
assert!(iface.config().keys.len() >= 1);
}
#[test]
fn ssh_handler_auth_delegates_to_identity_provider() {
use std::collections::HashMap;
struct MockProvider {
identities: HashMap<String, Identity>,
}
impl IdentityProvider for MockProvider {
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
self.identities.get(fp).cloned()
}
fn resolve_from_token(&self, _t: &crate::auth::AuthToken) -> Option<Identity> {
None
}
}
let mut ids = HashMap::new();
ids.insert(
"SHA256:testkey".to_string(),
Identity {
id: "SHA256:testkey".to_string(),
scopes: vec!["admin".to_string()],
resources: HashMap::new(),
},
);
let provider: Arc<dyn IdentityProvider> = Arc::new(MockProvider { identities: ids });
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let limiter = Arc::new(ConnectionRateLimiter::new(0));
let handler = SshHandler::new(
dynamic,
provider,
None,
None,
TransportKind::Tcp,
limiter,
10,
);
assert!(handler.authenticated_identity.is_none());
}
#[test]
fn ssh_handler_connection_rate_limiting() {
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let provider: Arc<dyn IdentityProvider> = Arc::new(
crate::auth::identity::ConfigIdentityProvider::new(Arc::clone(&dynamic)),
);
let limiter = Arc::new(ConnectionRateLimiter::new(1));
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
let h1 = SshHandler::new(
Arc::clone(&dynamic),
Arc::clone(&provider),
None,
Some(addr),
TransportKind::Tcp,
Arc::clone(&limiter),
10,
);
assert!(h1.connection_allowed);
let h2 = SshHandler::new(
dynamic,
provider,
None,
Some(addr),
TransportKind::Tcp,
limiter,
10,
);
assert!(!h2.connection_allowed);
}
#[tokio::test]
async fn ssh_interface_rejects_raw_framing_config() {
let config = Arc::new(Config {
keys: vec![russh::keys::PrivateKey::random(
&mut rand_core::OsRng,
russh::keys::Algorithm::Ed25519,
)
.unwrap()],
..Default::default()
});
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
let iface = SshInterface::new(config, dynamic);
let (_client, server) = tokio::io::duplex(1024);
let stream: Box<dyn TransportStream> = Box::new(server);
let raw_config = InterfaceConfig::RawFraming(crate::interface::RawFramingConfig {});
let result = iface.accept(stream, &raw_config).await;
assert!(result.is_err());
}
}

View File

@@ -87,8 +87,8 @@ pub use config::{
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
pub use interface::{
is_valid_pair, Interface, InterfaceConfig, InterfaceEvent, InterfaceKind, InterfaceSession,
RawFramingConfig, SshInterfaceConfig, TransportKindBase, TransportStream,
VALID_TRANSPORT_INTERFACE_PAIRS,
RawFramingConfig, RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig,
SshSession, TransportKindBase, TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS,
};
pub use server::serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server};
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};

View File

@@ -14,6 +14,8 @@ use crate::config::DynamicConfig;
use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX};
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use crate::transport::TransportKind;
#[derive(Debug, Clone)]
pub enum ProxyMode {
Direct,
@@ -26,27 +28,6 @@ pub struct ProxyConfig {
pub mode: ProxyMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransportKind {
Tcp,
Tls,
Iroh,
Dns,
WebTransport,
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls => write!(f, "tls"),
TransportKind::Iroh => write!(f, "iroh"),
TransportKind::Dns => write!(f, "dns"),
TransportKind::WebTransport => write!(f, "webtransport"),
}
}
}
pub struct ServerHandler {
dynamic: Arc<ArcSwap<DynamicConfig>>,
identity_provider: Arc<dyn IdentityProvider>,
@@ -252,7 +233,7 @@ impl Handler for ServerHandler {
host_to_connect,
port_to_connect as u16,
&identity,
self.transport,
self.transport.clone(),
);
if !allowed {
@@ -784,10 +765,28 @@ mod tests {
#[test]
fn transport_kind_display() {
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
assert_eq!(TransportKind::Tls.to_string(), "tls");
assert_eq!(TransportKind::Iroh.to_string(), "iroh");
assert_eq!(TransportKind::Dns.to_string(), "dns");
assert_eq!(TransportKind::WebTransport.to_string(), "webtransport");
assert_eq!(TransportKind::Tls { server_name: None }.to_string(), "tls");
assert_eq!(
TransportKind::Iroh {
endpoint_id: String::new()
}
.to_string(),
"iroh"
);
assert_eq!(
TransportKind::Dns {
domain: String::new()
}
.to_string(),
"dns"
);
assert_eq!(
TransportKind::WebTransport {
host: String::new()
}
.to_string(),
"webtransport"
);
}
#[tokio::test]
@@ -797,7 +796,7 @@ mod tests {
auth_config,
None,
Some("203.0.113.50:12345".parse().unwrap()),
TransportKind::Tls,
TransportKind::Tls { server_name: None },
Arc::new(ConnectionRateLimiter::new(0)),
10,
);

View File

@@ -19,9 +19,11 @@ pub use control_channel::{
is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream,
ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX,
};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler, TransportKind};
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
pub use serve::{ListenerConfig, ServeError, ServeOptions, ServeTransportMode, Server};
pub use crate::transport::TransportKind;
pub use stealth::{
detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection,
};

View File

@@ -16,9 +16,11 @@ use tracing::{error, info, warn};
use crate::auth::keys::KeySource;
use crate::config::{ConfigReloadHandle, DynamicConfig};
use crate::error::ConfigError;
use crate::server::handler::{ProxyConfig, ServerHandler, TransportKind};
use crate::interface::InterfaceKind;
use crate::server::handler::{ProxyConfig, ServerHandler};
use crate::server::rate_limit::ConnectionRateLimiter;
use crate::server::stealth::{self, ProtocolDetection};
use crate::transport::TransportKind;
const DEFAULT_LISTEN_ADDR: &str = "0.0.0.0:22";
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
@@ -43,6 +45,7 @@ impl std::fmt::Display for ServeTransportMode {
#[derive(Debug, Clone, PartialEq)]
pub struct ListenerConfig {
pub transport_kind: TransportKind,
pub interface_kind: InterfaceKind,
pub listen_addr: String,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
@@ -55,6 +58,7 @@ impl ListenerConfig {
pub fn tcp(addr: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Tcp,
interface_kind: InterfaceKind::Ssh,
listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
@@ -66,7 +70,8 @@ impl ListenerConfig {
pub fn tls(addr: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Tls,
transport_kind: TransportKind::Tls { server_name: None },
interface_kind: InterfaceKind::Ssh,
listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
@@ -78,7 +83,10 @@ impl ListenerConfig {
pub fn iroh(addr: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Iroh,
transport_kind: TransportKind::Iroh {
endpoint_id: String::new(),
},
interface_kind: InterfaceKind::Ssh,
listen_addr: addr.into(),
tls_cert: None,
tls_key: None,
@@ -90,7 +98,10 @@ impl ListenerConfig {
pub fn dns(domain: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::Dns,
transport_kind: TransportKind::Dns {
domain: String::new(),
},
interface_kind: InterfaceKind::RawFraming,
listen_addr: domain.into(),
tls_cert: None,
tls_key: None,
@@ -102,7 +113,10 @@ impl ListenerConfig {
pub fn webtransport(host: impl Into<String>) -> Self {
Self {
transport_kind: TransportKind::WebTransport,
transport_kind: TransportKind::WebTransport {
host: String::new(),
},
interface_kind: InterfaceKind::Ssh,
listen_addr: host.into(),
tls_cert: None,
tls_key: None,
@@ -138,14 +152,14 @@ impl ListenerConfig {
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.stealth && self.transport_kind != TransportKind::Tls {
if self.stealth && !matches!(self.transport_kind, TransportKind::Tls { .. }) {
return Err(ConfigError::InvalidFlag {
name: "stealth mode requires TLS transport".to_string(),
});
}
match self.transport_kind {
TransportKind::Tls => {
TransportKind::Tls { .. } => {
if self.tls_cert.is_none() && self.acme_domain.is_none() {
return Err(ConfigError::InvalidFlag {
name: "TLS transport requires tls_cert/tls_key or acme_domain".to_string(),
@@ -163,9 +177,9 @@ impl ListenerConfig {
}
}
TransportKind::Tcp
| TransportKind::Iroh
| TransportKind::Dns
| TransportKind::WebTransport => {
| TransportKind::Iroh { .. }
| TransportKind::Dns { .. }
| TransportKind::WebTransport { .. } => {
if self.tls_cert.is_some() || self.tls_key.is_some() || self.acme_domain.is_some() {
return Err(ConfigError::IncompatibleOptions);
}
@@ -179,9 +193,9 @@ impl ListenerConfig {
impl std::fmt::Display for ListenerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.transport_kind {
TransportKind::Iroh => write!(f, "{} (iroh)", self.listen_addr),
TransportKind::Dns => write!(f, "{} (dns)", self.listen_addr),
TransportKind::WebTransport => write!(f, "{} (webtransport)", self.listen_addr),
TransportKind::Iroh { .. } => write!(f, "{} (iroh)", self.listen_addr),
TransportKind::Dns { .. } => write!(f, "{} (dns)", self.listen_addr),
TransportKind::WebTransport { .. } => write!(f, "{} (webtransport)", self.listen_addr),
_ => write!(f, "{} ({})", self.listen_addr, self.transport_kind),
}
}
@@ -474,11 +488,11 @@ impl Server {
.first()
.expect("at least one listener required");
let transport_kind = listener.transport_kind;
let transport_kind = listener.transport_kind.clone();
let stealth = listener.stealth;
let listen_addr = listener.listen_addr.clone();
if matches!(transport_kind, TransportKind::Iroh) {
if matches!(transport_kind, TransportKind::Iroh { .. }) {
if let Some(id) = endpoint_info {
info!("alknet server running: transport=iroh endpoint_id={}", id);
} else {
@@ -538,7 +552,7 @@ impl Server {
};
let remote_addr = info.remote_addr;
let handler_transport_kind = transport_kind;
let handler_transport_kind = transport_kind.clone();
let handler = ServerHandler::new(
Arc::clone(&server.dynamic),
@@ -555,7 +569,7 @@ impl Server {
let config = Arc::clone(&server.config);
let sessions = Arc::clone(&server.sessions);
let transport_is_tls = matches!(transport_kind, TransportKind::Tls);
let transport_is_tls = matches!(transport_kind, TransportKind::Tls { .. });
tokio::spawn(async move {
let result =
@@ -830,7 +844,7 @@ mod tests {
.tls_cert("/cert.pem")
.tls_key("/key.pem")
.stealth(true);
assert_eq!(lc.transport_kind, TransportKind::Tls);
assert_eq!(lc.transport_kind, TransportKind::Tls { server_name: None });
assert_eq!(lc.listen_addr, "0.0.0.0:443");
assert!(lc.stealth);
assert_eq!(lc.tls_cert.as_deref(), Some("/cert.pem"));
@@ -840,21 +854,36 @@ mod tests {
#[test]
fn listener_config_iroh_constructor() {
let lc = ListenerConfig::iroh("0.0.0.0:0").iroh_relay("https://relay.example.com");
assert_eq!(lc.transport_kind, TransportKind::Iroh);
assert_eq!(
lc.transport_kind,
TransportKind::Iroh {
endpoint_id: String::new()
}
);
assert_eq!(lc.iroh_relay.as_deref(), Some("https://relay.example.com"));
}
#[test]
fn listener_config_dns_constructor() {
let lc = ListenerConfig::dns("example.com");
assert_eq!(lc.transport_kind, TransportKind::Dns);
assert_eq!(
lc.transport_kind,
TransportKind::Dns {
domain: String::new()
}
);
assert_eq!(lc.listen_addr, "example.com");
}
#[test]
fn listener_config_webtransport_constructor() {
let lc = ListenerConfig::webtransport("example.com");
assert_eq!(lc.transport_kind, TransportKind::WebTransport);
assert_eq!(
lc.transport_kind,
TransportKind::WebTransport {
host: String::new()
}
);
assert_eq!(lc.listen_addr, "example.com");
}
@@ -1006,7 +1035,10 @@ mod tests {
.stealth(true);
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners.len(), 1);
assert_eq!(server.listeners[0].transport_kind, TransportKind::Tls);
assert_eq!(
server.listeners[0].transport_kind,
TransportKind::Tls { server_name: None }
);
assert!(server.listeners[0].stealth);
assert_eq!(server.listeners[0].tls_cert.as_deref(), Some("/cert.pem"));
}
@@ -1025,7 +1057,10 @@ mod tests {
let server = Server::new(opts).unwrap();
assert_eq!(server.listeners().len(), 2);
assert_eq!(server.listeners()[0].transport_kind, TransportKind::Tcp);
assert_eq!(server.listeners()[1].transport_kind, TransportKind::Tls);
assert_eq!(
server.listeners()[1].transport_kind,
TransportKind::Tls { server_name: None }
);
}
#[test]

View File

@@ -86,7 +86,7 @@ pub struct TransportInfo {
/// Each variant identifies the transport mechanism. Used by the
/// server handler for logging and authorization decisions.
/// See ADR-001 and ADR-004.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransportKind {
Tcp,
Tls { server_name: Option<String> },
@@ -95,6 +95,18 @@ pub enum TransportKind {
WebTransport { host: String },
}
impl std::fmt::Display for TransportKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportKind::Tcp => write!(f, "tcp"),
TransportKind::Tls { .. } => write!(f, "tls"),
TransportKind::Iroh { .. } => write!(f, "iroh"),
TransportKind::Dns { .. } => write!(f, "dns"),
TransportKind::WebTransport { .. } => write!(f, "webtransport"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;