Add read timeout and line length limit to admin socket (ADR-027)

This commit is contained in:
2026-06-12 14:03:22 +00:00
parent 54f1725173
commit 4c6b55a780

View File

@@ -164,14 +164,20 @@ async fn is_socket_active(path: &str) -> bool {
} }
async fn handle_connection(stream: tokio::net::UnixStream, admin_socket: Arc<AdminSocket>) { async fn handle_connection(stream: tokio::net::UnixStream, admin_socket: Arc<AdminSocket>) {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
let (reader, mut writer) = stream.into_split(); let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader); let mut reader = BufReader::new(reader.take(4096));
let mut line = String::new(); let mut line = String::new();
match reader.read_line(&mut line).await { let read_result = tokio::time::timeout(
Ok(0) | Err(_) => { std::time::Duration::from_secs(5),
reader.read_line(&mut line),
)
.await;
match read_result {
Ok(Ok(0)) | Ok(Err(_)) => {
let _ = writer let _ = writer
.write_all( .write_all(
format!( format!(
@@ -187,7 +193,42 @@ async fn handle_connection(stream: tokio::net::UnixStream, admin_socket: Arc<Adm
.await; .await;
return; return;
} }
_ => {} Err(_) => {
tracing::debug!("admin socket connection timed out");
let _ = writer
.write_all(
format!(
"{}\n",
serde_json::to_string(&ErrorResponse {
status: "error",
message: "read timeout".to_string(),
})
.unwrap()
)
.as_bytes(),
)
.await;
return;
}
Ok(Ok(n)) => {
if !line.ends_with('\n') && n > 0 {
tracing::warn!("admin socket command exceeded 4096 byte limit");
let _ = writer
.write_all(
format!(
"{}\n",
serde_json::to_string(&ErrorResponse {
status: "error",
message: "command too long".to_string(),
})
.unwrap()
)
.as_bytes(),
)
.await;
return;
}
}
} }
let command = line.trim(); let command = line.trim();
@@ -680,4 +721,105 @@ upstream = "127.0.0.1:8080"
assert_eq!(parsed["sites"], 1); assert_eq!(parsed["sites"], 1);
assert!(parsed["uptime_secs"].is_number()); assert!(parsed["uptime_secs"].is_number());
} }
#[tokio::test]
async fn test_read_timeout() {
let dir = tempfile::tempdir().unwrap();
let admin_socket = Arc::new(create_test_admin_socket(dir.path()));
let socket_path = dir.path().join("admin.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let admin_socket_clone = admin_socket.clone();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
handle_connection(stream, admin_socket_clone).await;
});
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
use tokio::io::{AsyncBufReadExt, BufReader};
let mut response = String::new();
let mut reader = BufReader::new(stream);
let result =
tokio::time::timeout(Duration::from_secs(10), reader.read_line(&mut response)).await;
handle.await.unwrap();
assert!(result.is_ok());
let parsed: serde_json::Value = serde_json::from_str(response.trim()).unwrap();
assert_eq!(parsed["status"], "error");
assert_eq!(parsed["message"], "read timeout");
}
#[tokio::test]
async fn test_command_too_long() {
let dir = tempfile::tempdir().unwrap();
let admin_socket = Arc::new(create_test_admin_socket(dir.path()));
let socket_path = dir.path().join("admin.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let admin_socket_clone = admin_socket.clone();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
handle_connection(stream, admin_socket_clone).await;
});
let mut stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let long_data = "A".repeat(5000);
stream.write_all(long_data.as_bytes()).await.unwrap();
stream.shutdown().await.unwrap();
let mut response = String::new();
let mut reader = BufReader::new(stream);
reader.read_line(&mut response).await.unwrap();
handle.await.unwrap();
let parsed: serde_json::Value = serde_json::from_str(response.trim()).unwrap();
assert_eq!(parsed["status"], "error");
assert_eq!(parsed["message"], "command too long");
}
#[tokio::test]
async fn test_command_at_limit_boundary() {
let dir = tempfile::tempdir().unwrap();
let admin_socket = Arc::new(create_test_admin_socket(dir.path()));
let socket_path = dir.path().join("admin.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let admin_socket_clone = admin_socket.clone();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
handle_connection(stream, admin_socket_clone).await;
});
let mut stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let at_limit = format!("{}\n", "A".repeat(4095));
stream.write_all(at_limit.as_bytes()).await.unwrap();
stream.shutdown().await.unwrap();
let mut response = String::new();
let mut reader = BufReader::new(stream);
reader.read_line(&mut response).await.unwrap();
handle.await.unwrap();
let parsed: serde_json::Value = serde_json::from_str(response.trim()).unwrap();
assert_eq!(parsed["status"], "error");
assert_eq!(
parsed["message"]
.as_str()
.unwrap()
.starts_with("unknown command:"),
true
);
}
} }