diff --git a/crates/alknet-core/src/call/frame.rs b/crates/alknet-core/src/call/frame.rs index 63168e6..8907011 100644 --- a/crates/alknet-core/src/call/frame.rs +++ b/crates/alknet-core/src/call/frame.rs @@ -1,3 +1,7 @@ +use std::io; + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + use crate::call::envelope::EventEnvelope; pub fn encode(envelope: &EventEnvelope) -> Vec { @@ -58,6 +62,73 @@ pub enum FrameDecodeError { Json(#[from] serde_json::Error), } +pub struct FrameFramedReader { + stream: S, + buf: Vec, +} + +impl FrameFramedReader +where + S: AsyncRead + Unpin, +{ + pub fn new(stream: S) -> Self { + Self { + stream, + buf: Vec::with_capacity(4096), + } + } + + pub async fn read_frame(&mut self) -> io::Result> { + loop { + if self.buf.len() >= 4 { + let len = u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]) + as usize; + let total = 4 + len; + if self.buf.len() >= total { + let body = &self.buf[4..total]; + match serde_json::from_slice(body) { + Ok(envelope) => { + self.buf.drain(..total); + return Ok(Some(envelope)); + } + Err(e) => { + self.buf.drain(..total); + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + } + } + } + } + + let mut tmp = [0u8; 4096]; + match self.stream.read(&mut tmp).await { + Ok(0) => return Ok(None), + Ok(n) => self.buf.extend_from_slice(&tmp[..n]), + Err(e) => return Err(e), + } + } + } +} + +pub struct FrameFramedWriter { + stream: S, +} + +impl FrameFramedWriter +where + S: AsyncWrite + Unpin, +{ + pub fn new(stream: S) -> Self { + Self { stream } + } + + pub async fn write_frame(&mut self, envelope: &EventEnvelope) -> io::Result<()> { + let frame = encode(envelope); + self.stream.write_all(&frame).await?; + self.stream.flush().await?; + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/alknet-core/src/call/mod.rs b/crates/alknet-core/src/call/mod.rs index 2dabb2e..326c737 100644 --- a/crates/alknet-core/src/call/mod.rs +++ b/crates/alknet-core/src/call/mod.rs @@ -18,7 +18,9 @@ pub use context::OperationContext; pub use env::OperationEnv; pub use envelope::EventEnvelope; pub use events::{CALL_ABORTED, CALL_COMPLETED, CALL_ERROR, CALL_REQUESTED, CALL_RESPONDED}; -pub use frame::{decode, decode_with_remainder, encode, FrameDecodeError}; +pub use frame::{ + decode, decode_with_remainder, encode, FrameDecodeError, FrameFramedReader, FrameFramedWriter, +}; pub use pending::PendingRequestMap; pub use registry::{Handler, OperationRegistry, OperationRegistryBuilder}; pub use response::{CallError, ResponseEnvelope}; diff --git a/crates/alknet-core/src/interface/mod.rs b/crates/alknet-core/src/interface/mod.rs index 0184b1b..de549b0 100644 --- a/crates/alknet-core/src/interface/mod.rs +++ b/crates/alknet-core/src/interface/mod.rs @@ -35,7 +35,7 @@ pub use http::HttpInterface; pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS}; pub use raw_framing::{RawFramingInterface, RawFramingSession}; pub use session::{InterfaceEvent, InterfaceSession}; -pub use ssh::{SshInterface, SshSession}; +pub use ssh::{ControlChannelBridge, SshInterface, SshSession}; pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {} diff --git a/crates/alknet-core/src/interface/ssh.rs b/crates/alknet-core/src/interface/ssh.rs index 26f80fc..559b5f1 100644 --- a/crates/alknet-core/src/interface/ssh.rs +++ b/crates/alknet-core/src/interface/ssh.rs @@ -9,13 +9,18 @@ use russh::keys::ssh_key::HashAlg; use russh::server::{self, Config}; use russh::Channel; use russh::ChannelId; +use tokio::sync::mpsc; use crate::auth::identity::{Identity, IdentityProvider}; +use crate::call::frame::{FrameFramedReader, FrameFramedWriter}; use crate::call::EventEnvelope; use crate::config::DynamicConfig; use crate::interface::session::{InterfaceEvent, InterfaceSession}; use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream}; -use crate::server::control_channel::{ControlChannelRouter, ALKNET_PREFIX}; +use crate::server::control_channel::{ + ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION, + ALKNET_PREFIX, +}; use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter}; use crate::transport::TransportKind; @@ -30,6 +35,8 @@ struct SshHandler { auth_limiter: AuthAttemptLimiter, authenticated_identity: Option, control_channel_router: ControlChannelRouter, + bridge_event_tx: Option>, + bridge_envelope_rx: Option>, connected_at: Instant, } @@ -76,6 +83,8 @@ impl SshHandler { auth_limiter: AuthAttemptLimiter::new(max_auth_attempts), authenticated_identity: None, control_channel_router: ControlChannelRouter::without_handler(), + bridge_event_tx: None, + bridge_envelope_rx: None, connected_at: Instant::now(), } } @@ -85,6 +94,20 @@ impl SshHandler { self.control_channel_router = router; self } + + fn with_bridge_channels( + mut self, + event_tx: mpsc::Sender, + envelope_rx: mpsc::Receiver, + ) -> Self { + self.bridge_event_tx = Some(event_tx); + self.bridge_envelope_rx = Some(envelope_rx); + self + } + + fn has_control_channel_bridge(&self) -> bool { + self.bridge_event_tx.is_some() && self.bridge_envelope_rx.is_some() + } } impl Drop for SshHandler { @@ -176,11 +199,36 @@ impl server::Handler for SshHandler { _session: &mut server::Session, ) -> Result { if host_to_connect.starts_with(ALKNET_PREFIX) { - if !self.control_channel_router.has_handler() { - return Ok(false); + if host_to_connect == ALKNET_CONTROL_DESTINATION && self.has_control_channel_bridge() { + let event_tx = self.bridge_event_tx.take().unwrap(); + let envelope_rx = self.bridge_envelope_rx.take().unwrap(); + let identity = self.authenticated_identity.clone(); + tokio::spawn(async move { + let stream = channel.into_stream(); + let (read_half, write_half) = tokio::io::split(stream); + run_control_channel_bridge( + read_half, + write_half, + identity, + event_tx, + envelope_rx, + ) + .await; + }); + let _ = (originator_address, originator_port); + return Ok(true); } - let _ = channel; - return Ok(true); + if self.control_channel_router.has_handler() { + if let Some(handler) = self.control_channel_router.take_handler() { + let stream: Box = Box::new(channel.into_stream()); + tokio::spawn(async move { + handler.handle_channel(stream).await; + }); + } + let _ = (originator_address, originator_port); + return Ok(true); + } + return Ok(false); } let identity = self @@ -529,6 +577,9 @@ impl SshInterface { let identity_provider = Arc::clone(&ssh_config.auth); let _forwarding = Arc::clone(&ssh_config.forwarding); + let (event_tx, event_rx) = mpsc::channel::(256); + let (envelope_tx, envelope_rx) = mpsc::channel::(256); + let handler = SshHandler::new( Arc::clone(&self.dynamic), identity_provider, @@ -537,7 +588,8 @@ impl SshInterface { transport, Arc::clone(&self.connection_limiter), self.max_auth_attempts, - ); + ) + .with_bridge_channels(event_tx, envelope_rx); let running = server::run_stream(Arc::clone(&self.config), stream, handler).await?; let handle = running.handle(); @@ -548,6 +600,8 @@ impl SshInterface { Ok(SshSession { handle, _join: join, + event_rx, + envelope_tx, }) } } @@ -576,6 +630,8 @@ impl StreamInterface for SshInterface { pub struct SshSession { handle: server::Handle, _join: tokio::task::JoinHandle<()>, + event_rx: mpsc::Receiver, + envelope_tx: mpsc::Sender, } impl SshSession { @@ -586,26 +642,95 @@ impl SshSession { #[async_trait] impl InterfaceSession for SshSession { - /// Stub for Phase 1 — always returns `None`. - /// - /// TODO: Bridge `alknet-control:0` channel events to call protocol - /// `InterfaceEvent` frames. Planned for Phase 2/3. async fn recv(&mut self) -> Option { - None + self.event_rx.recv().await } - /// Stub for Phase 1 — accepts silently and discards. - /// - /// TODO: Bridge outgoing `EventEnvelope` frames to the SSH channel - /// established by the call protocol. Planned for Phase 2/3. - async fn send(&mut self, _envelope: EventEnvelope) -> Result<()> { - Ok(()) + async fn send(&mut self, envelope: EventEnvelope) -> Result<()> { + self.envelope_tx + .send(envelope) + .await + .map_err(|_| anyhow::anyhow!("control channel bridge closed")) + } +} + +async fn run_control_channel_bridge( + read_half: R, + write_half: W, + identity: Option, + event_tx: mpsc::Sender, + mut envelope_rx: mpsc::Receiver, +) where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + let mut reader = FrameFramedReader::new(read_half); + let mut writer = FrameFramedWriter::new(write_half); + + loop { + tokio::select! { + frame = reader.read_frame() => { + match frame { + Ok(Some(envelope)) => { + let event = match &identity { + Some(id) => InterfaceEvent::with_identity(envelope, id.clone()), + None => InterfaceEvent::new(envelope), + }; + if event_tx.send(event).await.is_err() { + return; + } + } + Ok(None) => return, + Err(_) => return, + } + } + envelope = envelope_rx.recv() => { + match envelope { + Some(envelope) => { + if writer.write_frame(&envelope).await.is_err() { + return; + } + } + None => return, + } + } + } + } +} + +pub struct ControlChannelBridge { + identity: Option, +} + +impl ControlChannelBridge { + pub fn new(identity: Option) -> Self { + Self { identity } + } +} + +#[async_trait] +impl ControlChannelHandler for ControlChannelBridge { + async fn handle_channel(&self, stream: Box) { + let (event_tx, _event_rx) = mpsc::channel::(256); + let (_envelope_tx, envelope_rx) = mpsc::channel::(256); + + let identity = self.identity.clone(); + let (read_half, write_half) = tokio::io::split(stream); + tokio::spawn(run_control_channel_bridge( + read_half, + write_half, + identity, + event_tx, + envelope_rx, + )); } } #[cfg(test)] mod tests { use super::*; + use crate::call::frame::{FrameFramedReader, FrameFramedWriter}; + use tokio::io::duplex; #[test] fn ssh_interface_constructs_with_config() { @@ -742,4 +867,116 @@ mod tests { let result = iface.accept(stream, &raw_config).await; assert!(result.is_err()); } + + #[tokio::test] + async fn ssh_session_round_trip_event_envelope() { + let (client, server) = duplex(4096); + + let (event_tx, mut event_rx) = mpsc::channel::(256); + let (envelope_tx, envelope_rx) = mpsc::channel::(256); + + let identity = Identity { + id: "SHA256:test".to_string(), + scopes: vec![], + resources: std::collections::HashMap::new(), + }; + let identity_clone = identity.clone(); + + let (server_read, server_write) = tokio::io::split(server); + tokio::spawn(run_control_channel_bridge( + server_read, + server_write, + Some(identity_clone), + event_tx, + envelope_rx, + )); + + let (client_read, client_write) = tokio::io::split(client); + let mut client_reader = FrameFramedReader::new(client_read); + let mut client_writer = FrameFramedWriter::new(client_write); + + let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"})); + client_writer.write_frame(&envelope).await.unwrap(); + + let received_event = + tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received_event.envelope, envelope); + assert_eq!(received_event.identity.as_ref().unwrap().id, "SHA256:test"); + + let response = EventEnvelope::call_responded("req-1", serde_json::json!({"result": 42})); + envelope_tx.send(response.clone()).await.unwrap(); + + let read_back = tokio::time::timeout( + std::time::Duration::from_secs(2), + client_reader.read_frame(), + ) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(read_back, response); + } + + #[tokio::test] + async fn ssh_session_recv_without_identity() { + let (client, server) = duplex(4096); + + let (event_tx, mut event_rx) = mpsc::channel::(256); + let (_envelope_tx, envelope_rx) = mpsc::channel::(256); + + let (server_read, server_write) = tokio::io::split(server); + tokio::spawn(run_control_channel_bridge( + server_read, + server_write, + None, + event_tx, + envelope_rx, + )); + + let (client_read, client_write) = tokio::io::split(client); + let mut client_writer = FrameFramedWriter::new(client_write); + let _client_reader = FrameFramedReader::new(client_read); + + let envelope = EventEnvelope::call_requested("req-2", serde_json::json!({"op": "no-id"})); + client_writer.write_frame(&envelope).await.unwrap(); + + let received_event = + tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received_event.envelope, envelope); + assert!(received_event.identity.is_none()); + } + + #[tokio::test] + async fn control_channel_router_with_handler_routes_data() { + let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let called_clone = called.clone(); + + struct TrackingHandler { + called: std::sync::Arc, + } + + #[async_trait] + impl ControlChannelHandler for TrackingHandler { + async fn handle_channel(&self, _stream: Box) { + self.called.store(true, std::sync::atomic::Ordering::SeqCst); + } + } + + let router = ControlChannelRouter::with_handler(Box::new(TrackingHandler { + called: called_clone, + })); + assert!(router.has_handler()); + + let (_client, server) = duplex(64); + let stream: Box = Box::new(server); + let result = router.route(stream).await; + assert!(result.is_ok()); + assert!(called.load(std::sync::atomic::Ordering::SeqCst)); + } } diff --git a/crates/alknet-core/src/server/control_channel.rs b/crates/alknet-core/src/server/control_channel.rs index 9ba316a..6455f19 100644 --- a/crates/alknet-core/src/server/control_channel.rs +++ b/crates/alknet-core/src/server/control_channel.rs @@ -60,6 +60,10 @@ impl ControlChannelRouter { )), } } + + pub fn take_handler(&mut self) -> Option> { + self.handler.take() + } } #[cfg(test)]