diff --git a/Cargo.lock b/Cargo.lock index d2ac1c0..8275536 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,8 +77,12 @@ dependencies = [ "anyhow", "arc-swap", "async-trait", + "axum", "futures", "hex", + "http-body-util", + "hyper", + "hyper-util", "ipnetwork", "iroh", "rand 0.10.1", @@ -97,6 +101,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-util", + "tower", "tracing", "url", "webpki-roots 0.26.11", @@ -402,6 +407,58 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.4.1", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.1", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backon" version = "1.6.0" @@ -2380,6 +2437,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md5" version = "0.7.0" @@ -2392,6 +2455,12 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4111,6 +4180,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -4738,6 +4818,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] diff --git a/crates/alknet-core/Cargo.toml b/crates/alknet-core/Cargo.toml index 158371f..d493596 100644 --- a/crates/alknet-core/Cargo.toml +++ b/crates/alknet-core/Cargo.toml @@ -14,6 +14,7 @@ default = [] tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"] iroh = ["dep:iroh", "dep:url"] acme = ["dep:rustls-acme", "dep:futures", "tls"] +http = ["dep:axum", "dep:hyper", "dep:hyper-util", "dep:tower", "dep:http-body-util"] irpc = [] testutil = [] transport-traits = [] @@ -40,9 +41,14 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" sha2 = "0.10" hex = "0.4" +axum = { version = "0.8", optional = true } +hyper = { version = "1", optional = true } +hyper-util = { version = "0.1", features = ["tokio", "server", "service"], optional = true } +tower = { version = "0.5", optional = true } +http-body-util = { version = "0.1", optional = true } [dev-dependencies] -alknet-core = { path = ".", features = ["testutil", "tls", "iroh"] } +alknet-core = { path = ".", features = ["testutil", "tls", "iroh", "http"] } tempfile = "3" rcgen = "0.14" rand_core = "0.6" diff --git a/crates/alknet-core/src/http/auth.rs b/crates/alknet-core/src/http/auth.rs new file mode 100644 index 0000000..b1357d3 --- /dev/null +++ b/crates/alknet-core/src/http/auth.rs @@ -0,0 +1,182 @@ +use axum::extract::Request; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; + +use crate::auth::{AuthToken, Identity, IdentityProvider}; + +#[derive(Clone)] +pub struct IdentityExt(pub Identity); + +pub async fn auth_middleware( + axum::extract::State(identity_provider): axum::extract::State< + std::sync::Arc, + >, + mut request: Request, + next: Next, +) -> Response { + let auth_header = request + .headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()); + + let token_str = match auth_header { + Some(h) if h.starts_with("Bearer ") => &h[7..], + _ => { + return axum::http::StatusCode::UNAUTHORIZED.into_response(); + } + }; + + let token = AuthToken { + raw: token_str.as_bytes().to_vec(), + }; + + match identity_provider.resolve_from_token(&token) { + Some(identity) => { + request.extensions_mut().insert(IdentityExt(identity)); + next.run(request).await + } + None => axum::http::StatusCode::UNAUTHORIZED.into_response(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::Body; + use axum::http::{Request as HttpRequest, StatusCode}; + use axum::routing::get; + use axum::Router; + use std::collections::HashMap; + use std::sync::Arc; + use tower::ServiceExt; + + struct MockIdentityProvider { + valid_token: String, + identity: Identity, + } + + impl IdentityProvider for MockIdentityProvider { + fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option { + None + } + + fn resolve_from_token(&self, token: &AuthToken) -> Option { + let token_str = String::from_utf8_lossy(&token.raw); + if token_str == self.valid_token { + Some(self.identity.clone()) + } else { + None + } + } + } + + fn make_provider(valid_token: &str) -> Arc { + let identity = Identity { + id: "test-user".to_string(), + scopes: vec!["relay:connect".to_string()], + resources: HashMap::new(), + }; + Arc::new(MockIdentityProvider { + valid_token: valid_token.to_string(), + identity, + }) + } + + #[tokio::test] + async fn auth_middleware_extracts_bearer_token() { + let provider = make_provider("alk_validtoken1"); + let app = Router::new() + .route( + "/test", + get(|request: Request| async move { + let has_identity = request.extensions().get::().is_some(); + if has_identity { + StatusCode::OK.into_response() + } else { + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + }), + ) + .layer(axum::middleware::from_fn_with_state( + provider, + auth_middleware, + )); + + let req = HttpRequest::builder() + .uri("/test") + .header("authorization", "Bearer alk_validtoken1") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn auth_middleware_returns_401_for_missing_token() { + let provider = make_provider("alk_validtoken1"); + let app = Router::new() + .route("/test", get(|| async { StatusCode::OK.into_response() })) + .layer(axum::middleware::from_fn_with_state( + provider, + auth_middleware, + )); + + let req = HttpRequest::builder() + .uri("/test") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn auth_middleware_returns_401_for_invalid_token() { + let provider = make_provider("alk_validtoken1"); + let app = Router::new() + .route("/test", get(|| async { StatusCode::OK.into_response() })) + .layer(axum::middleware::from_fn_with_state( + provider, + auth_middleware, + )); + + let req = HttpRequest::builder() + .uri("/test") + .header("authorization", "Bearer alk_wrongtoken1") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn auth_middleware_attaches_identity_to_extensions() { + let provider = make_provider("alk_testidentity1"); + let app = Router::new() + .route( + "/test", + get(|request: Request| async move { + let identity = request.extensions().get::().unwrap(); + identity.0.id.clone() + }), + ) + .layer(axum::middleware::from_fn_with_state( + provider, + auth_middleware, + )); + + let req = HttpRequest::builder() + .uri("/test") + .header("authorization", "Bearer alk_testidentity1") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); + assert_eq!(&body[..], b"test-user"); + } +} diff --git a/crates/alknet-core/src/http/mod.rs b/crates/alknet-core/src/http/mod.rs new file mode 100644 index 0000000..745ac71 --- /dev/null +++ b/crates/alknet-core/src/http/mod.rs @@ -0,0 +1,5 @@ +pub mod auth; +pub mod router; + +pub use auth::IdentityExt; +pub use router::{build_router, serve_connection}; diff --git a/crates/alknet-core/src/http/router.rs b/crates/alknet-core/src/http/router.rs new file mode 100644 index 0000000..32653b4 --- /dev/null +++ b/crates/alknet-core/src/http/router.rs @@ -0,0 +1,150 @@ +use std::sync::Arc; + +use axum::response::IntoResponse; +use axum::Router; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; + +use crate::auth::IdentityProvider; +use crate::http::auth::auth_middleware; + +async fn default_404() -> impl IntoResponse { + axum::http::StatusCode::NOT_FOUND +} + +pub fn build_router(identity_provider: Arc) -> Router { + Router::new() + .fallback(default_404) + .layer(axum::middleware::from_fn_with_state( + identity_provider, + auth_middleware, + )) +} + +pub async fn serve_connection(stream: S, identity_provider: Arc) +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let app = build_router(identity_provider); + let io = TokioIo::new(stream); + + let hyper_service = TowerToHyperService::new(app.into_service::()); + + let result = Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(io, hyper_service) + .await; + + if let Err(e) = result { + tracing::debug!("http connection error: {e}"); + } +} + +pub async fn serve_connection_from_reader( + reader: BufReader, + identity_provider: Arc, +) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + serve_connection(reader, identity_provider).await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::{AuthToken, Identity}; + use axum::body::Body; + use axum::http::{Request as HttpRequest, StatusCode}; + use axum::response::IntoResponse; + use std::collections::HashMap; + use std::sync::Arc; + use tower::ServiceExt; + + struct NullIdentityProvider; + + impl IdentityProvider for NullIdentityProvider { + fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option { + None + } + + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + None + } + } + + #[tokio::test] + async fn default_404_handler_returns_not_found() { + let provider: Arc = Arc::new(MockValidProvider); + let app = build_router(provider); + + let req = HttpRequest::builder() + .uri("/anything") + .header("authorization", "Bearer alk_sometoken1") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn missing_auth_returns_401_before_404() { + let provider: Arc = Arc::new(MockValidProvider); + let app = build_router(provider); + + let req = HttpRequest::builder() + .uri("/anything") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn invalid_auth_returns_401_before_404() { + let provider: Arc = Arc::new(NullIdentityProvider); + let app = build_router(provider); + + let req = HttpRequest::builder() + .uri("/anything") + .header("authorization", "Bearer alk_sometoken1") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn unmatched_route_returns_404_with_valid_auth() { + let provider: Arc = Arc::new(MockValidProvider); + let app = build_router(provider); + + let req = HttpRequest::builder() + .uri("/v1/unknown/op") + .header("authorization", "Bearer alk_valid") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + struct MockValidProvider; + + impl IdentityProvider for MockValidProvider { + fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option { + None + } + + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + Some(Identity { + id: "test".to_string(), + scopes: vec![], + resources: HashMap::new(), + }) + } + } +} diff --git a/crates/alknet-core/src/interface/http.rs b/crates/alknet-core/src/interface/http.rs index 109ed2f..ad26466 100644 --- a/crates/alknet-core/src/interface/http.rs +++ b/crates/alknet-core/src/interface/http.rs @@ -27,6 +27,13 @@ impl MessageInterface for HttpInterface { } } +#[cfg(feature = "http")] +impl HttpInterface { + pub fn build_router(&self) -> axum::Router { + crate::http::router::build_router(Arc::clone(&self.identity_provider)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -42,4 +49,18 @@ mod tests { registry, }; } + + #[cfg(feature = "http")] + #[test] + fn http_interface_builds_router() { + let registry = Arc::new(crate::call::OperationRegistry::new()); + let iface = HttpInterface { + identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new( + arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())), + ))), + env: OperationEnv::local(crate::call::OperationRegistry::new()), + registry, + }; + let _router = iface.build_router(); + } } diff --git a/crates/alknet-core/src/lib.rs b/crates/alknet-core/src/lib.rs index c4e41f5..b7d76e2 100644 --- a/crates/alknet-core/src/lib.rs +++ b/crates/alknet-core/src/lib.rs @@ -62,6 +62,12 @@ pub mod server; pub mod socks5; pub mod transport; +#[cfg(feature = "http")] +pub mod http; + +#[cfg(feature = "http")] +pub use http::IdentityExt; + #[cfg(feature = "testutil")] pub mod testutil; diff --git a/crates/alknet-core/src/server/mod.rs b/crates/alknet-core/src/server/mod.rs index 90cb6c1..f03f2eb 100644 --- a/crates/alknet-core/src/server/mod.rs +++ b/crates/alknet-core/src/server/mod.rs @@ -28,5 +28,6 @@ pub use serve::{ pub use crate::transport::TransportKind; pub use stealth::{ - detect_protocol, send_fake_nginx_404, validate_stealth_config, ProtocolDetection, + detect_protocol, handle_http_stealth, send_fake_nginx_404, validate_stealth_config, + ProtocolDetection, }; diff --git a/crates/alknet-core/src/server/serve.rs b/crates/alknet-core/src/server/serve.rs index 5f96a98..297b936 100644 --- a/crates/alknet-core/src/server/serve.rs +++ b/crates/alknet-core/src/server/serve.rs @@ -15,7 +15,9 @@ use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; +use crate::auth::identity::ConfigIdentityProvider; use crate::auth::keys::KeySource; +use crate::auth::IdentityProvider; use crate::config::{ConfigReloadHandle, DynamicConfig}; use crate::error::ConfigError; use crate::interface::pairs::is_valid_pair; @@ -522,6 +524,7 @@ struct ActiveSession { pub struct Server { config: Arc, dynamic: Arc>, + identity_provider: Arc, connection_limiter: Arc, outbound_proxy: Option, listeners: Vec, @@ -551,10 +554,13 @@ impl Server { let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); let dynamic = Arc::new(ArcSwap::new(Arc::new(dynamic_config))); + let identity_provider: Arc = + Arc::new(ConfigIdentityProvider::new(Arc::clone(&dynamic))); Ok(Self { config, dynamic, + identity_provider, connection_limiter, outbound_proxy: static_config.proxy_config, listeners: static_config.listeners, @@ -734,12 +740,20 @@ impl Server { let config = Arc::clone(&server.config); let sessions = Arc::clone(&server.sessions); + let identity_provider = Arc::clone(&server.identity_provider); let transport_is_tls = matches!(transport_kind, TransportKind::Tls { .. }); tokio::spawn(async move { - let result = - handle_connection(stream, config, handler, sessions, stealth, transport_is_tls) - .await; + let result = handle_connection( + stream, + config, + handler, + sessions, + identity_provider, + stealth, + transport_is_tls, + ) + .await; if let Err(e) = result { warn!("connection error: {e}"); @@ -765,6 +779,7 @@ async fn handle_connection( config: Arc, handler: ServerHandler, sessions: Arc>>, + identity_provider: Arc, stealth: bool, transport_is_tls: bool, ) -> Result<(), anyhow::Error> @@ -772,10 +787,10 @@ where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { if stealth && transport_is_tls { - let (protocol, mut reader) = stealth::detect_protocol(stream).await; + let (protocol, reader) = stealth::detect_protocol(stream).await; match protocol { ProtocolDetection::Http => { - stealth::send_fake_nginx_404(&mut reader).await; + stealth::handle_http_stealth(reader, identity_provider).await; return Ok(()); } ProtocolDetection::Ssh => { diff --git a/crates/alknet-core/src/server/stealth.rs b/crates/alknet-core/src/server/stealth.rs index 1481205..37db011 100644 --- a/crates/alknet-core/src/server/stealth.rs +++ b/crates/alknet-core/src/server/stealth.rs @@ -2,12 +2,17 @@ //! //! When stealth mode is enabled with TLS transport, the server peeks at the first //! bytes after the TLS handshake to determine whether the client is speaking SSH -//! or HTTP. Non-SSH connections receive a fake nginx 404 response, making the -//! server appear as an ordinary web server to port scanners and DPI systems. -//! See ADR-017. +//! or HTTP. When the `http` feature is enabled, detected HTTP traffic is routed to +//! the axum router. When `http` is disabled, non-SSH connections receive a fake +//! nginx 404 response, making the server appear as an ordinary web server to port +//! scanners and DPI systems. See ADR-017. + +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use crate::auth::IdentityProvider; + const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-"; const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n"; @@ -52,6 +57,26 @@ where let _ = reader.get_mut().shutdown().await; } +#[cfg(feature = "http")] +pub async fn handle_http_stealth( + reader: BufReader, + identity_provider: Arc, +) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + crate::http::router::serve_connection_from_reader(reader, identity_provider).await +} + +#[cfg(not(feature = "http"))] +pub async fn handle_http_stealth( + mut reader: BufReader, + _identity_provider: Arc, +) where + S: AsyncRead + AsyncWrite + Unpin, +{ + send_fake_nginx_404(&mut reader).await +} + pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> { if stealth && !transport_is_tls { return Err("stealth mode requires TLS transport (--transport tls)"); @@ -232,4 +257,60 @@ mod tests { let result = client.read(&mut extra).await; assert!(result.is_err() || result.unwrap() == 0); } + + #[cfg(feature = "http")] + #[tokio::test] + async fn stealth_handoff_routes_http_to_axum() { + use crate::auth::{AuthToken, IdentityProvider}; + use std::sync::Arc; + use tokio::io::AsyncWriteExt; + + struct NullProvider; + + impl IdentityProvider for NullProvider { + fn resolve_from_fingerprint( + &self, + _fingerprint: &str, + ) -> Option { + None + } + + fn resolve_from_token(&self, _token: &AuthToken) -> Option { + None + } + } + + let (client, server) = duplex(4096); + let (mut client_read, mut client_write) = tokio::io::split(client); + + client_write + .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + drop(client_write); + + let (detection, reader) = detect_protocol(server).await; + assert_eq!(detection, ProtocolDetection::Http); + + let provider: Arc = Arc::new(NullProvider); + let handle = tokio::spawn(async move { + handle_http_stealth(reader, provider).await; + }); + + let mut buf = Vec::new(); + tokio::io::AsyncReadExt::read_to_end(&mut client_read, &mut buf) + .await + .unwrap(); + let response = String::from_utf8_lossy(&buf); + assert!( + response.contains("401"), + "expected 401 from axum auth middleware, got: {response}" + ); + assert!( + !response.contains("nginx"), + "should not contain fake nginx response when http feature is enabled" + ); + + let _ = handle.await; + } }