diff --git a/src/main.rs b/src/main.rs index c554acf..97d825f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,33 @@ +use std::net::SocketAddr; use std::sync::Arc; +use anyhow::{Context, Result}; use arc_swap::ArcSwap; -use reverse_proxy::admin::{start_admin_socket, AdminSocket}; -use reverse_proxy::cli; -use reverse_proxy::config::{ConfigReloadHandle, DynamicConfig}; -use reverse_proxy::health::start_health_check_listener; -use reverse_proxy::proxy::{create_http_client, create_https_client, proxy_router, ProxyState}; -use reverse_proxy::rate_limit::{start_eviction_task, RateLimiter}; -use reverse_proxy::shutdown::GracefulShutdown; -use reverse_proxy::tls::redirect::start_http_redirect_listener; use tokio::net::TcpListener; -use tracing::info; +use tokio_rustls::TlsAcceptor; +use tracing::{error, info, warn}; + +use reverse_proxy::admin::{start_admin_socket, AdminSocket, AdminSocketError}; +use reverse_proxy::cli; +use reverse_proxy::config::ConfigReloadHandle; +use reverse_proxy::config::DynamicConfig; +use reverse_proxy::health; +use reverse_proxy::logging; +use reverse_proxy::proxy::{build_router, create_http_client, create_https_client, ProxyState}; +use reverse_proxy::rate_limit::{start_eviction_task, RateLimiter}; +use reverse_proxy::server::serve_https_listener; +use reverse_proxy::shutdown::GracefulShutdown; +use reverse_proxy::tls::acceptor::{setup_tls, TlsMode}; +use reverse_proxy::tls::redirect; + +fn notify_systemd_ready() { + if std::env::var("NOTIFY_SOCKET").is_ok() { + match sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) { + Ok(()) => info!("sd_notify: READY=1 sent"), + Err(e) => warn!("sd_notify: failed to notify systemd: {}", e), + } + } +} fn main() { let args = cli::parse(); @@ -34,27 +51,137 @@ fn main() { rt.block_on(async move { if let Err(e) = run_server(loaded_config, &args.config).await { - tracing::error!("fatal error: {e:#}"); + error!("fatal error: {e:#}"); std::process::exit(1); } }); } -async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> anyhow::Result<()> { - let shutdown = Arc::new(GracefulShutdown::new( - loaded_config.static_config.shutdown_timeout_secs, - )); +async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> Result<()> { + logging::init(&loaded_config.static_config.logging).context("failed to initialize logging")?; + + info!("reverse-proxy starting"); let dynamic_config: DynamicConfig = loaded_config.dynamic_config; let config_arc = Arc::new(ArcSwap::from_pointee(dynamic_config)); + + let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone())); + + let http_client = create_http_client(); + let https_client = create_https_client(); + + let proxy_state = Arc::new(ProxyState { + config: config_arc.clone(), + http_client, + https_client, + }); + let reload_handle = Arc::new(ConfigReloadHandle::new( config_arc.clone(), loaded_config.static_config.clone(), )); - reverse_proxy::logging::init(&loaded_config.static_config.logging)?; + if loaded_config.static_config.health_check_port > 0 { + let (health_addr, _health_handle) = + health::start_health_check_listener(loaded_config.static_config.health_check_port) + .await + .context("failed to bind health check port")?; + info!(addr = %health_addr, "Health check listener bound"); + } - info!("reverse-proxy starting"); + if !loaded_config.static_config.admin_socket_path.is_empty() { + let admin_socket = Arc::new(AdminSocket::new( + loaded_config.static_config.admin_socket_path.clone(), + reload_handle.clone(), + config_path.to_string(), + )); + let admin_socket_clone = admin_socket.clone(); + tokio::spawn(async move { + if let Err(e) = start_admin_socket(admin_socket_clone).await { + match e { + AdminSocketError::Disabled => {} + AdminSocketError::SocketInUse(path) => { + warn!("admin socket disabled: {} is in use", path); + } + AdminSocketError::BindFailed(msg) => { + error!("admin socket bind failed: {}", msg); + } + AdminSocketError::Io(e) => { + error!("admin socket IO error: {}", e); + } + } + } + }); + } + + let mut bound_listeners = Vec::new(); + + for listener_config in &loaded_config.static_config.listeners { + if listener_config.http_port > 0 { + let (http_addr, _http_handle) = redirect::start_http_redirect_listener(listener_config) + .await + .context(format!( + "failed to bind HTTP redirect listener on {}:{}", + listener_config.bind_addr, listener_config.http_port + ))?; + info!(addr = %http_addr, "HTTP redirect listener bound"); + } + + let https_addr: SocketAddr = format!( + "{}:{}", + listener_config.bind_addr, listener_config.https_port + ) + .parse() + .context(format!( + "invalid HTTPS bind address {}:{}", + listener_config.bind_addr, listener_config.https_port + ))?; + + let https_tcp = TcpListener::bind(https_addr).await.context(format!( + "failed to bind HTTPS listener on {}:{}", + listener_config.bind_addr, listener_config.https_port + ))?; + + let local_addr = https_tcp.local_addr()?; + info!(addr = %local_addr, "HTTPS listener bound"); + + bound_listeners.push((listener_config.clone(), https_tcp)); + } + + let mut tls_acceptors = Vec::new(); + for (listener_config, _) in &bound_listeners { + let tls_mode = setup_tls(&listener_config.tls).context(format!( + "failed to setup TLS for listener {}", + listener_config.bind_addr + ))?; + + match tls_mode { + TlsMode::Manual(server_config) => { + let acceptor = TlsAcceptor::from(server_config); + tls_acceptors.push(acceptor); + info!( + addr = %listener_config.bind_addr, + "Manual TLS configured" + ); + } + TlsMode::Acme { + default_config, + challenge_config: _, + resolver: _, + } => { + let acceptor = TlsAcceptor::from(default_config); + tls_acceptors.push(acceptor); + info!( + addr = %listener_config.bind_addr, + "ACME TLS configured" + ); + } + } + } + + let shutdown = Arc::new(GracefulShutdown::new( + loaded_config.static_config.shutdown_timeout_secs, + )); reverse_proxy::shutdown::register_signal_handlers( shutdown.clone(), @@ -62,80 +189,39 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> anyh config_path.to_string(), )?; - let rate_limiter = Arc::new(RateLimiter::new(config_arc.clone())); - - let proxy_state = Arc::new(ProxyState { - config: config_arc.clone(), - http_client: create_http_client(), - https_client: create_https_client(), - }); - - let mut server_handles: Vec>> = Vec::new(); - let mut tcp_listeners: Vec = Vec::new(); - - if loaded_config.static_config.health_check_port > 0 { - let (addr, handle) = - start_health_check_listener(loaded_config.static_config.health_check_port).await?; - info!(addr = %addr, "Health check listener started"); - server_handles.push(handle); - } - - let admin_socket = Arc::new(AdminSocket::new( - loaded_config.static_config.admin_socket_path.clone(), - reload_handle.clone(), - config_path.to_string(), - )); - - let admin_handle = tokio::spawn(start_admin_socket(admin_socket)); - - let eviction_handle = start_eviction_task( + let _eviction_handle = start_eviction_task( rate_limiter.clone(), std::time::Duration::from_secs(60), std::time::Duration::from_secs(300), ); - for listener_config in &loaded_config.static_config.listeners { - if listener_config.http_port > 0 { - let (addr, handle) = start_http_redirect_listener(listener_config).await?; - info!(addr = %addr, "HTTP redirect listener started"); - server_handles.push(handle); - } - - let https_bind_addr: std::net::SocketAddr = format!( - "{}:{}", - listener_config.bind_addr, listener_config.https_port - ) - .parse() - .map_err(|e| { - anyhow::anyhow!( - "invalid bind address {}:{}: {}", - listener_config.bind_addr, - listener_config.https_port, - e - ) - })?; - - let tcp_listener = TcpListener::bind(https_bind_addr).await?; - let local_addr = tcp_listener.local_addr()?; - info!(addr = %local_addr, "HTTPS listener bound"); - tcp_listeners.push(tcp_listener); - } - - let app = proxy_router(proxy_state); - let app = reverse_proxy::proxy::router_with_body_limit(app, config_arc); + let app = build_router(proxy_state.clone(), config_arc.clone(), rate_limiter); let mut https_server_handles = Vec::new(); - for tcp_listener in tcp_listeners { + + for ((listener_config, tcp_listener), tls_acceptor) in + bound_listeners.into_iter().zip(tls_acceptors.into_iter()) + { let shutdown_rx = shutdown.subscribe(); - let handle = tokio::spawn(serve_with_graceful_shutdown( + + let handle = tokio::spawn(serve_https_listener( tcp_listener, + tls_acceptor, app.clone(), shutdown_rx, )); + + info!( + bind_addr = %listener_config.bind_addr, + https_port = listener_config.https_port, + "HTTPS listener accepting connections" + ); + https_server_handles.push(handle); } - info!("reverse-proxy ready"); + info!("all listeners started"); + notify_systemd_ready(); let mut shutdown_rx = shutdown.subscribe(); shutdown_rx @@ -145,30 +231,9 @@ async fn run_server(loaded_config: cli::LoadedConfig, config_path: &str) -> anyh info!("shutdown signal received, starting graceful shutdown"); - drop(https_server_handles); - - for handle in server_handles { + for handle in https_server_handles { handle.abort(); } - admin_handle.abort(); - eviction_handle.abort(); - - info!("all connections closed, exiting"); - std::process::exit(0); -} - -async fn serve_with_graceful_shutdown( - listener: TcpListener, - app: axum::Router, - mut shutdown_rx: tokio::sync::watch::Receiver, -) -> anyhow::Result<()> { - let local_addr = listener.local_addr()?; - axum::serve(listener, app) - .with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - info!(addr = %local_addr, "HTTPS server shutting down"); - }) - .await - .map_err(anyhow::Error::from) + Ok(()) } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 17ae170..d8ce9c2 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -11,6 +11,17 @@ use std::sync::Arc; use arc_swap::ArcSwap; use crate::config::DynamicConfig; +use crate::rate_limit::RateLimiter; + +pub fn build_router( + proxy_state: Arc, + config: Arc>, + rate_limiter: Arc, +) -> axum::Router { + let router = proxy_router(proxy_state); + let router = router_with_body_limit(router, config); + router_with_rate_limit(router, rate_limiter) +} pub fn router_with_body_limit( router: axum::Router, @@ -21,3 +32,13 @@ pub fn router_with_body_limit( body_limit::body_limit_middleware, )) } + +pub fn router_with_rate_limit( + router: axum::Router, + rate_limiter: Arc, +) -> axum::Router { + router.layer(axum::middleware::from_fn_with_state( + rate_limiter, + crate::rate_limit::rate_limit_middleware, + )) +} diff --git a/src/server.rs b/src/server.rs index c0cb464..6e3dc63 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,240 +1,73 @@ use std::net::SocketAddr; -use std::sync::Arc; -use anyhow::{Context, Result}; -use arc_swap::ArcSwap; use axum::extract::ConnectInfo; use axum::http::Request; use axum::response::Response; +use axum::Router; use hyper_util::rt::TokioExecutor; use hyper_util::service::TowerToHyperService; use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; use tower::Service; -use tracing::{error, info, warn}; +use tracing::{error, warn}; -use crate::admin::{start_admin_socket, AdminSocket, AdminSocketError}; -use crate::config::dynamic_config::DynamicConfig; -use crate::config::static_config::StaticConfig; -use crate::config::ConfigReloadHandle; -use crate::health; -use crate::logging; -use crate::proxy::{create_http_client, create_https_client, proxy_router, ProxyState}; -use crate::rate_limit::{start_eviction_task, RateLimiter}; -use crate::tls::acceptor::{setup_tls, TlsMode}; -use crate::tls::redirect; - -fn notify_systemd_ready() { - if std::env::var("NOTIFY_SOCKET").is_ok() { - match sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) { - Ok(()) => info!("sd_notify: READY=1 sent"), - Err(e) => warn!("sd_notify: failed to notify systemd: {}", e), - } - } -} - -pub async fn run(static_config: StaticConfig, dynamic_config: DynamicConfig) -> Result<()> { - logging::init(&static_config.logging).context("failed to initialize logging")?; - - info!("reverse-proxy starting"); - - let dynamic_config = Arc::new(ArcSwap::from_pointee(dynamic_config)); - - let http_client = create_http_client(); - let https_client = create_https_client(); - - let rate_limiter = Arc::new(RateLimiter::new(dynamic_config.clone())); - - let proxy_state = Arc::new(ProxyState { - config: dynamic_config.clone(), - http_client, - https_client, - }); - - if static_config.health_check_port > 0 { - let (health_addr, _health_handle) = - health::start_health_check_listener(static_config.health_check_port) - .await - .context("failed to bind health check port")?; - info!(addr = %health_addr, "Health check listener started"); - } - - let reload_handle = Arc::new(ConfigReloadHandle::new( - dynamic_config.clone(), - static_config.clone(), - )); - - if !static_config.admin_socket_path.is_empty() { - let admin_socket = Arc::new(AdminSocket::new( - static_config.admin_socket_path.clone(), - reload_handle.clone(), - std::env::args().next().unwrap_or_default(), - )); - let admin_socket_clone = admin_socket.clone(); - tokio::spawn(async move { - if let Err(e) = start_admin_socket(admin_socket_clone).await { - match e { - AdminSocketError::Disabled => {} - AdminSocketError::SocketInUse(path) => { - warn!("admin socket disabled: {} is in use", path); - } - AdminSocketError::BindFailed(msg) => { - error!("admin socket bind failed: {}", msg); - } - AdminSocketError::Io(e) => { - error!("admin socket IO error: {}", e); - } - } - } - }); - } - - let _eviction_handle = start_eviction_task( - rate_limiter.clone(), - std::time::Duration::from_secs(60), - std::time::Duration::from_secs(300), - ); - - let mut bound_https_listeners = Vec::new(); - - for listener_config in &static_config.listeners { - let https_addr: SocketAddr = format!( - "{}:{}", - listener_config.bind_addr, listener_config.https_port - ) - .parse() - .context(format!( - "invalid HTTPS bind address {}:{}", - listener_config.bind_addr, listener_config.https_port - ))?; - - let https_tcp = TcpListener::bind(https_addr).await.context(format!( - "failed to bind HTTPS listener on {}:{}", - listener_config.bind_addr, listener_config.https_port - ))?; - - let local_addr = https_tcp.local_addr()?; - info!(addr = %local_addr, "HTTPS listener bound"); - - bound_https_listeners.push((listener_config.clone(), https_tcp)); - } - - for listener_config in &static_config.listeners { - if listener_config.http_port > 0 { - let (http_addr, _http_handle) = redirect::start_http_redirect_listener(listener_config) - .await - .context(format!( - "failed to start HTTP redirect listener for {}:{}", - listener_config.bind_addr, listener_config.http_port - ))?; - info!(addr = %http_addr, "HTTP redirect listener started"); - } - } - - let mut tls_acceptors = Vec::new(); - for (listener_config, _) in &bound_https_listeners { - let tls_mode = setup_tls(&listener_config.tls).context(format!( - "failed to setup TLS for listener {}", - listener_config.bind_addr - ))?; - - match tls_mode { - TlsMode::Manual(server_config) => { - let acceptor = TlsAcceptor::from(server_config); - tls_acceptors.push(acceptor); - info!( - addr = %listener_config.bind_addr, - "Manual TLS configured" - ); - } - TlsMode::Acme { - default_config, - challenge_config: _, - resolver: _, - } => { - let acceptor = TlsAcceptor::from(default_config); - tls_acceptors.push(acceptor); - info!( - addr = %listener_config.bind_addr, - "ACME TLS configured" - ); - } - } - } - - for ((listener_config, tcp_listener), tls_acceptor) in bound_https_listeners - .into_iter() - .zip(tls_acceptors.into_iter()) - { - let state = proxy_state.clone(); - - tokio::spawn(serve_https_listener(tcp_listener, tls_acceptor, state)); - - info!( - bind_addr = %listener_config.bind_addr, - https_port = listener_config.https_port, - "HTTPS listener accepting connections" - ); - } - - info!("all listeners started"); - notify_systemd_ready(); - - tokio::signal::ctrl_c() - .await - .context("failed to listen for ctrl-c")?; - info!("shutting down"); - - Ok(()) -} - -async fn serve_https_listener( +pub async fn serve_https_listener( tcp_listener: TcpListener, tls_acceptor: TlsAcceptor, - state: Arc, + router: Router, + mut shutdown_rx: tokio::sync::watch::Receiver, ) { - let router = proxy_router(state); + let local_addr = tcp_listener.local_addr(); loop { - let (tcp_stream, remote_addr) = match tcp_listener.accept().await { - Ok(conn) => conn, - Err(e) => { - error!(error = %e, "failed to accept TCP connection"); - continue; + tokio::select! { + accept_result = tcp_listener.accept() => { + let (tcp_stream, remote_addr) = match accept_result { + Ok(conn) => conn, + Err(e) => { + error!(error = %e, "failed to accept TCP connection"); + continue; + } + }; + + let tls_acceptor = tls_acceptor.clone(); + let router = router.clone(); + + tokio::spawn(async move { + let tls_stream = match tls_acceptor.accept(tcp_stream).await { + Ok(stream) => stream, + Err(e) => { + warn!(error = %e, "TLS handshake failed"); + return; + } + }; + + let svc = ConnectInfoService { + inner: router.into_service::(), + remote_addr, + }; + + let svc = TowerToHyperService::new(svc); + let io = hyper_util::rt::TokioIo::new(tls_stream); + + if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(io, svc) + .await + { + if e.to_string().contains("incomplete message") { + return; + } + error!(error = %e, "HTTPS connection error"); + } + }); } - }; - - let tls_acceptor = tls_acceptor.clone(); - let router = router.clone(); - - tokio::spawn(async move { - let tls_stream = match tls_acceptor.accept(tcp_stream).await { - Ok(stream) => stream, - Err(e) => { - warn!(error = %e, "TLS handshake failed"); - return; + _ = shutdown_rx.changed() => { + if let Ok(addr) = local_addr { + tracing::info!(addr = %addr, "HTTPS listener shutting down"); } - }; - - let svc = ConnectInfoService { - inner: router.into_service::(), - remote_addr, - }; - - let svc = TowerToHyperService::new(svc); - - let io = hyper_util::rt::TokioIo::new(tls_stream); - - if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(io, svc) - .await - { - if e.to_string().contains("incomplete message") { - return; - } - error!(error = %e, "HTTPS connection error"); + break; } - }); + } } }