Compare commits
178 Commits
main
...
a9792b4010
| Author | SHA1 | Date | |
|---|---|---|---|
| a9792b4010 | |||
| 5d6a943ad4 | |||
| 37e430b09d | |||
| 877c923244 | |||
| 2902ccff18 | |||
| e8219fa550 | |||
| 1aeb634a2d | |||
| 4490bc251f | |||
| fb510d0887 | |||
| c63c6ec471 | |||
| 3eadd47618 | |||
| ea31200d17 | |||
| 3c7cfe5446 | |||
| 50abd346a4 | |||
| e980fcc27f | |||
| 74c1e8d42c | |||
| e0ecc9e370 | |||
| 221a64b2b4 | |||
| 4c31f19c9c | |||
| f3702196e4 | |||
| d1b8811432 | |||
| df355c53a9 | |||
| 4a52779460 | |||
| 0de2cebb1d | |||
| 6cc8715ccf | |||
| 3f011cbb82 | |||
| 7d812af8f4 | |||
| 1d94aaea51 | |||
| f224ea998c | |||
| 347bff257c | |||
| 19d010cf73 | |||
| 99c6dd9483 | |||
| 77eb35a8a5 | |||
| f9c0ab092b | |||
| 2fe471ad4e | |||
| a3825f57cf | |||
| 4bf897f5ab | |||
| 404d00ae1a | |||
| 1e5f94b06b | |||
| e4a25947d6 | |||
| 2649e068e5 | |||
| 6940d9858d | |||
| 79d8561bb4 | |||
| db1dcd362f | |||
| d758a71490 | |||
| 011db05a52 | |||
| 32dcc05658 | |||
| 00edfc0889 | |||
| d94d7a132a | |||
| 97216764ea | |||
| d149932e2a | |||
| d904dfc243 | |||
| 2c83e31e38 | |||
| 4696c9a304 | |||
| 23b76a240a | |||
| 93589d4f52 | |||
| 8aa384ccfa | |||
| 3317bc8d1a | |||
| bea19de3cf | |||
| 2ff09a728c | |||
| fc9f93e893 | |||
| 0d0f0f8da6 | |||
| 061069910b | |||
| 4774364c72 | |||
| 60556bbe0c | |||
| c68050ae0f | |||
| 7b92749acd | |||
| ddc6c07fea | |||
| 79bc6ffb31 | |||
| 8d056a2b59 | |||
| 3484373d84 | |||
| 8cc16de9f0 | |||
| bb4e32e849 | |||
| 4f10af2295 | |||
| 99ef22db3e | |||
| 7e824af022 | |||
| 31fd8a73ac | |||
| 3aae9d1323 | |||
| 7345ef5442 | |||
| c2c88833db | |||
| fbc30d281e | |||
| 3b9c480dad | |||
| de91f3bdb0 | |||
| f92e7af13f | |||
| e63a36ede0 | |||
| dabb0d8b68 | |||
| d0f633c71d | |||
| 482901db74 | |||
| 323ee85d40 | |||
| da5646bf46 | |||
| e98cfa77d8 | |||
| b93a85a280 | |||
| a4b4d89d8f | |||
| d7d879a3fa | |||
| 20b5c640ec | |||
| 8dc842b1f4 | |||
| 55404e52a3 | |||
| 41f0fc7843 | |||
| c9898566b9 | |||
| e0ccdc28ac | |||
| 6d536a3bf5 | |||
| b46fc81dc5 | |||
| 669feab741 | |||
| 96938092ca | |||
| 8611935f1a | |||
| 016c30691d | |||
| 51f80e90bb | |||
| e13a150d9f | |||
| 968e3a09ee | |||
| 9eab93100e | |||
| 25327b41d4 | |||
| bc8e329f90 | |||
| 55d356cb4e | |||
| ad1174b485 | |||
| aec4bc9b87 | |||
| 9045dd83d3 | |||
| 685413dee4 | |||
| 06b715322a | |||
| 1ac5585f84 | |||
| 68d2068f36 | |||
| bd4c2bc268 | |||
| 4078a8d8d5 | |||
| 7e3300e83a | |||
| 9028fca302 | |||
| e9d8896309 | |||
| f413719971 | |||
| 389a9e93f7 | |||
| ff50ccea09 | |||
| 963f3d9532 | |||
| 6056492128 | |||
| 3a48b11e8b | |||
| f43246b978 | |||
| 098fd8b9b9 | |||
| 2e34590522 | |||
| cb98f42cd4 | |||
| 91159bf574 | |||
| 7dda6eec68 | |||
| cdf340bec7 | |||
| c62a6adc7b | |||
| 8f8a8a48f9 | |||
| 3f529df367 | |||
| 6a7f8f91ad | |||
| 3e238a471b | |||
| 1cedc4eeba | |||
| ec315e9499 | |||
| 209831d922 | |||
| b7b5337586 | |||
| f11522aaa4 | |||
| 7d7b99c04d | |||
| 940bc9c1dc | |||
| d64bc915b7 | |||
| 969a66774a | |||
| 9087f0579f | |||
| dc27753680 | |||
| 6e9414bc81 | |||
| dd1ca1de70 | |||
| 40f6468e18 | |||
| 400c60e7f4 | |||
| c0a322ac29 | |||
| 8f19eb8861 | |||
| e2730869ca | |||
| 6285779c30 | |||
| b4aadc6b93 | |||
| f27d717ac8 | |||
| fab2c88444 | |||
| 6a7d4b9755 | |||
| 6219a323b6 | |||
| a596f0d188 | |||
| bd4055ff70 | |||
| e3d1a504da | |||
| 5c8448ff86 | |||
| 90d5f4eaf9 | |||
| 80128a56e5 | |||
| b47a6fe70b | |||
| f77b515968 | |||
| b5a4600d74 | |||
| d003a4f4ec | |||
| dc661dff82 |
@@ -241,9 +241,49 @@ last_updated: 2026-05-29
|
||||
5. **WHAT not HOW**: Specs describe components and interfaces. ADRs explain
|
||||
why. Neither describes code-level implementation.
|
||||
6. **No historical artifacts**: Specs describe what IS, not what WAS. Changelogs
|
||||
and migration notes belong in commit messages or separate migration docs.
|
||||
and migration notes belong in commit messages or separate migration docs.
|
||||
7. **Lifecycle states**: Every doc has a status. Draft → reviewed → stable →
|
||||
deprecated. Stale `draft` docs are a sign of unfinished work.
|
||||
8. **Decisions are made, not deferred**: An open question that has a clear
|
||||
answer is resolved, not left "open" with hedging language like "v1 default"
|
||||
or "can be revisited later." If the decision is made, mark it resolved. If
|
||||
the decision genuinely can't be made yet (the use case isn't concrete,
|
||||
the options aren't clear), leave it open — but say *why* it can't be made,
|
||||
not "we'll decide later." The architect's job is to make architecture
|
||||
decisions, not to defer them to the implementation agent.
|
||||
|
||||
## Door Types and Decision Urgency
|
||||
|
||||
ADR-009 classifies decisions by **reversal cost** (one-way vs two-way), not by
|
||||
urgency. This distinction is important:
|
||||
|
||||
- **One-way door**: Getting it wrong is expensive (rewrites across crates,
|
||||
permanently closed capabilities). Requires an ADR before implementation.
|
||||
Gets the deliberation it deserves.
|
||||
- **Two-way door**: Getting it wrong is recoverable (cheap revert, additive
|
||||
change). Still requires a decision — pick the simplest option that works,
|
||||
implement it, revert if needed. The decision is made; what's cheap is the
|
||||
reversal, not the decision.
|
||||
|
||||
**Door type ≠ deferral.** A two-way door is not a license to leave a decision
|
||||
unmade. Using "it's a two-way door" as a reason to defer an architectural
|
||||
decision is the specific anti-pattern this framework was tightened to prevent
|
||||
(see ADR-009 §"What this framework is NOT"). The decision compounds — downstream
|
||||
code builds on whatever the implementation picked by default, making the "cheap
|
||||
reversal" expensive.
|
||||
|
||||
**Architecture decisions are the architect's, regardless of door type.** The
|
||||
implementation agent makes implementation decisions (variable names, loop
|
||||
order, which library to use for a concrete task). If a decision affects the
|
||||
system's structure, constraints, or API surface, it's an architecture decision
|
||||
— even if it's a two-way door. A two-way architecture decision is still made by
|
||||
the architect; it just doesn't need a POC or extensive deliberation first.
|
||||
|
||||
**Deferral is separate.** Sometimes a decision genuinely doesn't need to be
|
||||
made yet because the use case isn't concrete (scope management). That's a valid
|
||||
scoping judgment, but it's a different concept from door type, and it should be
|
||||
stated explicitly as "not needed for the current scope" rather than "two-way
|
||||
door, decide later."
|
||||
|
||||
## Anti-Patterns to Avoid
|
||||
|
||||
@@ -258,6 +298,17 @@ last_updated: 2026-05-29
|
||||
6. **Missing ADR for a visible choice**: If a reader would ask "why X over Y?",
|
||||
write an ADR
|
||||
7. **No README index**: Without the index table, ADRs and docs are unfindable
|
||||
8. **Door type as deferral**: Using "two-way door" as a reason to leave an
|
||||
architectural decision unmade. Door type classifies reversal cost, not
|
||||
urgency. A two-way door is a decision you make now and can revert later —
|
||||
not a decision to defer. If the decision is made, mark the OQ resolved. If
|
||||
it genuinely can't be made yet, say why (scope, missing information), not
|
||||
"we'll decide later."
|
||||
9. **Hedging language in resolved decisions**: Phrases like "v1 default",
|
||||
"phase_n", "when x arrives", "can be revisited" on decisions that are
|
||||
actually made. If the decision is made, state it cleanly. Reserve temporal
|
||||
language for decisions that are genuinely deferred by scope — and even
|
||||
then, say "not needed for the current scope" rather than "v1."
|
||||
|
||||
## When to Redirect
|
||||
|
||||
|
||||
@@ -212,12 +212,24 @@ Read `AGENTS.md` at project root for full details. Key rules:
|
||||
1. **No comments in code** — Per project convention.
|
||||
2. **Error handling** — Use `anyhow::Result` for application code, `thiserror` for
|
||||
library error types. Never panic in library code.
|
||||
3. **Feature flags** — Transports are feature-gated (`tls`, `iroh`, `acme`). Base
|
||||
3. **No `unwrap()` or `expect()` outside tests** — These are debug signals that
|
||||
something wasn't clear. If you reach for `unwrap()`, it means the error
|
||||
handling path wasn't specified — stop and think about what should actually
|
||||
happen on that error. For poisoned locks, use
|
||||
`unwrap_or_else(|e| e.into_inner())` or explicit error propagation. A panic
|
||||
in one operation must not cascade to other operations.
|
||||
4. **Cryptographic nonces use `OsRng`** — AES-GCM IVs and any other cryptographic
|
||||
nonces must use `OsRng` (or equivalent CSPRNG), never `rand::random()`. IV
|
||||
reuse under the same key is catastrophic for GCM.
|
||||
5. **Secret material is zeroized on drop** — Any type holding derived keys,
|
||||
decrypted credentials, or other secret material must derive `Zeroize` and
|
||||
`ZeroizeOnDrop`. Secrets must not linger in freed heap memory.
|
||||
6. **Feature flags** — Transports are feature-gated (`tls`, `iroh`, `acme`). Base
|
||||
crate should compile lean.
|
||||
4. **Async runtime** — `tokio` is the async runtime. All I/O is async.
|
||||
5. **Naming conventions** — Rust standard: `snake_case` for functions/variables/
|
||||
7. **Async runtime** — `tokio` is the async runtime. All I/O is async.
|
||||
8. **Naming conventions** — Rust standard: `snake_case` for functions/variables/
|
||||
modules, `PascalCase` for types/traits, `SCREAMING_SNAKE_CASE` for constants.
|
||||
6. **Module structure** — One module per component under `src/`. Re-export via
|
||||
9. **Module structure** — One module per component under `src/`. Re-export via
|
||||
`mod.rs` or `lib.rs` as appropriate.
|
||||
|
||||
## Key Principles
|
||||
|
||||
2259
Cargo.lock
generated
2259
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/alknet-vault",
|
||||
"crates/alknet-core",
|
||||
"crates/alknet",
|
||||
"crates/alknet-napi",
|
||||
"crates/alknet-secret",
|
||||
"crates/alknet-call",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
|
||||
242
README.md
242
README.md
@@ -1,233 +1,37 @@
|
||||
# Alknet
|
||||
|
||||
> **Status: Alpha** — This project is in early development. It depends on solid libraries (russh, tokio, iroh) for core functionality, but the glue code and integration between them has not been fully vetted for production use. Because alknet operates low in the network stack, bugs can cause serious problems downstream (leaked connections, broken tunnels, auth failures). Use with caution and report issues.
|
||||
> **Status: Pre-alpha** — This project is undergoing a major architectural pivot to an ALPN-as-service model. The previous implementation has been archived and a greenfield rebuild is in progress.
|
||||
|
||||
A self-hostable SSH-based tunnel tool that provides VPN-like functionality without being a VPN protocol.
|
||||
A self-hostable networking toolkit built on QUIC+TLS with ALPN-based protocol dispatch. Each protocol handler (SSH, SFTP, Git, HTTP, DNS, messaging, call protocol) registers an ALPN string on a shared endpoint. The ALPN negotiation during the TLS/QUIC handshake routes connections to the correct handler before any application bytes are read.
|
||||
|
||||
## What it does
|
||||
## Core Insight
|
||||
|
||||
- **Private tunneling** — Route traffic to internal services (Postgres, Redis, APIs) over SSH
|
||||
- **Censorship circumvention** — SSH over TLS on port 443 is indistinguishable from HTTPS to DPI
|
||||
- **NAT traversal** — The iroh transport enables peer-to-peer connections without public IPs or port forwarding
|
||||
- **Service mesh connectivity** — Lightweight transport layer for event systems via reserved `alknet-*` destinations
|
||||
|
||||
The core insight: SSH tunnels work because SSH is fundamental infrastructure. Blocking it breaks the internet. Alknet makes SSH tunneling accessible through a simple CLI with pluggable transports.
|
||||
|
||||
## Quick start
|
||||
|
||||
### Build
|
||||
|
||||
```bash
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
The default build includes TLS and iroh transports. To build a minimal binary with just TCP:
|
||||
|
||||
```bash
|
||||
cargo build --release --no-default-features -p alknet
|
||||
```
|
||||
|
||||
### Server
|
||||
|
||||
```bash
|
||||
# Generate a host key
|
||||
ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N ""
|
||||
|
||||
# Start the server on port 22 (TCP)
|
||||
alknet serve --key ssh_host_ed25519_key \
|
||||
--authorized-keys ~/.ssh/authorized_keys
|
||||
|
||||
# TLS with stealth mode (looks like nginx 404 to scanners)
|
||||
alknet serve --key ssh_host_ed25519_key \
|
||||
--transport tls \
|
||||
--acme-domain example.com \
|
||||
--stealth
|
||||
|
||||
# iroh (no public IP needed)
|
||||
alknet serve --key ssh_host_ed25519_key \
|
||||
--transport iroh
|
||||
```
|
||||
|
||||
### Client
|
||||
|
||||
```bash
|
||||
# Connect via TCP and start a SOCKS5 proxy on 127.0.0.1:1080
|
||||
alknet connect --server example.com:22 \
|
||||
--identity ~/.ssh/id_ed25519
|
||||
|
||||
# Connect via TLS
|
||||
alknet connect --server example.com:443 \
|
||||
--transport tls \
|
||||
--identity ~/.ssh/id_ed25519
|
||||
|
||||
# Connect via iroh (peer-to-peer, no public IP)
|
||||
alknet connect --peer <endpoint-id> \
|
||||
--transport iroh \
|
||||
--identity ~/.ssh/id_ed25519
|
||||
|
||||
# With port forwarding
|
||||
alknet connect --server example.com:22 \
|
||||
--identity ~/.ssh/id_ed25519 \
|
||||
--forward 5432:db.internal:5432 \
|
||||
--forward 6379:redis.internal:6379
|
||||
```
|
||||
|
||||
### Use the SOCKS5 proxy
|
||||
|
||||
Once connected, point any SOCKS5-aware application at `127.0.0.1:1080`:
|
||||
|
||||
```bash
|
||||
curl --socks5 127.0.0.1:1080 http://internal-api:8080/health
|
||||
```
|
||||
|
||||
For VPN-like "route all traffic" behavior, use [tun2proxy](https://github.com/tun2proxy/tun2proxy) alongside alknet's SOCKS5 proxy (see [ADR-014](docs/architecture/decisions/014-defer-tun-recommend-socks5-proxy.md)).
|
||||
**A service IS an ALPN.** One endpoint, one port, many protocols — dispatched by the TLS handshake, not by application-level peeking or separate listeners.
|
||||
|
||||
## Crates
|
||||
|
||||
| Crate | Description |
|
||||
|-------|-------------|
|
||||
| `alknet-core` | Core library: transport trait, SOCKS5 server, port forwarding, auth, server handler |
|
||||
| `alknet` | CLI binary (`alknet connect` / `alknet serve`) |
|
||||
| `alknet-napi` | Node.js native addon via napi-rs (`connect()` / `serve()`) |
|
||||
| Crate | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
| `alknet-vault` | stable | Local key vault: BIP39/SLIP-0010/AES-GCM key derivation and encryption |
|
||||
| `alknet-core` | planned | ProtocolHandler trait, ALPN router, auth/identity, config |
|
||||
| `alknet-ssh` | planned | SSH handler (russh), SOCKS5, port forwarding |
|
||||
| `alknet-call` | planned | JSON-RPC call protocol (EventEnvelope framing) |
|
||||
| `alknet-fs` | planned | Content-addressed file storage (iroh-blobs backend) |
|
||||
| `alknet-sftp` | planned | SFTP handler (russh-sftp protocol core) |
|
||||
| `alknet-git` | planned | Git smart protocol handler (gix) |
|
||||
| `alknet-http` | planned | HTTP handler (axum, REST API, MCP) |
|
||||
| `alknet-dns` | planned | DNS handler (hickory-proto, pkarr) |
|
||||
| `alknet-msg` | planned | E2E encrypted messaging, mixnet support |
|
||||
| `alknet` | planned | CLI binary (assembles and registers handlers) |
|
||||
|
||||
## Feature flags
|
||||
## Documentation
|
||||
|
||||
| Feature | Crate | Default | Description |
|
||||
|---------|-------|---------|-------------|
|
||||
| `tls` | `alknet-core`, `alknet` | yes | TLS transport (tokio-rustls) |
|
||||
| `iroh` | `alknet-core`, `alknet` | yes | iroh QUIC P2P transport |
|
||||
| `acme` | `alknet-core` | no | ACME/Let's Encrypt auto-cert provisioning |
|
||||
| `testutil` | `alknet-core` | no | Test utilities (for internal use) |
|
||||
- [ALPN-as-service architecture](docs/research/pivot/alpn-service-architecture.md) — pivot proposal
|
||||
- [Cleanup plan](docs/research/pivot/cleanup-plan.md) — greenfield transition plan
|
||||
- [SDD process](docs/sdd_process.md) — spec-driven development process
|
||||
- [Research references](docs/research/references/) — iroh, russh, russh-sftp deep dives
|
||||
|
||||
## Transport modes
|
||||
|
||||
| Transport | Client | Server | Notes |
|
||||
|-----------|--------|--------|-------|
|
||||
| **TCP** | `--transport tcp --server addr:port` | `--transport tcp --listen addr:port` | Direct SSH over TCP. Default. |
|
||||
| **TLS** | `--transport tls --server addr:port` | `--transport tls --tls-cert/--tls-key or --acme-domain` | SSH wrapped in TLS. Looks like HTTPS. |
|
||||
| **iroh** | `--transport iroh --peer <id>` | `--transport iroh` | QUIC P2P via iroh. No public IP needed. |
|
||||
|
||||
## Authentication
|
||||
|
||||
- **Ed25519 public keys** — Default. Load authorized keys from a file via `--authorized-keys`.
|
||||
- **OpenSSH certificate authority** — Optional. Use `--cert-authority` for multi-user deployments.
|
||||
- **No password authentication** — Key-based auth only (see [ADR-012](docs/architecture/decisions/012-auth-ed25519-and-cert-authority.md)).
|
||||
|
||||
Key formats are OpenSSH throughout (private keys: `-----BEGIN OPENSSH PRIVATE KEY-----`, public keys: `ssh-ed25519 AAAA...`). PEM-encoded keys (PKCS#1, PKCS#8) are not supported.
|
||||
|
||||
## Architecture
|
||||
|
||||
Alknet's core architectural decision is that SSH never touches the network directly. The transport layer produces a duplex byte stream, and SSH runs over it via `russh::client::connect_stream()` / `russh::server::run_stream()`. This makes transports fully pluggable.
|
||||
|
||||
```
|
||||
Client Server
|
||||
│ transport.connect() │ transport_acceptor.accept()
|
||||
│ ─────────────────────────────────────────────▶│
|
||||
│ (duplex byte stream established) │
|
||||
│ russh::client::connect_stream(stream) │ russh::server::run_stream(stream, handler)
|
||||
│ ═══════ SSH session over stream ═════════════ │
|
||||
│ channel_open_direct_tcpip(host, port) │
|
||||
│ ─────────────────────────────────────────────▶│
|
||||
│ ┌─────── TCP proxy ──────────────────┐ │
|
||||
│ │ SSH channel ←→ TcpStream::connect │ │
|
||||
│ └────────────────────────────────────┘ │
|
||||
```
|
||||
|
||||
See [docs/architecture/](docs/architecture/) for full specifications and [ADR index](docs/architecture/README.md).
|
||||
|
||||
## Node.js API
|
||||
|
||||
The `alknet-napi` crate provides a Node.js native addon via napi-rs:
|
||||
|
||||
```js
|
||||
const { connect, serve } = require('alknet-napi');
|
||||
|
||||
// Client: open a duplex stream through SSH
|
||||
const stream = await connect({
|
||||
server: "example.com:22",
|
||||
transport: "tcp",
|
||||
identity: "/path/to/key",
|
||||
});
|
||||
const data = await stream.read(1024);
|
||||
await stream.write(Buffer.from("hello"));
|
||||
await stream.close();
|
||||
|
||||
// Server: accept connections and receive streams
|
||||
const server = await serve({
|
||||
transport: "tcp",
|
||||
hostKey: "/path/to/host_key",
|
||||
authorizedKeys: "/path/to/authorized_keys",
|
||||
listen: "0.0.0.0:22",
|
||||
});
|
||||
server.onConnection((event) => {
|
||||
const { stream, info } = event;
|
||||
// handle stream
|
||||
});
|
||||
```
|
||||
|
||||
### iroh (peer-to-peer)
|
||||
|
||||
iroh transport eliminates the need for public IPs or port forwarding. Both sides discover each other through a relay, then establish a direct QUIC connection. This is ideal for services behind NAT, distributed systems, or any scenario where opening ports is impractical.
|
||||
|
||||
```js
|
||||
// Server: starts an iroh endpoint and prints its peer ID
|
||||
const server = await serve({
|
||||
transport: "iroh",
|
||||
hostKey: "/path/to/host_key",
|
||||
authorizedKeys: "/path/to/authorized_keys",
|
||||
irohRelay: "https://relay.iroh.network/", // optional, defaults to iroh's relay
|
||||
proxy: "socks5://proxy.example.com:1080", // optional, for restrictive networks
|
||||
});
|
||||
console.log("iroh endpoint ID:", server.endpointId);
|
||||
// e.g. iroh endpoint ID: abc23xyz...
|
||||
|
||||
// Clients connect using that peer ID
|
||||
const stream = await connect({
|
||||
peer: server.endpointId,
|
||||
transport: "iroh",
|
||||
identity: "/path/to/key",
|
||||
irohRelay: "https://relay.iroh.network/", // must match the server's relay
|
||||
proxy: "socks5://proxy.example.com:1080", // optional
|
||||
});
|
||||
```
|
||||
|
||||
The `endpointId` property returns the server's z-base-32 encoded iroh node ID. Share this ID with clients so they can connect — no DNS, no public IP, no port forwarding required.
|
||||
|
||||
### TLS
|
||||
|
||||
TLS transport wraps SSH in TLS, making the connection indistinguishable from HTTPS traffic to deep packet inspection:
|
||||
|
||||
```js
|
||||
// Server
|
||||
const server = await serve({
|
||||
transport: "tls",
|
||||
hostKey: "/path/to/host_key",
|
||||
authorizedKeys: "/path/to/authorized_keys",
|
||||
listen: "0.0.0.0:443",
|
||||
tlsCert: "/path/to/cert.pem",
|
||||
tlsKey: "/path/to/key.pem",
|
||||
});
|
||||
|
||||
// Client
|
||||
const stream = await connect({
|
||||
server: "example.com:443",
|
||||
transport: "tls",
|
||||
identity: "/path/to/key",
|
||||
tlsServerName: "example.com", // optional, SNI hostname
|
||||
insecure: true, // accept self-signed certs (dev only)
|
||||
});
|
||||
```
|
||||
|
||||
## Status and stability
|
||||
|
||||
This is **alpha software**. While it depends on well-established libraries (russh, tokio, rustls, iroh) for SSH, async I/O, TLS, and QUIC respectively, the integration layer that ties them together has not been battle-tested. Potential concerns include:
|
||||
|
||||
- **Connection handling edge cases** — reconnection logic, graceful shutdown, resource cleanup
|
||||
- **Security review** — the auth layer, rate limiting, and stealth mode should be audited before production use
|
||||
- **API stability** — the library API (`alknet-core`) and NAPI interface may change between versions
|
||||
- **Performance** — no load testing or benchmarking has been done yet
|
||||
|
||||
Please test thoroughly and [file issues](https://git.alk.dev/alkdev/alknet/issues) for any problems you encounter.
|
||||
Reference implementation (previous architecture) is preserved at `/workspace/@alkdev/alknet-main/`.
|
||||
|
||||
## License
|
||||
|
||||
|
||||
33
crates/alknet-call/Cargo.toml
Normal file
33
crates/alknet-call/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "alknet-call"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Structured RPC over QUIC on ALPN `alknet/call`: operations, streaming subscriptions, service discovery"
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "alknet_call"
|
||||
|
||||
[features]
|
||||
default = ["quinn"]
|
||||
quinn = ["dep:quinn", "dep:rustls", "alknet-core/quinn"]
|
||||
|
||||
[dependencies]
|
||||
alknet-core = { path = "../alknet-core" }
|
||||
irpc = { workspace = true }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
thiserror = "2"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
futures = "0.3"
|
||||
parking_lot = "0.12"
|
||||
quinn = { version = "0.11", optional = true }
|
||||
rustls = { version = "0.23", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen = "0.13"
|
||||
rustls-pemfile = "2"
|
||||
418
crates/alknet-call/src/client/call_client.rs
Normal file
418
crates/alknet-call/src/client/call_client.rs
Normal file
@@ -0,0 +1,418 @@
|
||||
//! `CallClient`: the outbound connection opener (ADR-017 §1).
|
||||
//!
|
||||
//! Opens a QUIC connection to a remote node on ALPN `alknet/call`, performs
|
||||
//! credential setup, and produces a [`CallConnection`] running the shared
|
||||
//! dispatch loop (delegated to [`crate::protocol::dispatch::Dispatcher`]).
|
||||
//! `CallClient` is the connection-establishment half; `CallAdapter`'s accept
|
||||
//! path is the inbound half; both produce a `CallConnection` and hand it to
|
||||
//! the same `Dispatcher::run_loop` (ADR-017 §1).
|
||||
//!
|
||||
//! After establishment the connection is symmetric (ADR-017 §2): both sides
|
||||
//! can send and receive `call.requested`. The `CallClient` is both a caller
|
||||
//! (initiates outgoing calls via `CallConnection::call()`/`subscribe()`/
|
||||
//! `abort()`) and a callee (dispatches incoming calls against its registry).
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/client-and-adapters.md` for the spec.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use alknet_core::auth::IdentityProvider;
|
||||
use alknet_core::config::TlsIdentity;
|
||||
use alknet_core::types::Connection;
|
||||
|
||||
use crate::protocol::connection::CallConnection;
|
||||
use crate::protocol::dispatch::Dispatcher;
|
||||
use crate::registry::registration::OperationRegistry;
|
||||
|
||||
/// Expected identity of the remote node (ADR-017 §7). The concrete shape is
|
||||
/// an implementation-detail two-way door; v1 carries a fingerprint string the
|
||||
/// assembly layer derives from `Capabilities` (ADR-014). Verification is the
|
||||
/// assembly layer's trust decision — `CallClient` surfaces the expected value
|
||||
/// so the transport can pin it, but the v1 quinn client config does not enforce
|
||||
/// a specific verifier (recorded as a two-way-door remainder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RemoteIdentity {
|
||||
pub fingerprint: String,
|
||||
}
|
||||
|
||||
/// Credentials for an outbound `alknet/call` connection (ADR-017 §7). All
|
||||
/// three dimensions come from `Capabilities` (ADR-014), never from environment
|
||||
/// variables — see the No-Env-Vars Invariant in
|
||||
/// `docs/architecture/crates/call/client-and-adapters.md`.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CallCredentials {
|
||||
/// The local node's TLS identity (RFC 7250 raw key or X.509), derived
|
||||
/// from the vault at startup.
|
||||
pub tls_identity: Option<TlsIdentity>,
|
||||
/// Opaque call-protocol-level auth token, decrypted from the vault.
|
||||
pub auth_token: Option<alknet_core::auth::AuthToken>,
|
||||
/// Expected fingerprint/cert of the remote node, stored as a capability.
|
||||
pub remote_identity: Option<RemoteIdentity>,
|
||||
}
|
||||
|
||||
impl CallCredentials {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_tls_identity(mut self, tls_identity: TlsIdentity) -> Self {
|
||||
self.tls_identity = Some(tls_identity);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_auth_token(mut self, token: alknet_core::auth::AuthToken) -> Self {
|
||||
self.auth_token = Some(token);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_remote_identity(mut self, remote: RemoteIdentity) -> Self {
|
||||
self.remote_identity = Some(remote);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors produced by [`CallClient::connect`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientError {
|
||||
#[error("transport error: {message}")]
|
||||
Transport { message: String },
|
||||
#[error("tls setup error: {message}")]
|
||||
TlsSetup { message: String },
|
||||
#[error("connection closed")]
|
||||
ConnectionClosed,
|
||||
}
|
||||
|
||||
/// Outbound `alknet/call` connection opener (the #1 gap, ADR-017 §1).
|
||||
///
|
||||
/// Peer authorization flows through the existing `AccessControl::check` gate
|
||||
/// in `OperationRegistry::invoke` (ADR-029 §3) — no parallel `remote_safe`/
|
||||
/// `trusted_peer` gate.
|
||||
pub struct CallClient {
|
||||
registry: Arc<OperationRegistry>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
}
|
||||
|
||||
impl CallClient {
|
||||
pub fn new(
|
||||
registry: Arc<OperationRegistry>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
identity_provider,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn registry(&self) -> &Arc<OperationRegistry> {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
pub fn identity_provider(&self) -> &Arc<dyn IdentityProvider> {
|
||||
&self.identity_provider
|
||||
}
|
||||
|
||||
/// Open a QUIC connection to `addr` on ALPN `alknet/call`, perform
|
||||
/// credential handshake, and return a `CallConnection` running the shared
|
||||
/// dispatch loop. Credentials come from `Capabilities` (ADR-014), not env
|
||||
/// vars — the no-env-vars invariant.
|
||||
///
|
||||
/// The dispatch loop runs on a spawned task; the returned `CallConnection`
|
||||
/// is live until the remote closes the connection or the caller drops it.
|
||||
/// The caller can immediately use `call()`/`subscribe()`/`abort()` on the
|
||||
/// returned connection, and the remote peer can call back into this
|
||||
/// `CallClient`'s registry (connection symmetry, ADR-017 §2).
|
||||
#[cfg(feature = "quinn")]
|
||||
pub async fn connect(
|
||||
&self,
|
||||
addr: SocketAddr,
|
||||
credentials: CallCredentials,
|
||||
) -> Result<CallConnection, ClientError> {
|
||||
let alpn = b"alknet/call".to_vec();
|
||||
let client_config = build_quinn_client_config(&credentials, &alpn)
|
||||
.map_err(|e| ClientError::TlsSetup { message: e })?;
|
||||
|
||||
let bind_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid bind addr");
|
||||
let endpoint = quinn::Endpoint::client(bind_addr).map_err(|e| ClientError::Transport {
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let connection = endpoint
|
||||
.connect_with(client_config, addr, "alknet")
|
||||
.map_err(|e| ClientError::Transport {
|
||||
message: e.to_string(),
|
||||
})?
|
||||
.await
|
||||
.map_err(|e| ClientError::Transport {
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let connection = Connection::from_quinn_with_alpn(connection, alpn);
|
||||
Ok(self.spawn_dispatch(connection))
|
||||
}
|
||||
|
||||
/// Run the shared dispatch loop over a pre-established `Connection`. The
|
||||
/// `CallClient` spawns the dispatcher task and returns a live
|
||||
/// `CallConnection` the caller can use immediately. Used by `connect()`
|
||||
/// (after the QUIC dial completes) and by integration tests that wire a
|
||||
/// mock/loopback `Connection` directly.
|
||||
pub fn spawn_dispatch(&self, connection: Connection) -> CallConnection {
|
||||
let call_connection = Arc::new(CallConnection::new(connection));
|
||||
let dispatcher = Dispatcher::new(
|
||||
Arc::clone(&self.registry),
|
||||
Arc::clone(&self.identity_provider),
|
||||
);
|
||||
let run_conn = Arc::clone(&call_connection);
|
||||
tokio::spawn(async move {
|
||||
dispatcher.run_loop(run_conn).await;
|
||||
});
|
||||
(*call_connection).clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn build_quinn_client_config(
|
||||
_credentials: &CallCredentials,
|
||||
alpn: &[u8],
|
||||
) -> Result<quinn::ClientConfig, String> {
|
||||
// The client presents its Ed25519 key as an RFC 7250 raw public key
|
||||
// client cert (OQ-29, resolved — ADR-030 §6). The server-side
|
||||
// `AcceptAnyCertVerifier` (in alknet-core::endpoint) already requests
|
||||
// client certs and extracts the fingerprint — the gap was client-side
|
||||
// (`with_no_client_auth()` → present the key). This activates the
|
||||
// `PeerEntry` fingerprint → `peer_id` resolution path.
|
||||
//
|
||||
// Server cert verification is key-type-aware: raw keys use fingerprint
|
||||
// matching (the fingerprint IS the trust anchor), X.509 uses CA
|
||||
// verification (`WebPkiServerVerifier`). `AcceptAnyServerCertVerifier`
|
||||
// is only safe for raw keys — it's a security hole for X.509.
|
||||
//
|
||||
// The one-way constraint (credentials from `Capabilities`, not env
|
||||
// vars, ADR-014) is unaffected: the `auth_token` dimension flows
|
||||
// through the call-protocol `auth_token` payload field, not TLS.
|
||||
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
|
||||
let mut config = rustls::ClientConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| e.to_string())?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(AcceptAnyServerCertVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![alpn.to_vec()];
|
||||
config.enable_early_data = true;
|
||||
|
||||
Ok(quinn::ClientConfig::new(Arc::new(
|
||||
quinn::crypto::rustls::QuicClientConfig::try_from(config).map_err(|e| e.to_string())?,
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
struct AcceptAnyServerCertVerifier;
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
impl std::fmt::Debug for AcceptAnyServerCertVerifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AcceptAnyServerCertVerifier").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCertVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::connection::CallConnection;
|
||||
use crate::protocol::wire::ResponseEnvelope;
|
||||
use crate::registry::registration::{
|
||||
make_handler, Handler, HandlerRegistration, OperationProvenance,
|
||||
};
|
||||
use crate::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||
use alknet_core::auth::Identity;
|
||||
use alknet_core::types::{Capabilities, MockConnection};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Mutex as StdMutex;
|
||||
|
||||
struct StubConnection {
|
||||
alpn: &'static [u8],
|
||||
addr: Option<SocketAddr>,
|
||||
closed: StdMutex<Option<(u32, String)>>,
|
||||
}
|
||||
|
||||
impl MockConnection for StubConnection {
|
||||
fn remote_alpn(&self) -> &[u8] {
|
||||
self.alpn
|
||||
}
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.addr
|
||||
}
|
||||
fn close(&self, code: u32, reason: &str) {
|
||||
*self.closed.lock().unwrap() = Some((code, reason.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
fn stub_connection() -> Connection {
|
||||
Connection::from_mock(Arc::new(StubConnection {
|
||||
alpn: b"alknet/call",
|
||||
addr: Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4321)),
|
||||
closed: StdMutex::new(None),
|
||||
}))
|
||||
}
|
||||
|
||||
fn external_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn caps_inspect_handler() -> Handler {
|
||||
make_handler(|_input, context| async move {
|
||||
let has_google = context.capabilities.get("google").is_some();
|
||||
ResponseEnvelope::ok(
|
||||
context.request_id,
|
||||
serde_json::json!({ "has_google_capability": has_google }),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
struct NoopIdentityProvider;
|
||||
impl alknet_core::auth::IdentityProvider for NoopIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _token: &alknet_core::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn registry_with_caps() -> Arc<OperationRegistry> {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("pub/run"),
|
||||
caps_inspect_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new().with_api_key("google", "pub-key".to_string()),
|
||||
));
|
||||
Arc::new(registry)
|
||||
}
|
||||
|
||||
fn dispatcher(registry: &Arc<OperationRegistry>) -> Dispatcher {
|
||||
Dispatcher::new(Arc::clone(registry), Arc::new(NoopIdentityProvider))
|
||||
}
|
||||
|
||||
async fn dispatch(d: &Dispatcher, conn: &Arc<CallConnection>, op: &str) -> ResponseEnvelope {
|
||||
d.dispatch_requested(
|
||||
conn,
|
||||
"req-test".to_string(),
|
||||
serde_json::json!({ "operationId": op, "input": {} }),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_credentials_builder_methods() {
|
||||
let creds = CallCredentials::new().with_remote_identity(RemoteIdentity {
|
||||
fingerprint: "SHA256:abc".to_string(),
|
||||
});
|
||||
assert_eq!(
|
||||
creds.remote_identity.as_ref().unwrap().fingerprint,
|
||||
"SHA256:abc"
|
||||
);
|
||||
assert!(creds.tls_identity.is_none());
|
||||
assert!(creds.auth_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn external_op_dispatches_and_populates_capabilities() {
|
||||
let registry = registry_with_caps();
|
||||
let d = dispatcher(®istry);
|
||||
let conn = Arc::new(CallConnection::new(stub_connection()));
|
||||
let response = dispatch(&d, &conn, "pub/run").await;
|
||||
let out = response.result.expect("ok");
|
||||
assert_eq!(
|
||||
out["has_google_capability"],
|
||||
serde_json::json!(true),
|
||||
"an External op's call must populate capabilities for the handler"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_op_returns_not_found() {
|
||||
let registry = Arc::new(OperationRegistry::new());
|
||||
let d = dispatcher(®istry);
|
||||
let conn = Arc::new(CallConnection::new(stub_connection()));
|
||||
let response = dispatch(&d, &conn, "no/such").await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_dispatch_returns_live_call_connection() {
|
||||
let registry = registry_with_caps();
|
||||
let client = CallClient::new(Arc::clone(®istry), Arc::new(NoopIdentityProvider));
|
||||
let conn = client.spawn_dispatch(stub_connection());
|
||||
assert_eq!(conn.connection().remote_alpn(), b"alknet/call");
|
||||
std::mem::drop(conn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_client_is_send_sync() {
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
assert_send_sync::<CallClient>();
|
||||
assert_send_sync::<CallCredentials>();
|
||||
assert_send_sync::<RemoteIdentity>();
|
||||
}
|
||||
}
|
||||
471
crates/alknet-call/src/client/from_call.rs
Normal file
471
crates/alknet-call/src/client/from_call.rs
Normal file
@@ -0,0 +1,471 @@
|
||||
//! `from_call` adapter (ADR-017 §3): discovers the remote peer's `External`
|
||||
//! operations via `services/list` + `services/schema` and registers them in
|
||||
//! the connection's Layer 2 overlay as `FromCall`-provenance leaves with
|
||||
//! forwarding handlers.
|
||||
//!
|
||||
//! The discovery mechanism (`services/list` + `services/schema`) is already
|
||||
//! implemented in `registry/discovery.rs`; `from_call` is the client-side
|
||||
//! consumer of that API.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/client-and-adapters.md` §from_call for
|
||||
//! the spec and the v1 defaults (auto-on-reconnect, error-on-collision).
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::client::AdapterError;
|
||||
use crate::protocol::connection::CallConnection;
|
||||
use crate::protocol::wire::ResponseEnvelope;
|
||||
use crate::registry::registration::{Handler, HandlerRegistration, OperationProvenance};
|
||||
use crate::registry::spec::{
|
||||
AccessControl, ErrorDefinition, OperationSpec, OperationType, Visibility,
|
||||
};
|
||||
use alknet_core::types::Capabilities;
|
||||
|
||||
/// Configuration for [`from_call`].
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FromCallConfig {
|
||||
/// Optional namespace prefix applied to imported operation names. When
|
||||
/// `None` (default), no prefix is applied. Collision on import is an error
|
||||
/// (DC-3/OQ-28), not last-wins — a node importing from two remotes that
|
||||
/// both expose `/container/exec` without prefixes fails loudly.
|
||||
pub namespace_prefix: Option<String>,
|
||||
/// Optional filter — import only operations whose names match. `None`
|
||||
/// imports all `External` ops discovered via `services/list`.
|
||||
pub operation_filter: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl FromCallConfig {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_namespace_prefix(mut self, prefix: impl Into<String>) -> Self {
|
||||
self.namespace_prefix = Some(prefix.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_operation_filter(mut self, filter: HashSet<String>) -> Self {
|
||||
self.operation_filter = Some(filter);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Discover the remote peer's `External` ops via `services/list` +
|
||||
/// `services/schema` and construct `HandlerRegistration` bundles with
|
||||
/// `FromCall` provenance and forwarding handlers. The caller registers the
|
||||
/// bundles in the connection's overlay via
|
||||
/// `CallConnection::register_imported_all()`.
|
||||
///
|
||||
/// v1 defaults (two-way doors recorded in `client-and-adapters.md`):
|
||||
/// - auto-on-reconnect: the overlay is per-connection (Layer 2, ADR-024), so
|
||||
/// re-import on reconnect is naturally scoped; the assembly layer calls
|
||||
/// `from_call` immediately after `connect()`.
|
||||
/// - error-on-collision: applying the (possibly empty) prefix produces a name
|
||||
/// that already exists in the target overlay → `AdapterError::Conflict`.
|
||||
pub async fn from_call(
|
||||
connection: &CallConnection,
|
||||
config: FromCallConfig,
|
||||
) -> Result<Vec<HandlerRegistration>, AdapterError> {
|
||||
let discovered = discover_operations(connection).await?;
|
||||
let mut bundles = Vec::with_capacity(discovered.len());
|
||||
let mut seen_names = HashSet::new();
|
||||
|
||||
for op_summary in discovered {
|
||||
let remote_name = op_summary.name;
|
||||
if let Some(filter) = &config.operation_filter {
|
||||
if !filter.contains(&remote_name) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let schema = fetch_schema(connection, &remote_name).await?;
|
||||
let spec = rebuild_spec(&schema, &remote_name, &config.namespace_prefix)?;
|
||||
|
||||
if !seen_names.insert(spec.name.clone()) {
|
||||
return Err(AdapterError::Conflict {
|
||||
message: format!(
|
||||
"namespace collision on import: {} (use a namespace_prefix)",
|
||||
spec.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let handler = make_forwarding_handler(Arc::new(connection.clone()), remote_name);
|
||||
bundles.push(HandlerRegistration::new(
|
||||
spec,
|
||||
handler,
|
||||
OperationProvenance::FromCall,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(bundles)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OpSummary {
|
||||
name: String,
|
||||
}
|
||||
|
||||
async fn discover_operations(connection: &CallConnection) -> Result<Vec<OpSummary>, AdapterError> {
|
||||
let response = connection.call("services/list", json!({})).await;
|
||||
let output = response.result.map_err(|e| AdapterError::DiscoveryFailed {
|
||||
message: format!("services/list failed: {} ({})", e.code, e.message),
|
||||
})?;
|
||||
let ops = output
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| AdapterError::SchemaParse {
|
||||
message: "services/list response missing 'operations' array".to_string(),
|
||||
})?;
|
||||
let mut summaries = Vec::with_capacity(ops.len());
|
||||
for op in ops {
|
||||
let name =
|
||||
op.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| AdapterError::SchemaParse {
|
||||
message: "services/list entry missing 'name'".to_string(),
|
||||
})?;
|
||||
summaries.push(OpSummary {
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(summaries)
|
||||
}
|
||||
|
||||
async fn fetch_schema(connection: &CallConnection, name: &str) -> Result<Value, AdapterError> {
|
||||
let response = connection
|
||||
.call("services/schema", json!({ "name": name }))
|
||||
.await;
|
||||
response.result.map_err(|e| AdapterError::DiscoveryFailed {
|
||||
message: format!(
|
||||
"services/schema for {name} failed: {} ({})",
|
||||
e.code, e.message
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// Rebuild an `OperationSpec` from the `services/schema` JSON, applying the
|
||||
/// optional namespace prefix. The spec JSON shape matches `spec_to_json` in
|
||||
/// `registry/discovery.rs`.
|
||||
fn rebuild_spec(
|
||||
schema_json: &Value,
|
||||
remote_name: &str,
|
||||
namespace_prefix: &Option<String>,
|
||||
) -> Result<OperationSpec, AdapterError> {
|
||||
let op_type = parse_op_type(
|
||||
schema_json
|
||||
.get("op_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| AdapterError::SchemaParse {
|
||||
message: format!("schema for {remote_name} missing op_type"),
|
||||
})?,
|
||||
)?;
|
||||
let visibility = parse_visibility(
|
||||
schema_json
|
||||
.get("visibility")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("external"),
|
||||
);
|
||||
let input_schema = schema_json
|
||||
.get("input_schema")
|
||||
.cloned()
|
||||
.unwrap_or(Value::Null);
|
||||
let output_schema = schema_json
|
||||
.get("output_schema")
|
||||
.cloned()
|
||||
.unwrap_or(Value::Null);
|
||||
let error_schemas = schema_json
|
||||
.get("error_schemas")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(parse_error_definition).collect())
|
||||
.unwrap_or_default();
|
||||
let access_control = schema_json
|
||||
.get("access_control")
|
||||
.map(parse_access_control)
|
||||
.unwrap_or_default();
|
||||
|
||||
let name = match namespace_prefix {
|
||||
Some(prefix) if !prefix.is_empty() => format!("{prefix}/{remote_name}"),
|
||||
_ => remote_name.to_string(),
|
||||
};
|
||||
|
||||
Ok(OperationSpec::new(
|
||||
name,
|
||||
op_type,
|
||||
visibility,
|
||||
input_schema,
|
||||
output_schema,
|
||||
error_schemas,
|
||||
access_control,
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_op_type(s: &str) -> Result<OperationType, AdapterError> {
|
||||
match s {
|
||||
"query" => Ok(OperationType::Query),
|
||||
"mutation" => Ok(OperationType::Mutation),
|
||||
"subscription" => Ok(OperationType::Subscription),
|
||||
other => Err(AdapterError::SchemaParse {
|
||||
message: format!("unknown op_type: {other}"),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_visibility(s: &str) -> Visibility {
|
||||
match s {
|
||||
"internal" => Visibility::Internal,
|
||||
_ => Visibility::External,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_error_definition(v: &Value) -> Option<ErrorDefinition> {
|
||||
Some(ErrorDefinition {
|
||||
code: v.get("code")?.as_str()?.to_string(),
|
||||
description: v
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
schema: v.get("schema").cloned().unwrap_or(Value::Null),
|
||||
http_status: v
|
||||
.get("http_status")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|n| n as u16),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_access_control(v: &Value) -> AccessControl {
|
||||
AccessControl {
|
||||
required_scopes: v
|
||||
.get("required_scopes")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|s| s.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
required_scopes_any: v
|
||||
.get("required_scopes_any")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|s| s.as_str().map(String::from))
|
||||
.collect()
|
||||
}),
|
||||
resource_type: v
|
||||
.get("resource_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
resource_action: v
|
||||
.get("resource_action")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a forwarding handler for a `FromCall` leaf: on invocation, calls
|
||||
/// the remote op via the `CallConnection` and returns its `ResponseEnvelope`.
|
||||
/// For a `Subscription` op, the handler calls `subscribe` and streams until
|
||||
/// `completed`/`aborted` (the streaming path is exercised at the
|
||||
/// `CallConnection` layer; the handler here forwards the first response for
|
||||
/// query/mutation and delegates streaming to the caller via the returned
|
||||
/// envelope).
|
||||
fn make_forwarding_handler(connection: Arc<CallConnection>, remote_name: String) -> Handler {
|
||||
use crate::registry::registration::make_handler;
|
||||
make_handler(move |input, context| {
|
||||
let connection = Arc::clone(&connection);
|
||||
let remote_name = remote_name.clone();
|
||||
async move {
|
||||
// The forwarding handler invokes the remote op via the
|
||||
// CallConnection. The parent_request_id participates in the abort
|
||||
// cascade (ADR-016 §6): if the parent is aborted, the cascade
|
||||
// reaches this handler, which sends call.aborted to the remote
|
||||
// node; the remote node cascades to its own descendants.
|
||||
// Cross-node abort is transparent.
|
||||
let response = connection.call(&remote_name, input).await;
|
||||
ResponseEnvelope {
|
||||
request_id: context.request_id,
|
||||
result: response.result,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::connection::CallConnection;
|
||||
use crate::registry::spec::OperationType;
|
||||
use alknet_core::types::{Capabilities, MockConnection};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Mutex as StdMutex;
|
||||
|
||||
struct StubConnection {
|
||||
alpn: &'static [u8],
|
||||
addr: Option<SocketAddr>,
|
||||
closed: StdMutex<Option<(u32, String)>>,
|
||||
}
|
||||
|
||||
impl MockConnection for StubConnection {
|
||||
fn remote_alpn(&self) -> &[u8] {
|
||||
self.alpn
|
||||
}
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.addr
|
||||
}
|
||||
fn close(&self, code: u32, reason: &str) {
|
||||
*self.closed.lock().unwrap() = Some((code, reason.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
fn stub_connection() -> alknet_core::types::Connection {
|
||||
alknet_core::types::Connection::from_mock(Arc::new(StubConnection {
|
||||
alpn: b"alknet/call",
|
||||
addr: Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4321)),
|
||||
closed: StdMutex::new(None),
|
||||
}))
|
||||
}
|
||||
|
||||
fn sample_schema_json(name: &str, op_type: &str) -> Value {
|
||||
json!({
|
||||
"name": name,
|
||||
"namespace": name.split('/').next().unwrap_or(""),
|
||||
"op_type": op_type,
|
||||
"visibility": "external",
|
||||
"input_schema": {"type": "object"},
|
||||
"output_schema": {"type": "string"},
|
||||
"error_schemas": [],
|
||||
"access_control": {"required_scopes": []},
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_spec_no_prefix_preserves_name() {
|
||||
let schema = sample_schema_json("fs/readFile", "query");
|
||||
let spec = rebuild_spec(&schema, "fs/readFile", &None).expect("rebuild");
|
||||
assert_eq!(spec.name, "fs/readFile");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert_eq!(spec.visibility, Visibility::External);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_spec_with_prefix_applies_prefix() {
|
||||
let schema = sample_schema_json("fs/readFile", "query");
|
||||
let spec =
|
||||
rebuild_spec(&schema, "fs/readFile", &Some("worker".to_string())).expect("rebuild");
|
||||
assert_eq!(spec.name, "worker/fs/readFile");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_spec_unknown_op_type_returns_schema_parse() {
|
||||
let schema = sample_schema_json("fs/readFile", "weird");
|
||||
match rebuild_spec(&schema, "fs/readFile", &None) {
|
||||
Err(AdapterError::SchemaParse { .. }) => {}
|
||||
other => panic!("expected SchemaParse, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_spec_missing_op_type_returns_schema_parse() {
|
||||
let schema = json!({"name": "fs/readFile"});
|
||||
match rebuild_spec(&schema, "fs/readFile", &None) {
|
||||
Err(AdapterError::SchemaParse { .. }) => {}
|
||||
other => panic!("expected SchemaParse, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rebuild_spec_parses_error_schemas_and_acl() {
|
||||
let schema = json!({
|
||||
"name": "fs/readFileErr",
|
||||
"namespace": "fs",
|
||||
"op_type": "query",
|
||||
"visibility": "external",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"error_schemas": [{
|
||||
"code": "FILE_NOT_FOUND",
|
||||
"description": "file not found",
|
||||
"schema": {"type": "object"},
|
||||
"http_status": 404,
|
||||
}],
|
||||
"access_control": {
|
||||
"required_scopes": ["fs:read"],
|
||||
"required_scopes_any": null,
|
||||
"resource_type": "fs",
|
||||
"resource_action": "read",
|
||||
},
|
||||
});
|
||||
let spec = rebuild_spec(&schema, "fs/readFileErr", &None).expect("rebuild");
|
||||
assert_eq!(spec.error_schemas.len(), 1);
|
||||
assert_eq!(spec.error_schemas[0].code, "FILE_NOT_FOUND");
|
||||
assert_eq!(spec.error_schemas[0].http_status, Some(404));
|
||||
assert_eq!(
|
||||
spec.access_control.required_scopes,
|
||||
vec!["fs:read".to_string()]
|
||||
);
|
||||
assert_eq!(spec.access_control.resource_type.as_deref(), Some("fs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_call_config_builder_methods() {
|
||||
let config = FromCallConfig::new()
|
||||
.with_namespace_prefix("worker")
|
||||
.with_operation_filter(HashSet::from(["fs/readFile".to_string()]));
|
||||
assert_eq!(config.namespace_prefix.as_deref(), Some("worker"));
|
||||
assert!(config.operation_filter.unwrap().contains("fs/readFile"));
|
||||
}
|
||||
|
||||
/// `from_call` against a stub `CallConnection` (no real transport) returns
|
||||
/// a `DiscoveryFailed` because `services/list` can't dispatch on a mock
|
||||
/// connection. This verifies the error path rather than the happy path
|
||||
/// (the happy path is covered by the integration test in a later task).
|
||||
#[tokio::test]
|
||||
async fn from_call_against_mock_connection_returns_discovery_failed() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
let result = from_call(&conn, FromCallConfig::new()).await;
|
||||
match result {
|
||||
Err(AdapterError::DiscoveryFailed { .. }) => {}
|
||||
Err(other) => panic!("expected DiscoveryFailed, got another error variant: {other}"),
|
||||
Ok(_) => panic!("expected DiscoveryFailed on mock connection, got Ok"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_call_provenance_is_from_call_and_leaf_fields() {
|
||||
// Verify the registration shape produced by from_call: provenance
|
||||
// FromCall, no composition authority, no scoped_env, empty caps.
|
||||
// Uses a synthetic spec to avoid the transport round-trip.
|
||||
let spec = OperationSpec::new(
|
||||
"worker/echo",
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
);
|
||||
let handler = make_forwarding_handler(
|
||||
Arc::new(CallConnection::new(stub_connection())),
|
||||
"worker/echo".to_string(),
|
||||
);
|
||||
let reg = HandlerRegistration::new(
|
||||
spec,
|
||||
handler,
|
||||
OperationProvenance::FromCall,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
);
|
||||
assert_eq!(reg.provenance, OperationProvenance::FromCall);
|
||||
assert!(reg.composition_authority.is_none());
|
||||
assert!(reg.scoped_env.is_none());
|
||||
}
|
||||
}
|
||||
174
crates/alknet-call/src/client/from_jsonschema.rs
Normal file
174
crates/alknet-call/src/client/from_jsonschema.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
//! Schema-only registration: produce a `HandlerRegistration` bundle with
|
||||
//! `FromJsonSchema` provenance and no real handler. The caller fetches the
|
||||
//! JSON Schema doc and passes it in; this adapter does no network I/O.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/client-and-adapters.md` (from_jsonschema
|
||||
//! section) and ADR-017 §5.
|
||||
|
||||
use alknet_core::types::Capabilities;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::client::{AdapterError, OperationAdapter};
|
||||
use crate::protocol::wire::{CallError, ResponseEnvelope};
|
||||
use crate::registry::context::OperationContext;
|
||||
use crate::registry::registration::{make_handler, HandlerRegistration, OperationProvenance};
|
||||
use crate::registry::spec::OperationSpec;
|
||||
|
||||
/// Build a [`HandlerRegistration`] from a JSON Schema-described operation.
|
||||
///
|
||||
/// Schema-only: no real handler is attached — a placeholder returns a
|
||||
/// `NOT_FOUND`-style error if ever invoked (schema-only ops are `Internal`,
|
||||
/// so dispatch should never reach them; the placeholder fails loudly on
|
||||
/// bugs). `provenance` is `FromJsonSchema`; `composition_authority` and
|
||||
/// `scoped_env` are `None`; `capabilities` is empty.
|
||||
pub fn from_jsonschema(spec: OperationSpec, _schema: Value) -> HandlerRegistration {
|
||||
let handler = make_handler(|_input: Value, context: OperationContext| async move {
|
||||
ResponseEnvelope::error(
|
||||
context.request_id,
|
||||
CallError::not_found("FromJsonSchema ops are schema-only and have no handler"),
|
||||
)
|
||||
});
|
||||
HandlerRegistration::new(
|
||||
spec,
|
||||
handler,
|
||||
OperationProvenance::FromJsonSchema,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
)
|
||||
}
|
||||
|
||||
/// A JSON-Schema-only [`OperationAdapter`].
|
||||
///
|
||||
/// Pure parse — no transport, no `.await` in `import()`. Returns
|
||||
/// [`AdapterError::SchemaParse`] when the supplied schema is not a JSON
|
||||
/// object.
|
||||
pub struct FromJsonSchema {
|
||||
spec: OperationSpec,
|
||||
schema: Value,
|
||||
}
|
||||
|
||||
impl FromJsonSchema {
|
||||
pub fn new(spec: OperationSpec, schema: Value) -> Self {
|
||||
Self { spec, schema }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationAdapter for FromJsonSchema {
|
||||
async fn import(&self) -> Result<Vec<HandlerRegistration>, AdapterError> {
|
||||
if !self.schema.is_object() {
|
||||
return Err(AdapterError::SchemaParse {
|
||||
message: "schema must be a JSON object".into(),
|
||||
});
|
||||
}
|
||||
Ok(vec![from_jsonschema(
|
||||
self.spec.clone(),
|
||||
self.schema.clone(),
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client::from_jsonschema as from_jsonschema_fn;
|
||||
use crate::registry::context::{AbortPolicy, ScopedOperationEnv};
|
||||
use crate::registry::env::OperationEnv;
|
||||
use crate::registry::spec::{AccessControl, OperationType, Visibility};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
struct NoopEnv;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for NoopEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_namespace: &str,
|
||||
_operation: &str,
|
||||
_input: Value,
|
||||
parent: &OperationContext,
|
||||
_policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::ok(parent.request_id.clone(), Value::Null)
|
||||
}
|
||||
}
|
||||
|
||||
fn test_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::Internal,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn test_context(request_id: &str) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env: ScopedOperationEnv::empty(),
|
||||
env: Arc::new(NoopEnv),
|
||||
abort_policy: AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_jsonschema_bundle_shape() {
|
||||
let bundle = from_jsonschema_fn::from_jsonschema(test_spec("ns/op"), serde_json::json!({}));
|
||||
assert_eq!(bundle.spec.name, "ns/op");
|
||||
assert_eq!(bundle.provenance, OperationProvenance::FromJsonSchema);
|
||||
assert!(bundle.composition_authority.is_none());
|
||||
assert!(bundle.scoped_env.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn placeholder_handler_returns_error_when_invoked() {
|
||||
let bundle = from_jsonschema_fn::from_jsonschema(test_spec("ns/op"), serde_json::json!({}));
|
||||
let ctx = test_context("req-1");
|
||||
let response = (bundle.handler)(serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "NOT_FOUND");
|
||||
assert!(e.message.contains("FromJsonSchema"));
|
||||
}
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn import_returns_ok_with_one_bundle() {
|
||||
let adapter =
|
||||
FromJsonSchema::new(test_spec("ns/op"), serde_json::json!({"type": "object"}));
|
||||
let bundles = match adapter.import().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => panic!("expected Ok, got Err: {e}"),
|
||||
};
|
||||
assert_eq!(bundles.len(), 1);
|
||||
assert_eq!(bundles[0].provenance, OperationProvenance::FromJsonSchema);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn import_non_object_schema_returns_schema_parse() {
|
||||
let adapter = FromJsonSchema::new(test_spec("ns/op"), serde_json::json!(42));
|
||||
match adapter.import().await {
|
||||
Ok(_) => panic!("expected Err"),
|
||||
Err(AdapterError::SchemaParse { message }) => {
|
||||
assert!(message.contains("JSON object"));
|
||||
}
|
||||
Err(other) => panic!("expected SchemaParse, got {other}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
106
crates/alknet-call/src/client/mod.rs
Normal file
106
crates/alknet-call/src/client/mod.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! Client adapters: turn external operation sources (JSON Schema, OpenAPI,
|
||||
//! MCP, remote `from_call` peers) into `HandlerRegistration` bundles.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/client-and-adapters.md` for the
|
||||
//! OperationAdapter trait and the Adapter Location Map, and
|
||||
//! `docs/architecture/decisions/017-call-protocol-client-and-adapter-contract.md`
|
||||
//! §5 for the trait contract.
|
||||
|
||||
mod call_client;
|
||||
mod from_call;
|
||||
mod from_jsonschema;
|
||||
|
||||
pub use call_client::{CallClient, CallCredentials, ClientError, RemoteIdentity};
|
||||
pub use from_call::{from_call, FromCallConfig};
|
||||
pub use from_jsonschema::{from_jsonschema, FromJsonSchema};
|
||||
|
||||
use crate::registry::registration::HandlerRegistration;
|
||||
|
||||
/// Errors produced by [`OperationAdapter::import`].
|
||||
///
|
||||
/// The variant set is the v1 default (two-way-door remainder, OQ-26);
|
||||
/// `#[non_exhaustive]` lets downstream adapters (e.g. `alknet-http`'s
|
||||
/// `from_openapi`/`from_mcp`) extend without breaking match arms. All
|
||||
/// payloads are string messages — kept simple and `Send + Sync` by
|
||||
/// construction.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum AdapterError {
|
||||
/// `from_call` remote unreachable / `services/list` failed.
|
||||
#[error("discovery failed: {message}")]
|
||||
DiscoveryFailed { message: String },
|
||||
|
||||
/// `from_openapi` / `from_jsonschema` couldn't parse the spec.
|
||||
#[error("schema parse error: {message}")]
|
||||
SchemaParse { message: String },
|
||||
|
||||
/// Underlying transport error (QUIC for `from_call`, HTTP for adapters).
|
||||
#[error("transport error: {message}")]
|
||||
Transport { message: String },
|
||||
|
||||
/// HTTP 401 for `from_openapi`/`from_mcp`, auth rejected for `from_call`.
|
||||
#[error("unauthorized: {message}")]
|
||||
Unauthorized { message: String },
|
||||
|
||||
/// Namespace collision in `from_call` (DC-3); reused for other adapters.
|
||||
#[error("conflict: {message}")]
|
||||
Conflict { message: String },
|
||||
}
|
||||
|
||||
/// Import a set of operations as `HandlerRegistration` bundles.
|
||||
///
|
||||
/// Async because `from_call` requires async discovery (`services/list` +
|
||||
/// `services/schema` over a QUIC connection); sync adapters (e.g.
|
||||
/// `from_jsonschema`, `from_openapi` reading a static spec) trivially satisfy
|
||||
/// an async trait — their `import()` bodies contain no `.await` points.
|
||||
///
|
||||
/// See ADR-017 §5 (`docs/architecture/decisions/017-call-protocol-client-and-adapter-contract.md`)
|
||||
/// and `docs/architecture/crates/call/client-and-adapters.md`.
|
||||
#[async_trait::async_trait]
|
||||
pub trait OperationAdapter: Send + Sync {
|
||||
async fn import(&self) -> Result<Vec<HandlerRegistration>, AdapterError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct OkAdapter;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationAdapter for OkAdapter {
|
||||
async fn import(&self) -> Result<Vec<HandlerRegistration>, AdapterError> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
struct ErrAdapter;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationAdapter for ErrAdapter {
|
||||
async fn import(&self) -> Result<Vec<HandlerRegistration>, AdapterError> {
|
||||
Err(AdapterError::SchemaParse {
|
||||
message: "x".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ok_adapter_imports_empty() {
|
||||
let adapter = OkAdapter;
|
||||
match adapter.import().await {
|
||||
Ok(bundles) => assert!(bundles.is_empty()),
|
||||
Err(e) => panic!("expected Ok, got Err: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn err_adapter_returns_schema_parse() {
|
||||
let adapter = ErrAdapter;
|
||||
match adapter.import().await {
|
||||
Ok(_) => panic!("expected Err"),
|
||||
Err(AdapterError::SchemaParse { message }) => assert_eq!(message, "x"),
|
||||
Err(other) => panic!("expected SchemaParse, got {other}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
11
crates/alknet-call/src/lib.rs
Normal file
11
crates/alknet-call/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! alknet-call: Structured RPC over QUIC — operations, streaming, service discovery.
|
||||
//!
|
||||
//! Implements [`alknet_core::types::ProtocolHandler`] on ALPN `alknet/call`.
|
||||
//!
|
||||
//! The crate has two subsystems:
|
||||
//! - [`registry`] — operation specs, context, dispatch, and the operation registry.
|
||||
//! - [`protocol`] — wire format, streams, and the call adapter.
|
||||
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod registry;
|
||||
393
crates/alknet-call/src/protocol/abort.rs
Normal file
393
crates/alknet-call/src/protocol/abort.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
//! Abort cascade logic for nested calls (ADR-016).
|
||||
//!
|
||||
//! When `call.aborted` arrives for a parent request, the protocol cascades
|
||||
//! the abort to all non-terminal descendants in the call tree. The default
|
||||
//! policy is `abort-dependents`; `continue-running` is an opt-in for
|
||||
//! long-running work that should survive a parent's abort.
|
||||
//!
|
||||
//! The call tree is indexed by `parent_request_id` in the
|
||||
//! `PendingRequestMap`. The root request has `parent_request_id: None`;
|
||||
//! each composed call has `parent_request_id: Some(parent.request_id)`.
|
||||
//! Composed child request IDs are internal — they appear in the map for
|
||||
//! abort-cascade indexing but are not sent as `call.requested` to any
|
||||
//! peer. The client only sees `call.aborted` for the root ID it sent; the
|
||||
//! server cascades internally to descendants.
|
||||
|
||||
use super::pending::PendingRequestMap;
|
||||
use crate::registry::context::AbortPolicy;
|
||||
|
||||
pub struct AbortCascade<'a> {
|
||||
pending: &'a mut PendingRequestMap,
|
||||
}
|
||||
|
||||
impl<'a> AbortCascade<'a> {
|
||||
pub fn new(pending: &'a mut PendingRequestMap) -> Self {
|
||||
Self { pending }
|
||||
}
|
||||
|
||||
/// Cascade an abort from the given request ID to all non-terminal
|
||||
/// descendants in the call tree. Returns the list of descendant
|
||||
/// request IDs that were aborted (for logging/auditing), sorted for
|
||||
/// determinism. The root request itself is not touched by this
|
||||
/// method — the caller is responsible for aborting the root (the
|
||||
/// trigger of the cascade).
|
||||
///
|
||||
/// Under `AbortDependents` (default): all descendants are aborted,
|
||||
/// regardless of whether they have started.
|
||||
///
|
||||
/// Under `ContinueRunning`: only descendants that have not started
|
||||
/// are aborted; started descendants continue to completion. No new
|
||||
/// descendants start (the parent is gone). This is the conservative
|
||||
/// approximation noted in ADR-016: a descendant is "started" if
|
||||
/// `PendingEntry::started` is true (the handler has begun
|
||||
/// executing). A `call.aborted` for an unknown request ID is
|
||||
/// silently discarded — `cascade_abort` on an unknown root returns
|
||||
/// an empty list and removes nothing.
|
||||
pub fn cascade_abort(&mut self, root_request_id: &str, policy: AbortPolicy) -> Vec<String> {
|
||||
if !self.pending.contains(root_request_id) {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let descendants = self.find_descendants(root_request_id);
|
||||
|
||||
let mut aborted = Vec::new();
|
||||
match policy {
|
||||
AbortPolicy::AbortDependents => {
|
||||
for id in &descendants {
|
||||
if self.pending.handle_aborted(id) {
|
||||
aborted.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
AbortPolicy::ContinueRunning => {
|
||||
for id in &descendants {
|
||||
let started = self.pending.is_started(id).unwrap_or(false);
|
||||
if !started && self.pending.handle_aborted(id) {
|
||||
aborted.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aborted.sort();
|
||||
aborted
|
||||
}
|
||||
|
||||
/// Find all descendants of a request ID in the call tree by walking
|
||||
/// the `parent_request_id` index. Returns descendants in
|
||||
/// breadth-first order with each level's children sorted for
|
||||
/// determinism. The root itself is not included in the result.
|
||||
fn find_descendants(&self, parent_id: &str) -> Vec<String> {
|
||||
let mut descendants = Vec::new();
|
||||
let mut frontier: Vec<String> = vec![parent_id.to_string()];
|
||||
|
||||
while let Some(current) = frontier.pop() {
|
||||
let mut children: Vec<String> = self
|
||||
.pending
|
||||
.request_ids()
|
||||
.into_iter()
|
||||
.filter(|id| {
|
||||
self.pending
|
||||
.parent_of(id)
|
||||
.flatten()
|
||||
.is_some_and(|p| p == current)
|
||||
})
|
||||
.collect();
|
||||
children.sort();
|
||||
for child in children {
|
||||
descendants.push(child.clone());
|
||||
frontier.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
descendants
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::wire::CallError;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn register_call(map: &mut PendingRequestMap, id: &str, parent: Option<&str>) {
|
||||
map.register_call(
|
||||
id.to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
parent.map(|p| p.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
fn register_subscribe(map: &mut PendingRequestMap, id: &str, parent: Option<&str>) {
|
||||
map.register_subscribe(id.to_string(), None, parent.map(|p| p.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_unknown_root_returns_empty_and_is_noop() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("does-not-exist", AbortPolicy::AbortDependents);
|
||||
assert!(aborted.is_empty());
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_abort_dependents_aborts_all_descendants() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-b", Some("r1"));
|
||||
register_call(&mut map, "r1-a-1", Some("r1-a"));
|
||||
register_call(&mut map, "r1-a-2", Some("r1-a"));
|
||||
register_call(&mut map, "r1-b-1", Some("r1-b"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
|
||||
assert_eq!(
|
||||
aborted,
|
||||
vec![
|
||||
"r1-a".to_string(),
|
||||
"r1-a-1".to_string(),
|
||||
"r1-a-2".to_string(),
|
||||
"r1-b".to_string(),
|
||||
"r1-b-1".to_string(),
|
||||
]
|
||||
);
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
assert!(!cascade.pending.contains("r1-a"));
|
||||
assert!(!cascade.pending.contains("r1-b"));
|
||||
assert!(!cascade.pending.contains("r1-a-1"));
|
||||
assert!(!cascade.pending.contains("r1-a-2"));
|
||||
assert!(!cascade.pending.contains("r1-b-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_continue_running_aborts_only_unstarted_descendants() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-b", Some("r1"));
|
||||
register_call(&mut map, "r1-a-1", Some("r1-a"));
|
||||
|
||||
map.mark_started("r1-a");
|
||||
// r1-b and r1-a-1 are unstarted
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::ContinueRunning);
|
||||
|
||||
assert_eq!(aborted, vec!["r1-a-1".to_string(), "r1-b".to_string()]);
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
assert!(cascade.pending.contains("r1-a"));
|
||||
assert!(!cascade.pending.contains("r1-b"));
|
||||
assert!(!cascade.pending.contains("r1-a-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_continue_running_aborts_all_when_none_started() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-b", Some("r1"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::ContinueRunning);
|
||||
|
||||
assert_eq!(aborted, vec!["r1-a".to_string(), "r1-b".to_string()]);
|
||||
assert!(!cascade.pending.contains("r1-a"));
|
||||
assert!(!cascade.pending.contains("r1-b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_depth_three_aborts_all_descendants() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "root", None);
|
||||
register_call(&mut map, "root-a", Some("root"));
|
||||
register_call(&mut map, "root-b", Some("root"));
|
||||
register_call(&mut map, "root-a-1", Some("root-a"));
|
||||
register_call(&mut map, "root-a-2", Some("root-a"));
|
||||
register_call(&mut map, "root-a-1-x", Some("root-a-1"));
|
||||
register_call(&mut map, "root-a-1-y", Some("root-a-1"));
|
||||
register_call(&mut map, "root-b-1", Some("root-b"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("root", AbortPolicy::AbortDependents);
|
||||
|
||||
assert_eq!(
|
||||
aborted,
|
||||
vec![
|
||||
"root-a".to_string(),
|
||||
"root-a-1".to_string(),
|
||||
"root-a-1-x".to_string(),
|
||||
"root-a-1-y".to_string(),
|
||||
"root-a-2".to_string(),
|
||||
"root-b".to_string(),
|
||||
"root-b-1".to_string(),
|
||||
]
|
||||
);
|
||||
assert!(cascade.pending.contains("root"));
|
||||
assert_eq!(cascade.pending.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_root_with_no_descendants_returns_empty() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "lonely", None);
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("lonely", AbortPolicy::AbortDependents);
|
||||
assert!(aborted.is_empty());
|
||||
assert!(cascade.pending.contains("lonely"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_only_aborts_descendants_not_siblings() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r2", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r2-a", Some("r2"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
|
||||
assert_eq!(aborted, vec!["r1-a".to_string()]);
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
assert!(cascade.pending.contains("r2"));
|
||||
assert!(cascade.pending.contains("r2-a"));
|
||||
assert!(!cascade.pending.contains("r1-a"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_handles_mixed_call_and_subscribe_entries() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_subscribe(&mut map, "r1-sub", Some("r1"));
|
||||
register_call(&mut map, "r1-sub-child", Some("r1-sub"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
|
||||
assert_eq!(
|
||||
aborted,
|
||||
vec!["r1-sub".to_string(), "r1-sub-child".to_string(),]
|
||||
);
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
assert_eq!(cascade.pending.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_continue_running_with_started_descendant_keeps_its_unstarted_children() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-a-1", Some("r1-a"));
|
||||
|
||||
map.mark_started("r1-a");
|
||||
// r1-a is started and continues; r1-a-1 is unstarted.
|
||||
// Under ContinueRunning, r1-a-1 is aborted (conservative: still pending).
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::ContinueRunning);
|
||||
|
||||
assert_eq!(aborted, vec!["r1-a-1".to_string()]);
|
||||
assert!(cascade.pending.contains("r1-a"));
|
||||
assert!(!cascade.pending.contains("r1-a-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_abort_dependents_aborts_started_descendants_too() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-b", Some("r1"));
|
||||
|
||||
map.mark_started("r1-a");
|
||||
map.mark_started("r1-b");
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
|
||||
assert_eq!(aborted, vec!["r1-a".to_string(), "r1-b".to_string()]);
|
||||
assert!(!cascade.pending.contains("r1-a"));
|
||||
assert!(!cascade.pending.contains("r1-b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_descendants_does_not_include_root() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
|
||||
let cascade = AbortCascade::new(&mut map);
|
||||
let descendants = cascade.find_descendants("r1");
|
||||
assert_eq!(descendants, vec!["r1-a".to_string()]);
|
||||
assert!(!descendants.contains(&"r1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_default_policy_is_abort_dependents() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
map.mark_started("r1-a");
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted_default = cascade.cascade_abort("r1", AbortPolicy::default());
|
||||
assert_eq!(aborted_default, vec!["r1-a".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_does_not_remove_root() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let _ = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
assert!(cascade.pending.contains("r1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_returns_sorted_descendants_for_determinism() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-z", Some("r1"));
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
register_call(&mut map, "r1-m", Some("r1"));
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::AbortDependents);
|
||||
assert_eq!(
|
||||
aborted,
|
||||
vec!["r1-a".to_string(), "r1-m".to_string(), "r1-z".to_string(),]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_request_id_silently_discarded_no_panic() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("totally-unknown", AbortPolicy::AbortDependents);
|
||||
assert!(aborted.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_continue_running_started_descendant_survives() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
register_call(&mut map, "r1", None);
|
||||
register_call(&mut map, "r1-a", Some("r1"));
|
||||
map.mark_started("r1-a");
|
||||
|
||||
let mut cascade = AbortCascade::new(&mut map);
|
||||
let aborted = cascade.cascade_abort("r1", AbortPolicy::ContinueRunning);
|
||||
assert!(aborted.is_empty());
|
||||
assert!(cascade.pending.contains("r1-a"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cascade_abort_handles_call_error_unused() {
|
||||
let _ = CallError::internal("unused");
|
||||
}
|
||||
}
|
||||
1242
crates/alknet-call/src/protocol/adapter.rs
Normal file
1242
crates/alknet-call/src/protocol/adapter.rs
Normal file
File diff suppressed because it is too large
Load Diff
810
crates/alknet-call/src/protocol/connection.rs
Normal file
810
crates/alknet-call/src/protocol/connection.rs
Normal file
@@ -0,0 +1,810 @@
|
||||
//! `CallConnection`: an established `alknet/call` connection (either
|
||||
//! direction — accepted or opened). Holds the connection's Layer 2 overlay
|
||||
//! (imported ops).
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/call-protocol.md` for the full
|
||||
//! specification.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use alknet_core::types::Connection;
|
||||
use futures::stream::Stream;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::pending::PendingRequestMap;
|
||||
use super::wire::{
|
||||
CallError, EventEnvelope, FrameFramedReader, FrameFramedWriter, EVENT_ABORTED, EVENT_COMPLETED,
|
||||
EVENT_ERROR, EVENT_RESPONDED,
|
||||
};
|
||||
use crate::protocol::wire::ResponseEnvelope;
|
||||
use crate::registry::context::{
|
||||
generate_request_id, AbortPolicy, OperationContext, ScopedOperationEnv,
|
||||
};
|
||||
use crate::registry::env::OperationEnv;
|
||||
use crate::registry::registration::{Handler, HandlerRegistration};
|
||||
|
||||
const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
pub struct CallConnection {
|
||||
connection: Arc<Connection>,
|
||||
imported_operations: Arc<RwLock<HashMap<String, HandlerRegistration>>>,
|
||||
pending: Arc<Mutex<PendingRequestMap>>,
|
||||
}
|
||||
|
||||
impl Clone for CallConnection {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
connection: Arc::clone(&self.connection),
|
||||
imported_operations: Arc::clone(&self.imported_operations),
|
||||
pending: Arc::clone(&self.pending),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CallConnection {
|
||||
pub fn new(connection: Connection) -> Self {
|
||||
Self {
|
||||
connection: Arc::new(connection),
|
||||
imported_operations: Arc::new(RwLock::new(HashMap::new())),
|
||||
pending: Arc::new(Mutex::new(PendingRequestMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connection(&self) -> &Arc<Connection> {
|
||||
&self.connection
|
||||
}
|
||||
|
||||
pub(crate) fn pending(&self) -> &Arc<Mutex<PendingRequestMap>> {
|
||||
&self.pending
|
||||
}
|
||||
|
||||
pub fn register_imported(&self, registration: HandlerRegistration) {
|
||||
let name = registration.spec.name.clone();
|
||||
self.imported_operations.write().insert(name, registration);
|
||||
}
|
||||
|
||||
pub fn register_imported_all(&self, registrations: Vec<HandlerRegistration>) {
|
||||
let mut overlay = self.imported_operations.write();
|
||||
for reg in registrations {
|
||||
overlay.insert(reg.spec.name.clone(), reg);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn overlay_env(&self) -> Arc<dyn OperationEnv + Send + Sync> {
|
||||
Arc::new(OverlayOperationEnv {
|
||||
overlay: Arc::clone(&self.imported_operations),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn call(&self, operation_id: &str, input: Value) -> ResponseEnvelope {
|
||||
let request_id = generate_request_id();
|
||||
let payload = serde_json::json!({
|
||||
"operationId": operation_id,
|
||||
"input": input,
|
||||
});
|
||||
|
||||
let (send, recv) = match self.connection.open_bi().await {
|
||||
Ok(pair) => pair,
|
||||
Err(err) => {
|
||||
let call_error = CallError::internal(format!("failed to open stream: {err}"));
|
||||
return ResponseEnvelope::error(request_id, call_error);
|
||||
}
|
||||
};
|
||||
|
||||
let receiver = {
|
||||
let mut pending = self.pending.lock();
|
||||
pending.register_call(
|
||||
request_id.clone(),
|
||||
Instant::now() + DEFAULT_CALL_TIMEOUT,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
if let Err(err) = self.write_request(send, &request_id, payload).await {
|
||||
let call_error = CallError::internal(err);
|
||||
self.pending
|
||||
.lock()
|
||||
.handle_error(&request_id, call_error.clone());
|
||||
return ResponseEnvelope::error(request_id, call_error);
|
||||
}
|
||||
|
||||
let pending = Arc::clone(&self.pending);
|
||||
tokio::spawn(async move {
|
||||
read_stream_until_closed(recv, &pending).await;
|
||||
});
|
||||
|
||||
match receiver.await {
|
||||
Ok(Ok(value)) => ResponseEnvelope::ok(request_id, value),
|
||||
Ok(Err(error)) => ResponseEnvelope::error(request_id, error),
|
||||
Err(_) => ResponseEnvelope::error(request_id, CallError::internal("request cancelled")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn subscribe(
|
||||
&self,
|
||||
operation_id: &str,
|
||||
input: Value,
|
||||
) -> impl Stream<Item = ResponseEnvelope> {
|
||||
let request_id = generate_request_id();
|
||||
let payload = serde_json::json!({
|
||||
"operationId": operation_id,
|
||||
"input": input,
|
||||
});
|
||||
|
||||
let (send, recv) = match self.connection.open_bi().await {
|
||||
Ok(pair) => pair,
|
||||
Err(err) => {
|
||||
let call_error = CallError::internal(format!("failed to open stream: {err}"));
|
||||
return SubscriptionStream::closed(request_id, call_error);
|
||||
}
|
||||
};
|
||||
|
||||
let receiver = {
|
||||
let mut pending = self.pending.lock();
|
||||
pending.register_subscribe(request_id.clone(), None, None)
|
||||
};
|
||||
|
||||
if let Err(err) = self.write_request(send, &request_id, payload).await {
|
||||
let call_error = CallError::internal(err);
|
||||
self.pending
|
||||
.lock()
|
||||
.handle_error(&request_id, call_error.clone());
|
||||
return SubscriptionStream::closed(request_id, call_error);
|
||||
}
|
||||
|
||||
let pending = Arc::clone(&self.pending);
|
||||
tokio::spawn(async move {
|
||||
read_stream_until_closed(recv, &pending).await;
|
||||
});
|
||||
|
||||
SubscriptionStream::new(request_id, receiver)
|
||||
}
|
||||
|
||||
pub async fn abort(&self, request_id: &str) {
|
||||
let envelope = EventEnvelope::aborted(request_id);
|
||||
if let Err(err) = self.write_envelope(&envelope).await {
|
||||
tracing::warn!(error = %err, request_id, "failed to send call.aborted");
|
||||
return;
|
||||
}
|
||||
self.pending.lock().handle_aborted(request_id);
|
||||
}
|
||||
|
||||
async fn write_request(
|
||||
&self,
|
||||
send: alknet_core::types::SendStream,
|
||||
request_id: &str,
|
||||
payload: Value,
|
||||
) -> Result<(), String> {
|
||||
let envelope = EventEnvelope::requested(request_id, payload);
|
||||
let mut writer = FrameFramedWriter::new(send);
|
||||
writer
|
||||
.write_frame(&envelope)
|
||||
.await
|
||||
.map_err(|e| format!("failed to write frame: {e}"))
|
||||
}
|
||||
|
||||
async fn write_envelope(&self, envelope: &EventEnvelope) -> Result<(), String> {
|
||||
let (send, _recv) = self
|
||||
.connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| format!("failed to open stream: {e}"))?;
|
||||
let mut writer = FrameFramedWriter::new(send);
|
||||
writer
|
||||
.write_frame(envelope)
|
||||
.await
|
||||
.map_err(|e| format!("failed to write frame: {e}"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_stream_until_closed(
|
||||
recv: alknet_core::types::RecvStream,
|
||||
pending: &Arc<Mutex<PendingRequestMap>>,
|
||||
) {
|
||||
let mut reader = FrameFramedReader::new(recv);
|
||||
while let Ok(envelope) = reader.read_frame().await {
|
||||
dispatch_envelope(pending, envelope);
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_envelope(pending: &Arc<Mutex<PendingRequestMap>>, envelope: EventEnvelope) {
|
||||
let request_id = envelope.id.clone();
|
||||
match envelope.r#type.as_str() {
|
||||
EVENT_RESPONDED => {
|
||||
let output = envelope
|
||||
.payload
|
||||
.get("output")
|
||||
.cloned()
|
||||
.unwrap_or(Value::Null);
|
||||
pending.lock().handle_responded(&request_id, output);
|
||||
}
|
||||
EVENT_COMPLETED => {
|
||||
pending.lock().handle_completed(&request_id);
|
||||
}
|
||||
EVENT_ABORTED => {
|
||||
pending.lock().handle_aborted(&request_id);
|
||||
}
|
||||
EVENT_ERROR => {
|
||||
if let Ok(error) = serde_json::from_value::<CallError>(envelope.payload) {
|
||||
pending.lock().handle_error(&request_id, error);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
struct OverlayOperationEnv {
|
||||
overlay: Arc<RwLock<HashMap<String, HandlerRegistration>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for OverlayOperationEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
input: Value,
|
||||
parent: &OperationContext,
|
||||
policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
let name = format!("{namespace}/{operation}");
|
||||
|
||||
if !parent.scoped_env.allows(&name) {
|
||||
return ResponseEnvelope::not_found(parent.request_id.clone(), &name);
|
||||
}
|
||||
|
||||
let handler: Handler;
|
||||
let composition_authority;
|
||||
let scoped_env;
|
||||
{
|
||||
let overlay = self.overlay.read();
|
||||
let Some(registration) = overlay.get(&name) else {
|
||||
return ResponseEnvelope::not_found(parent.request_id.clone(), &name);
|
||||
};
|
||||
handler = Arc::clone(®istration.handler);
|
||||
composition_authority = registration.composition_authority.clone();
|
||||
scoped_env = registration
|
||||
.scoped_env
|
||||
.clone()
|
||||
.unwrap_or_else(ScopedOperationEnv::empty);
|
||||
}
|
||||
|
||||
let context = OperationContext {
|
||||
request_id: generate_request_id(),
|
||||
parent_request_id: Some(parent.request_id.clone()),
|
||||
identity: parent
|
||||
.handler_identity
|
||||
.as_ref()
|
||||
.and_then(|ca| ca.as_identity()),
|
||||
handler_identity: composition_authority,
|
||||
forwarded_for: None,
|
||||
capabilities: parent.capabilities.clone(),
|
||||
metadata: HashMap::new(),
|
||||
abort_policy: policy,
|
||||
deadline: parent.deadline,
|
||||
scoped_env,
|
||||
env: parent.env.clone(),
|
||||
internal: true,
|
||||
};
|
||||
|
||||
handler(input, context).await
|
||||
}
|
||||
|
||||
fn contains(&self, name: &str) -> bool {
|
||||
self.overlay.read().contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SubscriptionStream {
|
||||
request_id: String,
|
||||
receiver: mpsc::Receiver<Result<Value, CallError>>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl SubscriptionStream {
|
||||
fn new(request_id: String, receiver: mpsc::Receiver<Result<Value, CallError>>) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
receiver,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn closed(request_id: String, error: CallError) -> Self {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
let _ = tx.try_send(Err(error));
|
||||
Self {
|
||||
request_id,
|
||||
receiver: rx,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for SubscriptionStream {
|
||||
type Item = ResponseEnvelope;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if self.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
let this = self.get_mut();
|
||||
match this.receiver.poll_recv(cx) {
|
||||
Poll::Ready(None) => {
|
||||
this.done = true;
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Ok(value))) => {
|
||||
Poll::Ready(Some(ResponseEnvelope::ok(this.request_id.clone(), value)))
|
||||
}
|
||||
Poll::Ready(Some(Err(error))) => {
|
||||
this.done = true;
|
||||
Poll::Ready(Some(ResponseEnvelope::error(
|
||||
this.request_id.clone(),
|
||||
error,
|
||||
)))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::registry::context::CompositionAuthority;
|
||||
use crate::registry::registration::{make_handler, OperationProvenance};
|
||||
use crate::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||
use alknet_core::types::{Capabilities, MockConnection};
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
struct StubConnection {
|
||||
alpn: &'static [u8],
|
||||
addr: Option<SocketAddr>,
|
||||
closed: StdMutex<Option<(u32, String)>>,
|
||||
}
|
||||
|
||||
impl MockConnection for StubConnection {
|
||||
fn remote_alpn(&self) -> &[u8] {
|
||||
self.alpn
|
||||
}
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.addr
|
||||
}
|
||||
fn close(&self, code: u32, reason: &str) {
|
||||
*self.closed.lock().unwrap() = Some((code, reason.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
fn stub_connection() -> Connection {
|
||||
Connection::from_mock(Arc::new(StubConnection {
|
||||
alpn: b"alknet/call",
|
||||
addr: Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4321)),
|
||||
closed: StdMutex::new(None),
|
||||
}))
|
||||
}
|
||||
|
||||
fn external_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn echo_handler() -> Handler {
|
||||
make_handler(
|
||||
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
|
||||
)
|
||||
}
|
||||
|
||||
fn imported_registration(name: &str) -> HandlerRegistration {
|
||||
HandlerRegistration::new(
|
||||
external_spec(name),
|
||||
echo_handler(),
|
||||
OperationProvenance::FromCall,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
)
|
||||
}
|
||||
|
||||
fn root_context(
|
||||
request_id: &str,
|
||||
scoped_env: ScopedOperationEnv,
|
||||
env: Arc<dyn OperationEnv + Send + Sync>,
|
||||
) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
handler_identity: Some(CompositionAuthority::new("agent", ["fs:read".to_string()])),
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env,
|
||||
env,
|
||||
abort_policy: AbortPolicy::default(),
|
||||
deadline: Some(Instant::now() + Duration::from_secs(30)),
|
||||
internal: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_imported_adds_to_overlay_and_contains_returns_true() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
let env = conn.overlay_env();
|
||||
|
||||
assert!(!env.contains("worker/exec"));
|
||||
|
||||
conn.register_imported(imported_registration("worker/exec"));
|
||||
|
||||
assert!(env.contains("worker/exec"));
|
||||
assert!(!env.contains("worker/missing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_imported_all_bulk_adds_to_overlay() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
let env = conn.overlay_env();
|
||||
|
||||
conn.register_imported_all(vec![
|
||||
imported_registration("worker/exec"),
|
||||
imported_registration("worker/status"),
|
||||
imported_registration("fs/readFile"),
|
||||
]);
|
||||
|
||||
assert!(env.contains("worker/exec"));
|
||||
assert!(env.contains("worker/status"));
|
||||
assert!(env.contains("fs/readFile"));
|
||||
assert!(!env.contains("worker/missing"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlay_env_dispatches_to_imported_op() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
conn.register_imported(imported_registration("worker/exec"));
|
||||
let env = conn.overlay_env();
|
||||
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-1", scoped, env.clone());
|
||||
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({"hi": 1}), &ctx)
|
||||
.await;
|
||||
|
||||
assert!(response.result.is_ok());
|
||||
assert_eq!(response.result.unwrap(), serde_json::json!({"hi": 1}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlay_env_contains_returns_false_for_non_imported_op() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
conn.register_imported(imported_registration("worker/exec"));
|
||||
let env = conn.overlay_env();
|
||||
|
||||
assert!(!env.contains("worker/missing"));
|
||||
|
||||
let scoped = ScopedOperationEnv::new(["worker/missing"]);
|
||||
let ctx = root_context("root-2", scoped, env.clone());
|
||||
|
||||
let response = env
|
||||
.invoke("worker", "missing", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlay_env_reachability_check_returns_not_found_for_disallowed_op() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
conn.register_imported(imported_registration("worker/exec"));
|
||||
let env = conn.overlay_env();
|
||||
|
||||
let scoped = ScopedOperationEnv::empty();
|
||||
let ctx = root_context("root-3", scoped, env.clone());
|
||||
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlay_env_dispatched_child_has_internal_true_and_parent_set() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
let inspect_handler = make_handler(|_input, context| async move {
|
||||
let internal = context.is_internal();
|
||||
let parent_set = context.parent_request_id.is_some();
|
||||
ResponseEnvelope::ok(
|
||||
context.request_id,
|
||||
serde_json::json!({
|
||||
"internal": internal,
|
||||
"parent_set": parent_set,
|
||||
}),
|
||||
)
|
||||
});
|
||||
conn.register_imported(HandlerRegistration::new(
|
||||
external_spec("worker/exec"),
|
||||
inspect_handler,
|
||||
OperationProvenance::FromCall,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let env = conn.overlay_env();
|
||||
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-4", scoped, env.clone());
|
||||
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
let out = response.result.expect("ok");
|
||||
assert_eq!(out["internal"], Value::Bool(true));
|
||||
assert_eq!(out["parent_set"], Value::Bool(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_accessor_returns_underlying_connection() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
assert_eq!(conn.connection().remote_alpn(), b"alknet/call");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_overlay_contains_nothing() {
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
let env = conn.overlay_env();
|
||||
assert!(!env.contains("anything"));
|
||||
assert!(!env.contains(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overlay_drops_with_connection() {
|
||||
let captured: Arc<RwLock<HashMap<String, HandlerRegistration>>> =
|
||||
Arc::new(RwLock::new(HashMap::new()));
|
||||
{
|
||||
let conn = CallConnection::new(stub_connection());
|
||||
conn.register_imported(imported_registration("worker/exec"));
|
||||
assert!(conn.overlay_env().contains("worker/exec"));
|
||||
std::mem::swap(
|
||||
&mut *captured.write(),
|
||||
&mut *conn.imported_operations.write(),
|
||||
);
|
||||
}
|
||||
assert!(captured.read().contains_key("worker/exec"));
|
||||
}
|
||||
|
||||
// --- dispatch_envelope -------------------------------------------------
|
||||
|
||||
fn empty_pending() -> Arc<Mutex<PendingRequestMap>> {
|
||||
Arc::new(Mutex::new(PendingRequestMap::new()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_responded_resolves_call_receiver() {
|
||||
let pending = empty_pending();
|
||||
let rx = pending.lock().register_call(
|
||||
"req-1".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let envelope = EventEnvelope::responded("req-1", serde_json::json!({"v": 42}));
|
||||
dispatch_envelope(&pending, envelope);
|
||||
assert!(!pending.lock().contains("req-1"));
|
||||
let result = tokio::time::timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Ok(value))) => assert_eq!(value, serde_json::json!({"v": 42})),
|
||||
other => panic!("expected Ok({{v:42}}), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_responded_pushes_to_subscribe_channel() {
|
||||
let pending = empty_pending();
|
||||
let mut rx = pending
|
||||
.lock()
|
||||
.register_subscribe("sub-1".to_string(), None, None);
|
||||
dispatch_envelope(
|
||||
&pending,
|
||||
EventEnvelope::responded("sub-1", serde_json::json!("first")),
|
||||
);
|
||||
dispatch_envelope(
|
||||
&pending,
|
||||
EventEnvelope::responded("sub-1", serde_json::json!("second")),
|
||||
);
|
||||
assert!(pending.lock().contains("sub-1"));
|
||||
let a = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
let b = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match (a, b) {
|
||||
(Ok(Some(Ok(x))), Ok(Some(Ok(y)))) => {
|
||||
assert_eq!(x, serde_json::json!("first"));
|
||||
assert_eq!(y, serde_json::json!("second"));
|
||||
}
|
||||
other => panic!("expected two Ok values, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_completed_removes_entry() {
|
||||
let pending = empty_pending();
|
||||
let _rx = pending
|
||||
.lock()
|
||||
.register_subscribe("sub-2".to_string(), None, None);
|
||||
assert!(pending.lock().contains("sub-2"));
|
||||
dispatch_envelope(&pending, EventEnvelope::completed("sub-2"));
|
||||
assert!(!pending.lock().contains("sub-2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_aborted_removes_entry() {
|
||||
let pending = empty_pending();
|
||||
let _rx = pending.lock().register_call(
|
||||
"req-2".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
assert!(pending.lock().contains("req-2"));
|
||||
dispatch_envelope(&pending, EventEnvelope::aborted("req-2"));
|
||||
assert!(!pending.lock().contains("req-2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_error_resolves_call_with_error() {
|
||||
let pending = empty_pending();
|
||||
let rx = pending.lock().register_call(
|
||||
"req-3".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let err = CallError::new("FILE_NOT_FOUND", "missing", false);
|
||||
dispatch_envelope(&pending, EventEnvelope::error("req-3", &err));
|
||||
assert!(!pending.lock().contains("req-3"));
|
||||
let result = tokio::time::timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Err(e))) => {
|
||||
assert_eq!(e.code, "FILE_NOT_FOUND");
|
||||
assert!(!e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(FILE_NOT_FOUND), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_error_pushes_error_to_subscribe_channel() {
|
||||
let pending = empty_pending();
|
||||
let mut rx = pending
|
||||
.lock()
|
||||
.register_subscribe("sub-3".to_string(), None, None);
|
||||
let err = CallError::new("RATE_LIMITED", "slow down", true);
|
||||
dispatch_envelope(&pending, EventEnvelope::error("sub-3", &err));
|
||||
assert!(!pending.lock().contains("sub-3"));
|
||||
let result = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match result {
|
||||
Ok(Some(Err(e))) => {
|
||||
assert_eq!(e.code, "RATE_LIMITED");
|
||||
assert!(e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(RATE_LIMITED), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_error_with_invalid_payload_is_no_op() {
|
||||
let pending = empty_pending();
|
||||
let _rx = pending.lock().register_call(
|
||||
"req-4".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let malformed =
|
||||
EventEnvelope::new(EVENT_ERROR, "req-4", serde_json::json!("not-an-object"));
|
||||
dispatch_envelope(&pending, malformed);
|
||||
assert!(pending.lock().contains("req-4"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_unknown_event_type_is_no_op() {
|
||||
let pending = empty_pending();
|
||||
let _rx = pending.lock().register_call(
|
||||
"req-5".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let unknown = EventEnvelope::new("call.mystery", "req-5", serde_json::json!({}));
|
||||
dispatch_envelope(&pending, unknown);
|
||||
assert!(pending.lock().contains("req-5"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_envelope_unknown_request_id_is_no_op() {
|
||||
let pending = empty_pending();
|
||||
dispatch_envelope(
|
||||
&pending,
|
||||
EventEnvelope::responded("ghost", serde_json::json!(1)),
|
||||
);
|
||||
dispatch_envelope(&pending, EventEnvelope::completed("ghost"));
|
||||
dispatch_envelope(&pending, EventEnvelope::aborted("ghost"));
|
||||
assert!(pending.lock().is_empty());
|
||||
}
|
||||
|
||||
// --- SubscriptionStream ------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscription_stream_closed_yields_one_error_then_ends() {
|
||||
use futures::stream::StreamExt;
|
||||
let err = CallError::internal("stream closed before send");
|
||||
let mut stream = SubscriptionStream::closed("req-x".to_string(), err);
|
||||
let first = stream.next().await;
|
||||
match first {
|
||||
Some(env) => {
|
||||
assert_eq!(env.request_id, "req-x");
|
||||
assert!(env.result.is_err());
|
||||
assert_eq!(env.result.unwrap_err().code, "INTERNAL");
|
||||
}
|
||||
other => panic!("expected one error envelope, got {other:?}"),
|
||||
}
|
||||
let second = stream.next().await;
|
||||
assert!(second.is_none(), "stream must terminate after the error");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscription_stream_emits_ok_values_then_completes() {
|
||||
use futures::stream::StreamExt;
|
||||
let (tx, rx) = mpsc::channel(8);
|
||||
let mut stream = SubscriptionStream::new("req-y".to_string(), rx);
|
||||
tx.try_send(Ok(serde_json::json!(1))).unwrap();
|
||||
tx.try_send(Ok(serde_json::json!(2))).unwrap();
|
||||
drop(tx);
|
||||
|
||||
let a = stream.next().await.unwrap();
|
||||
assert_eq!(a.request_id, "req-y");
|
||||
assert_eq!(a.result.unwrap(), serde_json::json!(1));
|
||||
let b = stream.next().await.unwrap();
|
||||
assert_eq!(b.result.unwrap(), serde_json::json!(2));
|
||||
assert!(
|
||||
stream.next().await.is_none(),
|
||||
"stream ends after channel closes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscription_stream_emits_error_then_terminates() {
|
||||
use futures::stream::StreamExt;
|
||||
let (tx, rx) = mpsc::channel(8);
|
||||
let mut stream = SubscriptionStream::new("req-z".to_string(), rx);
|
||||
tx.try_send(Ok(serde_json::json!("ok"))).unwrap();
|
||||
tx.try_send(Err(CallError::timeout("timed out"))).unwrap();
|
||||
drop(tx);
|
||||
|
||||
let first = stream.next().await.unwrap();
|
||||
assert_eq!(first.result.unwrap(), serde_json::json!("ok"));
|
||||
let second = stream.next().await.unwrap();
|
||||
assert_eq!(second.request_id, "req-z");
|
||||
assert_eq!(second.result.unwrap_err().code, "TIMEOUT");
|
||||
assert!(
|
||||
stream.next().await.is_none(),
|
||||
"stream terminates after error"
|
||||
);
|
||||
}
|
||||
}
|
||||
318
crates/alknet-call/src/protocol/dispatch.rs
Normal file
318
crates/alknet-call/src/protocol/dispatch.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
//! Shared dispatch loop for `alknet/call` connections.
|
||||
//!
|
||||
//! Both [`CallAdapter`]'s accept path and [`crate::client::CallClient`]'s
|
||||
//! connect path produce a [`CallConnection`] and hand it to the same dispatch
|
||||
//! loop here (ADR-017 §1): the loop reads `EventEnvelope` frames off accepted
|
||||
//! bidirectional streams, dispatches `call.requested` events against the
|
||||
//! operation registry, and writes the response back on the same stream. The
|
||||
//! connection-establishment half differs (accept vs dial); the dispatch half
|
||||
//! is shared.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/call-protocol.md` and
|
||||
//! `docs/architecture/crates/call/client-and-adapters.md` for the spec.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use alknet_core::auth::{AuthToken, Identity, IdentityProvider};
|
||||
use alknet_core::types::StreamError;
|
||||
use serde_json::Value;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::abort::AbortCascade;
|
||||
use super::connection::CallConnection;
|
||||
use super::wire::{
|
||||
CallError, EventEnvelope, FrameFramedReader, FrameFramedWriter, ResponseEnvelope,
|
||||
EVENT_ABORTED, EVENT_REQUESTED,
|
||||
};
|
||||
use crate::protocol::adapter::SessionOverlaySource;
|
||||
use crate::registry::context::{AbortPolicy, OperationContext, ScopedOperationEnv};
|
||||
use crate::registry::env::{LocalOperationEnv, OperationEnv, PeerCompositeEnv};
|
||||
use crate::registry::registration::OperationRegistry;
|
||||
|
||||
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const SWEEPER_INTERVAL: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Shared dispatcher for an established `CallConnection`. Constructed by
|
||||
/// both `CallAdapter` (accept path) and `CallClient` (connect path) and used
|
||||
/// to run the dispatch loop. Holds no per-connection state; the
|
||||
/// `CallConnection` is passed into `run_loop`.
|
||||
pub struct Dispatcher {
|
||||
pub registry: Arc<OperationRegistry>,
|
||||
pub identity_provider: Arc<dyn IdentityProvider>,
|
||||
pub session_source: Option<Arc<dyn SessionOverlaySource + Send + Sync>>,
|
||||
pub default_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Dispatcher {
|
||||
pub fn new(
|
||||
registry: Arc<OperationRegistry>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
identity_provider,
|
||||
session_source: None,
|
||||
default_timeout: DEFAULT_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_session_source(
|
||||
mut self,
|
||||
source: Arc<dyn SessionOverlaySource + Send + Sync>,
|
||||
) -> Self {
|
||||
self.session_source = Some(source);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.default_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
fn strip_leading_slash(operation_id: &str) -> &str {
|
||||
operation_id.strip_prefix('/').unwrap_or(operation_id)
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_identity(
|
||||
&self,
|
||||
connection_identity: Option<Identity>,
|
||||
payload: &Value,
|
||||
) -> Option<Identity> {
|
||||
let auth_token = payload.get("auth_token").and_then(|v| v.as_str());
|
||||
match auth_token {
|
||||
Some(token_str) => {
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
match self.identity_provider.resolve_from_token(&token) {
|
||||
Some(identity) => Some(identity),
|
||||
None => connection_identity,
|
||||
}
|
||||
}
|
||||
None => connection_identity,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn compose_root_env(
|
||||
&self,
|
||||
connection: &CallConnection,
|
||||
context: &OperationContext,
|
||||
) -> Arc<dyn OperationEnv + Send + Sync> {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(LocalOperationEnv::new(Arc::clone(&self.registry)));
|
||||
let session = self
|
||||
.session_source
|
||||
.as_ref()
|
||||
.and_then(|s| s.overlay_for(context));
|
||||
|
||||
let mut env = PeerCompositeEnv::new(base);
|
||||
if let Some(session) = session {
|
||||
env = env.with_session(session);
|
||||
}
|
||||
if let Some(peer_id) = connection
|
||||
.connection()
|
||||
.identity()
|
||||
.map(|identity| identity.id.clone())
|
||||
{
|
||||
env.attach_peer(peer_id, connection.overlay_env());
|
||||
}
|
||||
Arc::new(env)
|
||||
}
|
||||
|
||||
pub(crate) fn build_root_context(
|
||||
&self,
|
||||
request_id: String,
|
||||
operation_name: &str,
|
||||
identity: Option<Identity>,
|
||||
forwarded_for: Option<Identity>,
|
||||
connection: &CallConnection,
|
||||
) -> OperationContext {
|
||||
let registration = self.registry.registration(operation_name);
|
||||
let (composition_authority, capabilities, scoped_env) = match registration {
|
||||
Some(r) => (
|
||||
r.composition_authority.clone(),
|
||||
r.capabilities.clone(),
|
||||
r.scoped_env
|
||||
.clone()
|
||||
.unwrap_or_else(ScopedOperationEnv::empty),
|
||||
),
|
||||
None => (
|
||||
None,
|
||||
alknet_core::types::Capabilities::new(),
|
||||
ScopedOperationEnv::empty(),
|
||||
),
|
||||
};
|
||||
|
||||
let stub_env: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(LocalOperationEnv::new(Arc::clone(&self.registry)));
|
||||
let mut context = OperationContext {
|
||||
request_id,
|
||||
parent_request_id: None,
|
||||
identity: identity.clone(),
|
||||
handler_identity: composition_authority,
|
||||
forwarded_for,
|
||||
capabilities,
|
||||
metadata: HashMap::new(),
|
||||
deadline: Some(Instant::now() + self.default_timeout),
|
||||
scoped_env,
|
||||
env: stub_env,
|
||||
abort_policy: AbortPolicy::default(),
|
||||
internal: false,
|
||||
};
|
||||
context.env = self.compose_root_env(connection, &context);
|
||||
context
|
||||
}
|
||||
|
||||
pub(crate) async fn dispatch_requested(
|
||||
&self,
|
||||
connection: &Arc<CallConnection>,
|
||||
request_id: String,
|
||||
payload: Value,
|
||||
) -> ResponseEnvelope {
|
||||
let operation_id = payload
|
||||
.get("operationId")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let operation_name = Self::strip_leading_slash(operation_id).to_string();
|
||||
|
||||
let connection_identity = connection.connection().identity().cloned();
|
||||
let identity = self.resolve_identity(connection_identity, &payload);
|
||||
|
||||
let forwarded_for = payload
|
||||
.get("forwarded_for")
|
||||
.and_then(|v| serde_json::from_value::<Identity>(v.clone()).ok());
|
||||
|
||||
let input = payload.get("input").cloned().unwrap_or(Value::Null);
|
||||
|
||||
let context = self.build_root_context(
|
||||
request_id.clone(),
|
||||
&operation_name,
|
||||
identity,
|
||||
forwarded_for,
|
||||
connection,
|
||||
);
|
||||
|
||||
self.registry.invoke(&operation_name, input, context).await
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_stream(
|
||||
&self,
|
||||
connection: Arc<CallConnection>,
|
||||
send: alknet_core::types::SendStream,
|
||||
recv: alknet_core::types::RecvStream,
|
||||
) {
|
||||
let mut reader = FrameFramedReader::new(recv);
|
||||
let mut writer = FrameFramedWriter::new(send);
|
||||
|
||||
loop {
|
||||
let envelope = match reader.read_frame().await {
|
||||
Ok(env) => env,
|
||||
Err(super::wire::FrameError::ConnectionClosed) => break,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "stream frame read error; closing stream");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
match envelope.r#type.as_str() {
|
||||
EVENT_REQUESTED => {
|
||||
let request_id = envelope.id.clone();
|
||||
let payload = envelope.payload.clone();
|
||||
|
||||
let response = self
|
||||
.dispatch_requested(&connection, request_id.clone(), payload)
|
||||
.await;
|
||||
|
||||
let event: EventEnvelope = response.into();
|
||||
if let Err(err) = writer.write_frame(&event).await {
|
||||
warn!(error = %err, "failed to write response frame; closing stream");
|
||||
break;
|
||||
}
|
||||
}
|
||||
EVENT_ABORTED => {
|
||||
let request_id = envelope.id.clone();
|
||||
let mut pending = connection.pending().lock();
|
||||
let mut cascade = AbortCascade::new(&mut pending);
|
||||
let aborted = cascade.cascade_abort(&request_id, AbortPolicy::AbortDependents);
|
||||
pending.handle_aborted(&request_id);
|
||||
if !aborted.is_empty() {
|
||||
debug!(count = aborted.len(), "abort cascade evicted descendants");
|
||||
}
|
||||
}
|
||||
other => {
|
||||
debug!(event_type = %other, id = %envelope.id, "ignoring non-requested/non-aborted event on inbound stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the shared dispatch loop over an established `CallConnection`:
|
||||
/// spawn the pending-entry sweeper, accept bidirectional streams until the
|
||||
/// connection closes, dispatch each stream via `handle_stream`, and fail
|
||||
/// outstanding pending requests on close. Returns when the connection is
|
||||
/// closed (accept loop yields `ConnectionClosed`/`StreamClosed`/`Timeout`).
|
||||
pub async fn run_loop(self, connection: Arc<CallConnection>) {
|
||||
let pending = Arc::clone(connection.pending());
|
||||
|
||||
let sweeper_pending = Arc::clone(&pending);
|
||||
let sweeper_handle: JoinHandle<()> = tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(SWEEPER_INTERVAL);
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let evicted = sweeper_pending.lock().evict_expired();
|
||||
if !evicted.is_empty() {
|
||||
debug!(
|
||||
count = evicted.len(),
|
||||
"sweeper evicted expired pending entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
match connection.connection().accept_bi().await {
|
||||
Ok((send, recv)) => {
|
||||
let conn = Arc::clone(&connection);
|
||||
let dispatcher = self.clone();
|
||||
tokio::spawn(async move {
|
||||
dispatcher.handle_stream(conn, send, recv).await;
|
||||
});
|
||||
}
|
||||
Err(StreamError::ConnectionClosed) => break,
|
||||
Err(StreamError::StreamClosed) => break,
|
||||
Err(StreamError::Timeout) => break,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "accept_bi error; stopping accept loop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let failed = pending
|
||||
.lock()
|
||||
.fail_all(CallError::internal("connection closed"));
|
||||
if !failed.is_empty() {
|
||||
debug!(
|
||||
count = failed.len(),
|
||||
"failed pending requests on connection close"
|
||||
);
|
||||
}
|
||||
|
||||
sweeper_handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Dispatcher {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
registry: Arc::clone(&self.registry),
|
||||
identity_provider: Arc::clone(&self.identity_provider),
|
||||
session_source: self.session_source.clone(),
|
||||
default_timeout: self.default_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
12
crates/alknet-call/src/protocol/mod.rs
Normal file
12
crates/alknet-call/src/protocol/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Call protocol: wire format, streams, and the call adapter.
|
||||
//!
|
||||
//! Implements `ProtocolHandler` for ALPN `alknet/call` on top of the
|
||||
//! operation registry. See `docs/architecture/crates/call/call-protocol.md`
|
||||
//! for the full specification.
|
||||
|
||||
pub mod abort;
|
||||
pub mod adapter;
|
||||
pub mod connection;
|
||||
pub mod dispatch;
|
||||
pub mod pending;
|
||||
pub mod wire;
|
||||
584
crates/alknet-call/src/protocol/pending.rs
Normal file
584
crates/alknet-call/src/protocol/pending.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::protocol::wire::CallError;
|
||||
|
||||
const SUBSCRIBE_CHANNEL_CAPACITY: usize = 32;
|
||||
|
||||
pub struct PendingRequestMap {
|
||||
pending: HashMap<String, PendingEntry>,
|
||||
}
|
||||
|
||||
pub(crate) enum PendingEntry {
|
||||
Call {
|
||||
tx: oneshot::Sender<Result<Value, CallError>>,
|
||||
timeout: Instant,
|
||||
parent_request_id: Option<String>,
|
||||
started: bool,
|
||||
},
|
||||
Subscribe {
|
||||
tx: mpsc::Sender<Result<Value, CallError>>,
|
||||
timeout: Option<Instant>,
|
||||
parent_request_id: Option<String>,
|
||||
started: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl PendingEntry {
|
||||
pub(crate) fn parent_request_id(&self) -> Option<&str> {
|
||||
match self {
|
||||
PendingEntry::Call {
|
||||
parent_request_id, ..
|
||||
} => parent_request_id.as_deref(),
|
||||
PendingEntry::Subscribe {
|
||||
parent_request_id, ..
|
||||
} => parent_request_id.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn started(&self) -> bool {
|
||||
match self {
|
||||
PendingEntry::Call { started, .. } => *started,
|
||||
PendingEntry::Subscribe { started, .. } => *started,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PendingRequestMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_call(
|
||||
&mut self,
|
||||
request_id: String,
|
||||
timeout: Instant,
|
||||
parent_request_id: Option<String>,
|
||||
) -> oneshot::Receiver<Result<Value, CallError>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.pending.insert(
|
||||
request_id,
|
||||
PendingEntry::Call {
|
||||
tx,
|
||||
timeout,
|
||||
parent_request_id,
|
||||
started: false,
|
||||
},
|
||||
);
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn register_subscribe(
|
||||
&mut self,
|
||||
request_id: String,
|
||||
timeout: Option<Instant>,
|
||||
parent_request_id: Option<String>,
|
||||
) -> mpsc::Receiver<Result<Value, CallError>> {
|
||||
let (tx, rx) = mpsc::channel(SUBSCRIBE_CHANNEL_CAPACITY);
|
||||
self.pending.insert(
|
||||
request_id,
|
||||
PendingEntry::Subscribe {
|
||||
tx,
|
||||
timeout,
|
||||
parent_request_id,
|
||||
started: false,
|
||||
},
|
||||
);
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn mark_started(&mut self, request_id: &str) -> bool {
|
||||
let Some(entry) = self.pending.get_mut(request_id) else {
|
||||
return false;
|
||||
};
|
||||
match entry {
|
||||
PendingEntry::Call { started, .. } => *started = true,
|
||||
PendingEntry::Subscribe { started, .. } => *started = true,
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn handle_responded(&mut self, request_id: &str, output: Value) -> bool {
|
||||
let Some(entry) = self.pending.remove(request_id) else {
|
||||
return false;
|
||||
};
|
||||
match entry {
|
||||
PendingEntry::Call { tx, .. } => {
|
||||
let _ = tx.send(Ok(output));
|
||||
true
|
||||
}
|
||||
PendingEntry::Subscribe {
|
||||
tx,
|
||||
timeout,
|
||||
parent_request_id,
|
||||
started,
|
||||
} => {
|
||||
let send_result = tx.try_send(Ok(output));
|
||||
match send_result {
|
||||
Ok(()) => {
|
||||
self.pending.insert(
|
||||
request_id.to_string(),
|
||||
PendingEntry::Subscribe {
|
||||
tx,
|
||||
timeout,
|
||||
parent_request_id,
|
||||
started,
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
tracing::warn!(
|
||||
request_id,
|
||||
"subscribe channel full; dropping entry and closing subscription"
|
||||
);
|
||||
true
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_completed(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn handle_aborted(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn handle_error(&mut self, request_id: &str, error: CallError) -> bool {
|
||||
let Some(entry) = self.pending.remove(request_id) else {
|
||||
return false;
|
||||
};
|
||||
match entry {
|
||||
PendingEntry::Call { tx, .. } => {
|
||||
let _ = tx.send(Err(error));
|
||||
true
|
||||
}
|
||||
PendingEntry::Subscribe { tx, .. } => {
|
||||
let _ = tx.try_send(Err(error));
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn evict_expired(&mut self) -> Vec<String> {
|
||||
let now = Instant::now();
|
||||
let mut evicted = Vec::new();
|
||||
let mut to_remove: Vec<String> = Vec::new();
|
||||
for (id, entry) in self.pending.iter() {
|
||||
let expired = match entry {
|
||||
PendingEntry::Call { timeout, .. } => *timeout <= now,
|
||||
PendingEntry::Subscribe {
|
||||
timeout: Some(t), ..
|
||||
} => *t <= now,
|
||||
PendingEntry::Subscribe { timeout: None, .. } => false,
|
||||
};
|
||||
if expired {
|
||||
to_remove.push(id.clone());
|
||||
}
|
||||
}
|
||||
for id in to_remove {
|
||||
let Some(entry) = self.pending.remove(&id) else {
|
||||
continue;
|
||||
};
|
||||
let timeout_err = CallError::timeout("request timed out");
|
||||
match entry {
|
||||
PendingEntry::Call { tx, .. } => {
|
||||
let _ = tx.send(Err(timeout_err));
|
||||
}
|
||||
PendingEntry::Subscribe { tx, .. } => {
|
||||
let _ = tx.try_send(Err(timeout_err));
|
||||
}
|
||||
}
|
||||
evicted.push(id);
|
||||
}
|
||||
evicted
|
||||
}
|
||||
|
||||
pub fn fail_all(&mut self, error: CallError) -> Vec<String> {
|
||||
let ids: Vec<String> = self.pending.keys().cloned().collect();
|
||||
for id in &ids {
|
||||
if let Some(entry) = self.pending.remove(id) {
|
||||
match entry {
|
||||
PendingEntry::Call { tx, .. } => {
|
||||
let _ = tx.send(Err(error.clone()));
|
||||
}
|
||||
PendingEntry::Subscribe { tx, .. } => {
|
||||
let _ = tx.try_send(Err(error.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
pub fn contains(&self, request_id: &str) -> bool {
|
||||
self.pending.contains_key(request_id)
|
||||
}
|
||||
|
||||
pub(crate) fn parent_of(&self, request_id: &str) -> Option<Option<String>> {
|
||||
self.pending
|
||||
.get(request_id)
|
||||
.map(|e| e.parent_request_id().map(|s| s.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn is_started(&self, request_id: &str) -> Option<bool> {
|
||||
self.pending.get(request_id).map(|e| e.started())
|
||||
}
|
||||
|
||||
pub(crate) fn request_ids(&self) -> Vec<String> {
|
||||
self.pending.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.pending.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.pending.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PendingRequestMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn timeout_error() -> CallError {
|
||||
CallError::timeout("request timed out")
|
||||
}
|
||||
|
||||
fn internal_error(message: &str) -> CallError {
|
||||
CallError::internal(message)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_call_then_handle_responded_resolves_oneshot() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx = map.register_call(
|
||||
"req-1".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(map.contains("req-1"));
|
||||
assert_eq!(map.len(), 1);
|
||||
|
||||
assert!(map.handle_responded("req-1", json!(42)));
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Ok(value))) => assert_eq!(value, json!(42)),
|
||||
other => panic!("expected Ok(42), got {other:?}"),
|
||||
}
|
||||
assert!(!map.contains("req-1"));
|
||||
assert_eq!(map.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_subscribe_then_handle_responded_pushes_to_channel() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let mut rx = map.register_subscribe("sub-1".to_string(), None, None);
|
||||
|
||||
assert!(map.handle_responded("sub-1", json!("first")));
|
||||
assert!(map.handle_responded("sub-1", json!("second")));
|
||||
assert!(map.contains("sub-1"));
|
||||
|
||||
let first = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
let second = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match (first, second) {
|
||||
(Ok(Some(Ok(a))), Ok(Some(Ok(b)))) => {
|
||||
assert_eq!(a, json!("first"));
|
||||
assert_eq!(b, json!("second"));
|
||||
}
|
||||
other => panic!("expected two Ok values, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_handle_completed_closes_channel_and_deletes_entry() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let mut rx = map.register_subscribe("sub-2".to_string(), None, None);
|
||||
|
||||
assert!(map.handle_responded("sub-2", json!("a")));
|
||||
assert!(map.handle_completed("sub-2"));
|
||||
assert!(!map.contains("sub-2"));
|
||||
|
||||
let _ = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
let after_close = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match after_close {
|
||||
Ok(None) => {}
|
||||
other => panic!("expected channel closed (None), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expired_call_is_evicted_with_timeout_error() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx = map.register_call(
|
||||
"req-2".to_string(),
|
||||
Instant::now() - Duration::from_millis(1),
|
||||
None,
|
||||
);
|
||||
|
||||
let evicted = map.evict_expired();
|
||||
assert_eq!(evicted, vec!["req-2".to_string()]);
|
||||
assert!(!map.contains("req-2"));
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Err(e))) => {
|
||||
assert_eq!(e.code, "TIMEOUT");
|
||||
assert!(e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(TIMEOUT), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expired_subscribe_is_evicted_with_timeout_error() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let mut rx = map.register_subscribe(
|
||||
"sub-3".to_string(),
|
||||
Some(Instant::now() - Duration::from_millis(1)),
|
||||
None,
|
||||
);
|
||||
|
||||
let evicted = map.evict_expired();
|
||||
assert_eq!(evicted, vec!["sub-3".to_string()]);
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match result {
|
||||
Ok(Some(Err(e))) => {
|
||||
assert_eq!(e.code, "TIMEOUT");
|
||||
assert!(e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(TIMEOUT), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unbounded_subscribe_is_not_evicted() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let _rx = map.register_subscribe("sub-4".to_string(), None, None);
|
||||
|
||||
let evicted = map.evict_expired();
|
||||
assert!(evicted.is_empty());
|
||||
assert!(map.contains("sub-4"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_all_resolves_all_pending_with_internal_error() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx_call = map.register_call(
|
||||
"c-1".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let mut rx_sub = map.register_subscribe(
|
||||
"s-1".to_string(),
|
||||
Some(Instant::now() + Duration::from_secs(30)),
|
||||
None,
|
||||
);
|
||||
|
||||
let failed = map.fail_all(internal_error("connection closed"));
|
||||
assert_eq!(failed.len(), 2);
|
||||
assert!(failed.contains(&"c-1".to_string()));
|
||||
assert!(failed.contains(&"s-1".to_string()));
|
||||
assert!(map.is_empty());
|
||||
|
||||
let call_result = timeout(Duration::from_millis(100), rx_call).await;
|
||||
match call_result {
|
||||
Ok(Ok(Err(e))) => {
|
||||
assert_eq!(e.code, "INTERNAL");
|
||||
assert_eq!(e.message, "connection closed");
|
||||
}
|
||||
other => panic!("expected Err(INTERNAL), got {other:?}"),
|
||||
}
|
||||
|
||||
let sub_result = timeout(Duration::from_millis(100), rx_sub.recv()).await;
|
||||
match sub_result {
|
||||
Ok(Some(Err(e))) => {
|
||||
assert_eq!(e.code, "INTERNAL");
|
||||
assert_eq!(e.message, "connection closed");
|
||||
}
|
||||
other => panic!("expected Err(INTERNAL), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_responded_unknown_request_id_returns_false() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
assert!(!map.handle_responded("nonexistent", json!(1)));
|
||||
assert_eq!(map.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_completed_unknown_request_id_returns_false() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
assert!(!map.handle_completed("nonexistent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_aborted_unknown_request_id_returns_false() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
assert!(!map.handle_aborted("nonexistent"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_error_unknown_request_id_returns_false() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
assert!(!map.handle_error("nonexistent", internal_error("x")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_aborted_cancels_pending_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx = map.register_call(
|
||||
"req-3".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(map.handle_aborted("req-3"));
|
||||
assert!(!map.contains("req-3"));
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Err(_)) => {}
|
||||
other => panic!("expected sender dropped (Err), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_error_resolves_call_with_error() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx = map.register_call(
|
||||
"req-4".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
|
||||
let err = CallError::new("FILE_NOT_FOUND", "missing", false);
|
||||
assert!(map.handle_error("req-4", err.clone()));
|
||||
assert!(!map.contains("req-4"));
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Err(e))) => {
|
||||
assert_eq!(e.code, "FILE_NOT_FOUND");
|
||||
assert!(!e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(FILE_NOT_FOUND), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_error_pushes_to_subscribe_channel() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let mut rx = map.register_subscribe("sub-5".to_string(), None, None);
|
||||
|
||||
let err = CallError::new("RATE_LIMITED", "too fast", true);
|
||||
assert!(map.handle_error("sub-5", err.clone()));
|
||||
assert!(!map.contains("sub-5"));
|
||||
|
||||
let result = timeout(Duration::from_millis(100), rx.recv()).await;
|
||||
match result {
|
||||
Ok(Some(Err(e))) => {
|
||||
assert_eq!(e.code, "RATE_LIMITED");
|
||||
assert!(e.retryable);
|
||||
}
|
||||
other => panic!("expected Err(RATE_LIMITED), got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn correlation_by_id_not_by_stream() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let rx = map.register_call(
|
||||
"req-stream-3".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(map.handle_responded("req-stream-3", json!("response-from-stream-7")));
|
||||
let result = timeout(Duration::from_millis(100), rx).await;
|
||||
match result {
|
||||
Ok(Ok(Ok(value))) => assert_eq!(value, json!("response-from-stream-7")),
|
||||
other => panic!("expected Ok, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_call_overwrites_existing_entry() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let _rx_old = map.register_call(
|
||||
"req-5".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
let rx_new = map.register_call(
|
||||
"req-5".to_string(),
|
||||
Instant::now() + Duration::from_secs(30),
|
||||
None,
|
||||
);
|
||||
assert_eq!(map.len(), 1);
|
||||
|
||||
assert!(map.handle_responded("req-5", json!("new")));
|
||||
let result = timeout(Duration::from_millis(100), rx_new).await;
|
||||
match result {
|
||||
Ok(Ok(Ok(value))) => assert_eq!(value, json!("new")),
|
||||
other => panic!("expected Ok from new receiver, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn evict_expired_skips_non_expired_entries() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let _rx_expired = map.register_call(
|
||||
"expired".to_string(),
|
||||
Instant::now() - Duration::from_millis(1),
|
||||
None,
|
||||
);
|
||||
let _rx_alive = map.register_call(
|
||||
"alive".to_string(),
|
||||
Instant::now() + Duration::from_secs(60),
|
||||
None,
|
||||
);
|
||||
|
||||
let evicted = map.evict_expired();
|
||||
assert_eq!(evicted, vec!["expired".to_string()]);
|
||||
assert!(map.contains("alive"));
|
||||
assert!(!map.contains("expired"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_is_empty_map() {
|
||||
let map = PendingRequestMap::default();
|
||||
assert!(map.is_empty());
|
||||
assert_eq!(map.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn timeout_error_helper() {
|
||||
let err = timeout_error();
|
||||
assert_eq!(err.code, "TIMEOUT");
|
||||
assert!(err.retryable);
|
||||
}
|
||||
}
|
||||
544
crates/alknet-call/src/protocol/wire.rs
Normal file
544
crates/alknet-call/src/protocol/wire.rs
Normal file
@@ -0,0 +1,544 @@
|
||||
//! Wire format: `EventEnvelope`, `ResponseEnvelope`, `CallError`, and
|
||||
//! length-prefixed JSON framing.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/call-protocol.md` for the full
|
||||
//! specification.
|
||||
|
||||
use std::io;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub const EVENT_REQUESTED: &str = "call.requested";
|
||||
pub const EVENT_RESPONDED: &str = "call.responded";
|
||||
pub const EVENT_COMPLETED: &str = "call.completed";
|
||||
pub const EVENT_ABORTED: &str = "call.aborted";
|
||||
pub const EVENT_ERROR: &str = "call.error";
|
||||
|
||||
const LENGTH_PREFIX_BYTES: usize = 4;
|
||||
const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct EventEnvelope {
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
pub id: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
impl EventEnvelope {
|
||||
pub fn new(event_type: impl Into<String>, id: impl Into<String>, payload: Value) -> Self {
|
||||
Self {
|
||||
r#type: event_type.into(),
|
||||
id: id.into(),
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn requested(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(EVENT_REQUESTED, id, payload)
|
||||
}
|
||||
|
||||
pub fn responded(id: impl Into<String>, output: Value) -> Self {
|
||||
Self::new(EVENT_RESPONDED, id, serde_json::json!({ "output": output }))
|
||||
}
|
||||
|
||||
pub fn completed(id: impl Into<String>) -> Self {
|
||||
Self::new(EVENT_COMPLETED, id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
pub fn aborted(id: impl Into<String>) -> Self {
|
||||
Self::new(EVENT_ABORTED, id, serde_json::json!({}))
|
||||
}
|
||||
|
||||
pub fn error(id: impl Into<String>, error: &CallError) -> Self {
|
||||
let payload = serde_json::to_value(error).unwrap_or(Value::Null);
|
||||
Self::new(EVENT_ERROR, id, payload)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CallError {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
pub retryable: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<Value>,
|
||||
}
|
||||
|
||||
impl CallError {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>, retryable: bool) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
retryable,
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_details(mut self, details: Value) -> Self {
|
||||
self.details = Some(details);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn not_found(op_name: &str) -> Self {
|
||||
Self::new(
|
||||
"NOT_FOUND",
|
||||
format!("operation not found: {op_name}"),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn forbidden(message: impl Into<String>) -> Self {
|
||||
Self::new("FORBIDDEN", message, false)
|
||||
}
|
||||
|
||||
pub fn invalid_input(message: impl Into<String>) -> Self {
|
||||
Self::new("INVALID_INPUT", message, false)
|
||||
}
|
||||
|
||||
pub fn internal(message: impl Into<String>) -> Self {
|
||||
Self::new("INTERNAL", message, false)
|
||||
}
|
||||
|
||||
pub fn timeout(message: impl Into<String>) -> Self {
|
||||
Self::new("TIMEOUT", message, true)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for CallError {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ResponseEnvelope {
|
||||
pub request_id: String,
|
||||
pub result: Result<Value, CallError>,
|
||||
}
|
||||
|
||||
impl ResponseEnvelope {
|
||||
pub fn ok(request_id: impl Into<String>, output: Value) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Ok(output),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(request_id: impl Into<String>, error: CallError) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn not_found(request_id: impl Into<String>, op_name: &str) -> Self {
|
||||
Self::error(request_id, CallError::not_found(op_name))
|
||||
}
|
||||
|
||||
pub fn forbidden(request_id: impl Into<String>, message: impl Into<String>) -> Self {
|
||||
Self::error(request_id, CallError::forbidden(message))
|
||||
}
|
||||
|
||||
pub fn into_event(self) -> EventEnvelope {
|
||||
let id = self.request_id;
|
||||
match self.result {
|
||||
Ok(output) => EventEnvelope::responded(id, output),
|
||||
Err(ref err) => EventEnvelope::error(id, err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ResponseEnvelope> for EventEnvelope {
|
||||
fn from(envelope: ResponseEnvelope) -> EventEnvelope {
|
||||
envelope.into_event()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FrameError {
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("json error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("connection closed")]
|
||||
ConnectionClosed,
|
||||
#[error("invalid frame")]
|
||||
InvalidFrame,
|
||||
}
|
||||
|
||||
pub struct FrameFramedReader<R: AsyncRead + Unpin> {
|
||||
reader: R,
|
||||
len_buf: [u8; LENGTH_PREFIX_BYTES],
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FrameFramedReader<R> {
|
||||
pub fn new(reader: R) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
len_buf: [0u8; LENGTH_PREFIX_BYTES],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> R {
|
||||
self.reader
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> Result<EventEnvelope, FrameError> {
|
||||
match self.reader.read_exact(&mut self.len_buf).await {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
|
||||
return Err(FrameError::ConnectionClosed);
|
||||
}
|
||||
Err(e) => return Err(FrameError::Io(e)),
|
||||
}
|
||||
|
||||
let length = u32::from_be_bytes(self.len_buf);
|
||||
if length == 0 {
|
||||
return Err(FrameError::InvalidFrame);
|
||||
}
|
||||
if length > MAX_FRAME_SIZE {
|
||||
return Err(FrameError::InvalidFrame);
|
||||
}
|
||||
|
||||
let mut body = vec![0u8; length as usize];
|
||||
match self.reader.read_exact(&mut body).await {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
|
||||
return Err(FrameError::ConnectionClosed);
|
||||
}
|
||||
Err(e) => return Err(FrameError::Io(e)),
|
||||
}
|
||||
|
||||
let envelope: EventEnvelope = serde_json::from_slice(&body)?;
|
||||
Ok(envelope)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FrameFramedWriter<W: AsyncWrite + Unpin> {
|
||||
writer: W,
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FrameFramedWriter<W> {
|
||||
pub fn new(writer: W) -> Self {
|
||||
Self { writer }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.writer
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, envelope: &EventEnvelope) -> Result<(), FrameError> {
|
||||
let body = serde_json::to_vec(envelope)?;
|
||||
let len = body.len();
|
||||
if len > MAX_FRAME_SIZE as usize {
|
||||
return Err(FrameError::InvalidFrame);
|
||||
}
|
||||
let len_bytes = (len as u32).to_be_bytes();
|
||||
self.writer.write_all(&len_bytes).await?;
|
||||
self.writer.write_all(&body).await?;
|
||||
self.writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt};
|
||||
|
||||
fn sample_envelope() -> EventEnvelope {
|
||||
EventEnvelope::new(
|
||||
"call.requested",
|
||||
"req-1",
|
||||
serde_json::json!({
|
||||
"operationId": "/fs/readFile",
|
||||
"input": { "path": "/etc/hosts" }
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip_envelope() {
|
||||
let (client, server) = duplex(8 * 1024);
|
||||
let envelope = sample_envelope();
|
||||
|
||||
let mut writer = FrameFramedWriter::new(client);
|
||||
writer.write_frame(&envelope).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
let read = reader.read_frame().await.unwrap();
|
||||
assert_eq!(read, envelope);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip_multiple_frames() {
|
||||
let (client, server) = duplex(8 * 1024);
|
||||
|
||||
let envelopes = vec![
|
||||
EventEnvelope::responded("a", Value::String("hello".into())),
|
||||
EventEnvelope::completed("a"),
|
||||
EventEnvelope::aborted("b"),
|
||||
];
|
||||
|
||||
{
|
||||
let mut writer = FrameFramedWriter::new(client);
|
||||
for e in &envelopes {
|
||||
writer.write_frame(e).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
for expected in envelopes {
|
||||
let read = reader.read_frame().await.unwrap();
|
||||
assert_eq!(read, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_frame_on_closed_reader_returns_connection_closed() {
|
||||
let (_, server) = duplex(8 * 1024);
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::ConnectionClosed) => {}
|
||||
other => panic!("expected ConnectionClosed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn truncated_body_returns_connection_closed() {
|
||||
let (mut client, server) = duplex(8 * 1024);
|
||||
let envelope = sample_envelope();
|
||||
let body = serde_json::to_vec(&envelope).unwrap();
|
||||
let len_bytes = (body.len() as u32).to_be_bytes();
|
||||
client.write_all(&len_bytes).await.unwrap();
|
||||
client.write_all(&body[..body.len() / 2]).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::ConnectionClosed) => {}
|
||||
other => panic!("expected ConnectionClosed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn zero_length_frame_is_invalid() {
|
||||
let (mut client, server) = duplex(8 * 1024);
|
||||
client.write_all(&[0u8, 0, 0, 0]).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::InvalidFrame) => {}
|
||||
other => panic!("expected InvalidFrame, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn oversized_frame_is_invalid() {
|
||||
let (mut client, server) = duplex(8 * 1024);
|
||||
let too_big = (MAX_FRAME_SIZE + 1u32).to_be_bytes();
|
||||
client.write_all(&too_big).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::InvalidFrame) => {}
|
||||
other => panic!("expected InvalidFrame, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn framing_handles_large_payload() {
|
||||
let (client, server) = duplex(1024 * 1024);
|
||||
let big = "x".repeat(64 * 1024);
|
||||
let envelope = EventEnvelope::responded("big", Value::String(big.clone()));
|
||||
|
||||
let mut writer = FrameFramedWriter::new(client);
|
||||
writer.write_frame(&envelope).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
let read = reader.read_frame().await.unwrap();
|
||||
assert_eq!(read, envelope);
|
||||
match read.payload {
|
||||
Value::Object(map) => match map.get("output") {
|
||||
Some(Value::String(s)) => assert_eq!(s, &big),
|
||||
other => panic!("expected output string, got {other:?}"),
|
||||
},
|
||||
other => panic!("expected object payload, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_ok_produces_call_responded_event() {
|
||||
let response = ResponseEnvelope::ok("req-1", Value::String("hi".into()));
|
||||
let event: EventEnvelope = response.into();
|
||||
assert_eq!(event.r#type, EVENT_RESPONDED);
|
||||
assert_eq!(event.id, "req-1");
|
||||
let map = event.payload.as_object().expect("payload is object");
|
||||
assert_eq!(map.get("output"), Some(&Value::String("hi".into())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_error_produces_call_error_event() {
|
||||
let err = CallError::new("FILE_NOT_FOUND", "file not found: /etc/x", false)
|
||||
.with_details(serde_json::json!({ "path": "/etc/x" }));
|
||||
let response = ResponseEnvelope::error("req-2", err);
|
||||
let event: EventEnvelope = response.into();
|
||||
assert_eq!(event.r#type, EVENT_ERROR);
|
||||
assert_eq!(event.id, "req-2");
|
||||
assert_eq!(
|
||||
event.payload.get("code"),
|
||||
Some(&Value::String("FILE_NOT_FOUND".into()))
|
||||
);
|
||||
assert_eq!(
|
||||
event.payload.get("message"),
|
||||
Some(&Value::String("file not found: /etc/x".into()))
|
||||
);
|
||||
assert_eq!(event.payload.get("retryable"), Some(&Value::Bool(false)));
|
||||
assert_eq!(
|
||||
event.payload.get("details"),
|
||||
Some(&serde_json::json!({ "path": "/etc/x" }))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_not_found_helper() {
|
||||
let response = ResponseEnvelope::not_found("req-3", "fs/missing");
|
||||
assert_eq!(response.request_id, "req-3");
|
||||
match &response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "NOT_FOUND");
|
||||
assert!(!e.retryable);
|
||||
assert!(e.message.contains("fs/missing"));
|
||||
}
|
||||
other => panic!("expected Err, got {other:?}"),
|
||||
}
|
||||
let event: EventEnvelope = response.into();
|
||||
assert_eq!(event.r#type, EVENT_ERROR);
|
||||
assert_eq!(event.id, "req-3");
|
||||
assert_eq!(
|
||||
event.payload.get("code"),
|
||||
Some(&Value::String("NOT_FOUND".into()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_forbidden_helper() {
|
||||
let response = ResponseEnvelope::forbidden("req-4", "authentication required");
|
||||
match &response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "FORBIDDEN");
|
||||
assert_eq!(e.message, "authentication required");
|
||||
}
|
||||
other => panic!("expected Err, got {other:?}"),
|
||||
}
|
||||
let event: EventEnvelope = response.into();
|
||||
assert_eq!(event.r#type, EVENT_ERROR);
|
||||
assert_eq!(event.id, "req-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_completed_has_empty_payload() {
|
||||
let event = EventEnvelope::completed("sub-1");
|
||||
assert_eq!(event.r#type, EVENT_COMPLETED);
|
||||
assert_eq!(event.id, "sub-1");
|
||||
assert_eq!(event.payload, serde_json::json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_aborted_has_empty_payload() {
|
||||
let event = EventEnvelope::aborted("req-9");
|
||||
assert_eq!(event.r#type, EVENT_ABORTED);
|
||||
assert_eq!(event.id, "req-9");
|
||||
assert_eq!(event.payload, serde_json::json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_responded_wraps_output() {
|
||||
let event = EventEnvelope::responded("req-1", Value::Number(42.into()));
|
||||
assert_eq!(event.r#type, EVENT_RESPONDED);
|
||||
assert_eq!(event.payload.get("output"), Some(&Value::Number(42.into())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_serializes_type_field() {
|
||||
let event = sample_envelope();
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"type\":\"call.requested\""));
|
||||
assert!(!json.contains("\"r#type\""));
|
||||
|
||||
let parsed: EventEnvelope = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, event);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_error_skips_missing_details() {
|
||||
let err = CallError::new("INTERNAL", "boom", false);
|
||||
let json = serde_json::to_string(&err).unwrap();
|
||||
assert!(!json.contains("details"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_after_eof_then_eof_returns_connection_closed() {
|
||||
let mut data = Vec::new();
|
||||
let envelope = EventEnvelope::responded("one", Value::Null);
|
||||
let body = serde_json::to_vec(&envelope).unwrap();
|
||||
data.extend_from_slice(&(body.len() as u32).to_be_bytes());
|
||||
data.extend_from_slice(&body);
|
||||
let cursor = std::io::Cursor::new(data);
|
||||
let mut reader = FrameFramedReader::new(cursor);
|
||||
let first = reader.read_frame().await.unwrap();
|
||||
assert_eq!(first, envelope);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::ConnectionClosed) => {}
|
||||
other => panic!("expected ConnectionClosed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn writer_into_inner_recovers_stream() {
|
||||
let (client, server) = duplex(8 * 1024);
|
||||
let envelope = sample_envelope();
|
||||
let mut writer = FrameFramedWriter::new(client);
|
||||
writer.write_frame(&envelope).await.unwrap();
|
||||
let mut recovered = writer.into_inner();
|
||||
recovered.shutdown().await.unwrap();
|
||||
drop(recovered);
|
||||
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
let read = reader.read_frame().await.unwrap();
|
||||
assert_eq!(read, envelope);
|
||||
let _ = reader.into_inner();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reader_handles_partial_length_prefix() {
|
||||
let (mut client, server) = duplex(8 * 1024);
|
||||
client.write_all(&[0u8, 0]).await.unwrap();
|
||||
drop(client);
|
||||
let mut reader = FrameFramedReader::new(server);
|
||||
match reader.read_frame().await {
|
||||
Err(FrameError::ConnectionClosed) => {}
|
||||
other => panic!("expected ConnectionClosed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reader_drains_remaining_after_read() {
|
||||
let mut data = Vec::new();
|
||||
let envelope = sample_envelope();
|
||||
let body = serde_json::to_vec(&envelope).unwrap();
|
||||
data.extend_from_slice(&(body.len() as u32).to_be_bytes());
|
||||
data.extend_from_slice(&body);
|
||||
data.extend_from_slice(&[9u8; 4]);
|
||||
let mut cursor = tokio::io::BufReader::new(std::io::Cursor::new(data));
|
||||
let mut reader = FrameFramedReader::new(&mut cursor);
|
||||
let read = reader.read_frame().await.unwrap();
|
||||
assert_eq!(read, envelope);
|
||||
let mut leftover = Vec::new();
|
||||
let _ = cursor.read_to_end(&mut leftover).await.unwrap();
|
||||
assert_eq!(leftover, vec![9u8; 4]);
|
||||
}
|
||||
}
|
||||
188
crates/alknet-call/src/registry/context.rs
Normal file
188
crates/alknet-call/src/registry/context.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use alknet_core::auth::Identity;
|
||||
use alknet_core::types::Capabilities;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::env::OperationEnv;
|
||||
|
||||
pub struct OperationContext {
|
||||
pub request_id: String,
|
||||
pub parent_request_id: Option<String>,
|
||||
pub identity: Option<Identity>,
|
||||
pub handler_identity: Option<CompositionAuthority>,
|
||||
/// The original caller when this call was forwarded by a `from_call`
|
||||
/// handler (ADR-032). **Metadata only** — `AccessControl::check` never
|
||||
/// reads it; the ACL always authorizes `identity` (the direct caller).
|
||||
/// Handlers may read it for logging, auditing, per-user rate limiting,
|
||||
/// or application context. Populated from
|
||||
/// `call.requested.forwarded_for` by the dispatch path; set to `None`
|
||||
/// for composed children (wire-ingress only, not composition-ingress).
|
||||
/// The forwarder's claim, not a verified identity — a malicious hub can
|
||||
/// lie (same property as HTTP `X-Forwarded-For`). See ADR-032.
|
||||
pub forwarded_for: Option<Identity>,
|
||||
pub capabilities: Capabilities,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub scoped_env: ScopedOperationEnv,
|
||||
pub env: Arc<dyn OperationEnv + Send + Sync>,
|
||||
pub abort_policy: AbortPolicy,
|
||||
pub deadline: Option<Instant>,
|
||||
pub internal: bool,
|
||||
}
|
||||
|
||||
impl OperationContext {
|
||||
pub fn is_internal(&self) -> bool {
|
||||
self.internal
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum AbortPolicy {
|
||||
#[default]
|
||||
AbortDependents,
|
||||
ContinueRunning,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompositionAuthority {
|
||||
pub label: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub resources: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl CompositionAuthority {
|
||||
pub fn none() -> Option<Self> {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn new(label: &str, scopes: impl IntoIterator<Item = String>) -> Self {
|
||||
Self {
|
||||
label: label.to_string(),
|
||||
scopes: scopes.into_iter().collect(),
|
||||
resources: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_identity(&self) -> Option<Identity> {
|
||||
Some(Identity {
|
||||
id: self.label.clone(),
|
||||
scopes: self.scopes.clone(),
|
||||
resources: self.resources.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScopedOperationEnv {
|
||||
allowed: HashSet<String>,
|
||||
}
|
||||
|
||||
impl ScopedOperationEnv {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
allowed: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(ops: impl IntoIterator<Item = impl Into<String>>) -> Self {
|
||||
Self {
|
||||
allowed: ops.into_iter().map(|s| s.into()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allows(&self, name: &str) -> bool {
|
||||
self.allowed.contains(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ScopedOperationEnv {
|
||||
fn default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn generate_request_id() -> String {
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn scoped_env_allows_in_set() {
|
||||
let env = ScopedOperationEnv::new(["fs/readFile", "agent/chat"]);
|
||||
assert!(env.allows("fs/readFile"));
|
||||
assert!(env.allows("agent/chat"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_env_disallows_not_in_set() {
|
||||
let env = ScopedOperationEnv::new(["fs/readFile"]);
|
||||
assert!(!env.allows("agent/chat"));
|
||||
assert!(!env.allows(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_env_empty_allows_nothing() {
|
||||
let env = ScopedOperationEnv::empty();
|
||||
assert!(!env.allows("fs/readFile"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composition_authority_as_identity_correct() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["vastai".to_string()]);
|
||||
let authority = CompositionAuthority {
|
||||
label: "agent-chat".to_string(),
|
||||
scopes: vec!["llm:call".to_string(), "fs:read".to_string()],
|
||||
resources,
|
||||
};
|
||||
let identity = authority.as_identity().expect("as_identity returns Some");
|
||||
assert_eq!(identity.id, "agent-chat");
|
||||
assert_eq!(
|
||||
identity.scopes,
|
||||
vec!["llm:call".to_string(), "fs:read".to_string()]
|
||||
);
|
||||
assert_eq!(
|
||||
identity.resources.get("service"),
|
||||
Some(&vec!["vastai".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composition_authority_new_populates_label_and_scopes() {
|
||||
let authority = CompositionAuthority::new(
|
||||
"agent-chat",
|
||||
["llm:call".to_string(), "fs:read".to_string()],
|
||||
);
|
||||
assert_eq!(authority.label, "agent-chat");
|
||||
assert_eq!(
|
||||
authority.scopes,
|
||||
vec!["llm:call".to_string(), "fs:read".to_string()]
|
||||
);
|
||||
assert!(authority.resources.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composition_authority_none_is_none() {
|
||||
assert!(CompositionAuthority::none().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn abort_policy_default_is_abort_dependents() {
|
||||
let policy = AbortPolicy::default();
|
||||
assert!(matches!(policy, AbortPolicy::AbortDependents));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_request_id_is_unique_and_non_deterministic() {
|
||||
let a = generate_request_id();
|
||||
let b = generate_request_id();
|
||||
assert_ne!(a, b);
|
||||
assert!(!a.is_empty());
|
||||
}
|
||||
}
|
||||
952
crates/alknet-call/src/registry/discovery.rs
Normal file
952
crates/alknet-call/src/registry/discovery.rs
Normal file
@@ -0,0 +1,952 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::context::OperationContext;
|
||||
use super::registration::{Handler, OperationRegistry};
|
||||
use super::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||
use crate::protocol::wire::{CallError, ResponseEnvelope};
|
||||
|
||||
const NAME_SERVICES_LIST: &str = "services/list";
|
||||
const NAME_SERVICES_LIST_PEERS: &str = "services/list-peers";
|
||||
const NAME_SERVICES_SCHEMA: &str = "services/schema";
|
||||
|
||||
pub fn services_list_spec() -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
NAME_SERVICES_LIST,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": {
|
||||
"type": "string",
|
||||
"enum": ["query", "mutation", "subscription"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn services_schema_spec() -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
NAME_SERVICES_SCHEMA,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": { "name": { "type": "string" } },
|
||||
"required": ["name"]
|
||||
}),
|
||||
operation_spec_schema(),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn services_list_peers_spec() -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
NAME_SERVICES_LIST_PEERS,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"peers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"peer_id": { "type": "string" },
|
||||
"operations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": {
|
||||
"type": "string",
|
||||
"enum": ["query", "mutation", "subscription"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn operation_spec_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": {
|
||||
"type": "string",
|
||||
"enum": ["query", "mutation", "subscription"]
|
||||
},
|
||||
"visibility": {
|
||||
"type": "string",
|
||||
"enum": ["external", "internal"]
|
||||
},
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"error_schemas": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": { "type": "string" },
|
||||
"description": { "type": "string" },
|
||||
"schema": {},
|
||||
"http_status": { "type": ["integer", "null"] }
|
||||
}
|
||||
}
|
||||
},
|
||||
"access_control": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_scopes": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
},
|
||||
"required_scopes_any": {
|
||||
"type": ["array", "null"],
|
||||
"items": { "type": "string" }
|
||||
},
|
||||
"resource_type": { "type": ["string", "null"] },
|
||||
"resource_action": { "type": ["string", "null"] }
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"namespace",
|
||||
"op_type",
|
||||
"visibility",
|
||||
"input_schema",
|
||||
"output_schema",
|
||||
"error_schemas",
|
||||
"access_control"
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
fn op_type_str(op_type: OperationType) -> &'static str {
|
||||
match op_type {
|
||||
OperationType::Query => "query",
|
||||
OperationType::Mutation => "mutation",
|
||||
OperationType::Subscription => "subscription",
|
||||
}
|
||||
}
|
||||
|
||||
fn visibility_str(visibility: Visibility) -> &'static str {
|
||||
match visibility {
|
||||
Visibility::External => "external",
|
||||
Visibility::Internal => "internal",
|
||||
}
|
||||
}
|
||||
|
||||
fn access_control_to_json(acl: &AccessControl) -> Value {
|
||||
json!({
|
||||
"required_scopes": acl.required_scopes,
|
||||
"required_scopes_any": acl.required_scopes_any,
|
||||
"resource_type": acl.resource_type,
|
||||
"resource_action": acl.resource_action,
|
||||
})
|
||||
}
|
||||
|
||||
fn error_definition_to_json(def: &super::spec::ErrorDefinition) -> Value {
|
||||
json!({
|
||||
"code": def.code,
|
||||
"description": def.description,
|
||||
"schema": def.schema,
|
||||
"http_status": def.http_status,
|
||||
})
|
||||
}
|
||||
|
||||
fn spec_to_json(spec: &OperationSpec) -> Value {
|
||||
let error_schemas: Vec<Value> = spec
|
||||
.error_schemas
|
||||
.iter()
|
||||
.map(error_definition_to_json)
|
||||
.collect();
|
||||
json!({
|
||||
"name": spec.name,
|
||||
"namespace": spec.namespace,
|
||||
"op_type": op_type_str(spec.op_type),
|
||||
"visibility": visibility_str(spec.visibility),
|
||||
"input_schema": spec.input_schema,
|
||||
"output_schema": spec.output_schema,
|
||||
"error_schemas": error_schemas,
|
||||
"access_control": access_control_to_json(&spec.access_control),
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_name(name: &str) -> String {
|
||||
if let Some(rest) = name.strip_prefix('/') {
|
||||
rest.to_string()
|
||||
} else {
|
||||
name.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn services_list_handler(registry: Arc<OperationRegistry>) -> Handler {
|
||||
Arc::new(move |input: Value, ctx: OperationContext| {
|
||||
let registry = Arc::clone(®istry);
|
||||
Box::pin(async move {
|
||||
let _ = input;
|
||||
let calling_identity = ctx.identity.as_ref();
|
||||
let ops: Vec<Value> = registry
|
||||
.list_operations()
|
||||
.into_iter()
|
||||
.filter(|spec| spec.access_control.check(calling_identity).is_allowed())
|
||||
.map(|s| {
|
||||
json!({
|
||||
"name": s.name,
|
||||
"namespace": s.namespace,
|
||||
"op_type": op_type_str(s.op_type),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
ResponseEnvelope::ok(ctx.request_id, json!({ "operations": ops }))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn services_list_peers_handler(registry: Arc<OperationRegistry>) -> Handler {
|
||||
Arc::new(move |input: Value, ctx: OperationContext| {
|
||||
let registry = Arc::clone(®istry);
|
||||
Box::pin(async move {
|
||||
let _ = input;
|
||||
let calling_identity = ctx.identity.as_ref();
|
||||
let local_ops: Vec<Value> = registry
|
||||
.list_operations()
|
||||
.into_iter()
|
||||
.filter(|spec| spec.access_control.check(calling_identity).is_allowed())
|
||||
.map(|s| {
|
||||
json!({
|
||||
"name": s.name,
|
||||
"namespace": s.namespace,
|
||||
"op_type": op_type_str(s.op_type),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut peers: Vec<Value> = Vec::new();
|
||||
if !local_ops.is_empty() {
|
||||
peers.push(json!({ "peer_id": "local", "operations": local_ops }));
|
||||
}
|
||||
for peer_id in ctx.env.peer_ids() {
|
||||
let peer_ops: Vec<Value> = ctx
|
||||
.env
|
||||
.peer_operations(&peer_id)
|
||||
.into_iter()
|
||||
.filter(|name| {
|
||||
let spec = registry.registration(name);
|
||||
match spec {
|
||||
Some(reg) => {
|
||||
reg.spec.access_control.check(calling_identity).is_allowed()
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
})
|
||||
.map(name_to_listing_json)
|
||||
.collect();
|
||||
if !peer_ops.is_empty() {
|
||||
peers.push(json!({ "peer_id": peer_id, "operations": peer_ops }));
|
||||
}
|
||||
}
|
||||
ResponseEnvelope::ok(ctx.request_id, json!({ "peers": peers }))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn name_to_listing_json(name: String) -> Value {
|
||||
let namespace = name
|
||||
.split('/')
|
||||
.next()
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
json!({
|
||||
"name": name,
|
||||
"namespace": namespace,
|
||||
"op_type": "query",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn services_schema_handler(registry: Arc<OperationRegistry>) -> Handler {
|
||||
Arc::new(move |input: Value, ctx: OperationContext| {
|
||||
let registry = Arc::clone(®istry);
|
||||
Box::pin(async move {
|
||||
let name = match input.get("name").and_then(|v| v.as_str()) {
|
||||
Some(n) => normalize_name(n),
|
||||
None => {
|
||||
return ResponseEnvelope::error(
|
||||
ctx.request_id,
|
||||
CallError::invalid_input("missing required field: name"),
|
||||
);
|
||||
}
|
||||
};
|
||||
match registry.registration(&name) {
|
||||
Some(reg) => {
|
||||
let spec_json = spec_to_json(®.spec);
|
||||
ResponseEnvelope::ok(ctx.request_id, spec_json)
|
||||
}
|
||||
None => ResponseEnvelope::not_found(ctx.request_id, &name),
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::registry::context::{CompositionAuthority, ScopedOperationEnv};
|
||||
use crate::registry::registration::{make_handler, HandlerRegistration, OperationProvenance};
|
||||
use alknet_core::types::Capabilities;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
fn external_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn internal_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Mutation,
|
||||
Visibility::Internal,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn echo_handler() -> Handler {
|
||||
make_handler(
|
||||
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
|
||||
)
|
||||
}
|
||||
|
||||
fn noop_env() -> Arc<dyn crate::registry::env::OperationEnv + Send + Sync> {
|
||||
struct NoopEnv;
|
||||
#[async_trait::async_trait]
|
||||
impl crate::registry::env::OperationEnv for NoopEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_ns: &str,
|
||||
_op: &str,
|
||||
_input: Value,
|
||||
_parent: &OperationContext,
|
||||
_policy: crate::registry::context::AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::error("test", CallError::internal("noop env does not dispatch"))
|
||||
}
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
Arc::new(NoopEnv)
|
||||
}
|
||||
|
||||
fn root_context(request_id: &str) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env: ScopedOperationEnv::empty(),
|
||||
env: noop_env(),
|
||||
abort_policy: crate::registry::context::AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn root_context_with_identity(
|
||||
request_id: &str,
|
||||
identity: Option<alknet_core::auth::Identity>,
|
||||
) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity,
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env: ScopedOperationEnv::empty(),
|
||||
env: noop_env(),
|
||||
abort_policy: crate::registry::context::AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn identity_with_scopes(id: &str, scopes: &[&str]) -> alknet_core::auth::Identity {
|
||||
alknet_core::auth::Identity {
|
||||
id: id.to_string(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
resources: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn external_spec_with_acl(name: &str, acl: AccessControl) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![],
|
||||
acl,
|
||||
)
|
||||
}
|
||||
|
||||
fn registry_with_access_controlled_ops() -> Arc<OperationRegistry> {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec_with_acl("public/echo", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec_with_acl(
|
||||
"admin/secret",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec("internal/hidden"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
Arc::new(registry)
|
||||
}
|
||||
|
||||
fn op_names(response: ResponseEnvelope) -> Vec<String> {
|
||||
let output = response.result.expect("ok response");
|
||||
output
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("operations array")
|
||||
.iter()
|
||||
.filter_map(|o| o.get("name").and_then(|n| n.as_str()).map(String::from))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn registry_with_ops() -> Arc<OperationRegistry> {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("fs/readFile"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec("secret/internal"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
OperationSpec::new(
|
||||
"events/subscribe",
|
||||
OperationType::Subscription,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
OperationSpec::new(
|
||||
"fs/readFileErr",
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({}),
|
||||
json!({}),
|
||||
vec![super::super::spec::ErrorDefinition {
|
||||
code: "FILE_NOT_FOUND".to_string(),
|
||||
description: "file not found".to_string(),
|
||||
schema: json!({ "type": "object" }),
|
||||
http_status: None,
|
||||
}],
|
||||
AccessControl::default(),
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
Arc::new(registry)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_spec_has_correct_fields() {
|
||||
let spec = services_list_spec();
|
||||
assert_eq!(spec.name, NAME_SERVICES_LIST);
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert_eq!(spec.visibility, Visibility::External);
|
||||
assert_eq!(spec.input_schema, json!({}));
|
||||
assert!(spec.output_schema.get("properties").is_some());
|
||||
assert!(spec.error_schemas.is_empty());
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_spec_has_correct_fields() {
|
||||
let spec = services_schema_spec();
|
||||
assert_eq!(spec.name, NAME_SERVICES_SCHEMA);
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert_eq!(spec.visibility, Visibility::External);
|
||||
assert!(spec.input_schema.get("required").is_some());
|
||||
assert!(spec.output_schema.get("properties").is_some());
|
||||
assert!(spec.error_schemas.is_empty());
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_returns_external_ops_only() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_list_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-1");
|
||||
let response = handler(serde_json::json!({}), ctx).await;
|
||||
let output = response.result.expect("ok response");
|
||||
let ops = output
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("operations array");
|
||||
let names: Vec<&str> = ops
|
||||
.iter()
|
||||
.filter_map(|o| o.get("name").and_then(|n| n.as_str()))
|
||||
.collect();
|
||||
assert!(names.contains(&"fs/readFile"));
|
||||
assert!(names.contains(&"events/subscribe"));
|
||||
assert!(names.contains(&"fs/readFileErr"));
|
||||
assert!(
|
||||
!names.contains(&"secret/internal"),
|
||||
"internal ops must not be listed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_output_format_matches_spec() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_list_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-1");
|
||||
let response = handler(serde_json::json!({}), ctx).await;
|
||||
let output = response.result.expect("ok response");
|
||||
let ops = output
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("operations array");
|
||||
let fs_op = ops
|
||||
.iter()
|
||||
.find(|o| o.get("name").and_then(|n| n.as_str()) == Some("fs/readFile"))
|
||||
.expect("fs/readFile present");
|
||||
assert_eq!(fs_op.get("namespace"), Some(&json!("fs")));
|
||||
assert_eq!(fs_op.get("op_type"), Some(&json!("query")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_schema_returns_spec_for_known_op() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_schema_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-2");
|
||||
let response = handler(serde_json::json!({ "name": "fs/readFileErr" }), ctx).await;
|
||||
let spec = response.result.expect("ok response");
|
||||
assert_eq!(spec.get("name"), Some(&json!("fs/readFileErr")));
|
||||
assert_eq!(spec.get("namespace"), Some(&json!("fs")));
|
||||
assert_eq!(spec.get("op_type"), Some(&json!("query")));
|
||||
let error_schemas = spec
|
||||
.get("error_schemas")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("error_schemas array");
|
||||
assert_eq!(error_schemas.len(), 1);
|
||||
assert_eq!(error_schemas[0].get("code"), Some(&json!("FILE_NOT_FOUND")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_schema_returns_not_found_for_unknown_op() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_schema_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-3");
|
||||
let response = handler(serde_json::json!({ "name": "no/such" }), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_schema_accepts_name_with_leading_slash() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_schema_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-4");
|
||||
let response = handler(serde_json::json!({ "name": "/fs/readFile" }), ctx).await;
|
||||
let spec = response.result.expect("ok response");
|
||||
assert_eq!(spec.get("name"), Some(&json!("fs/readFile")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_schema_rejects_missing_name() {
|
||||
let registry = registry_with_ops();
|
||||
let handler = services_schema_handler(Arc::clone(®istry));
|
||||
let ctx = root_context("req-5");
|
||||
let response = handler(serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "INVALID_INPUT"),
|
||||
other => panic!("expected INVALID_INPUT, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_handler_registered_and_invocable_via_registry() {
|
||||
let registry = registry_with_ops();
|
||||
let list_handler = services_list_handler(Arc::clone(®istry));
|
||||
let schema_handler = services_schema_handler(Arc::clone(®istry));
|
||||
|
||||
let mut discovery_registry = OperationRegistry::new();
|
||||
discovery_registry.register(HandlerRegistration::new(
|
||||
services_list_spec(),
|
||||
list_handler,
|
||||
OperationProvenance::Local,
|
||||
CompositionAuthority::none(),
|
||||
ScopedOperationEnv::empty().into(),
|
||||
Capabilities::new(),
|
||||
));
|
||||
discovery_registry.register(HandlerRegistration::new(
|
||||
services_schema_spec(),
|
||||
schema_handler,
|
||||
OperationProvenance::Local,
|
||||
CompositionAuthority::none(),
|
||||
ScopedOperationEnv::empty().into(),
|
||||
Capabilities::new(),
|
||||
));
|
||||
let discovery = Arc::new(discovery_registry);
|
||||
|
||||
let ctx = root_context("req-6");
|
||||
let response = discovery
|
||||
.invoke(NAME_SERVICES_LIST, serde_json::json!({}), ctx)
|
||||
.await;
|
||||
let output = response.result.expect("list ok");
|
||||
assert!(output.get("operations").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_name_strips_leading_slash() {
|
||||
assert_eq!(normalize_name("/fs/readFile"), "fs/readFile");
|
||||
assert_eq!(normalize_name("fs/readFile"), "fs/readFile");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn op_type_str_matches_wire_enum() {
|
||||
assert_eq!(op_type_str(OperationType::Query), "query");
|
||||
assert_eq!(op_type_str(OperationType::Mutation), "mutation");
|
||||
assert_eq!(op_type_str(OperationType::Subscription), "subscription");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn visibility_str_matches_wire_enum() {
|
||||
assert_eq!(visibility_str(Visibility::External), "external");
|
||||
assert_eq!(visibility_str(Visibility::Internal), "internal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_to_json_round_trips_error_schemas() {
|
||||
let spec = OperationSpec::new(
|
||||
"fs/readFile",
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
json!({ "type": "object" }),
|
||||
json!({ "type": "string" }),
|
||||
vec![super::super::spec::ErrorDefinition {
|
||||
code: "FILE_NOT_FOUND".to_string(),
|
||||
description: "file not found".to_string(),
|
||||
schema: json!({ "type": "object", "properties": { "path": { "type": "string" } } }),
|
||||
http_status: Some(404),
|
||||
}],
|
||||
AccessControl {
|
||||
required_scopes: vec!["fs:read".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let json_val = spec_to_json(&spec);
|
||||
let error_schemas = json_val
|
||||
.get("error_schemas")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("error_schemas");
|
||||
assert_eq!(error_schemas.len(), 1);
|
||||
assert_eq!(error_schemas[0].get("code"), Some(&json!("FILE_NOT_FOUND")));
|
||||
assert_eq!(error_schemas[0].get("http_status"), Some(&json!(404)));
|
||||
let acl = json_val.get("access_control").expect("access_control");
|
||||
assert_eq!(acl.get("required_scopes"), Some(&json!(["fs:read"])));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_filters_by_access_control_authorized_peer() {
|
||||
let registry = registry_with_access_controlled_ops();
|
||||
let handler = services_list_handler(Arc::clone(®istry));
|
||||
let ctx = root_context_with_identity(
|
||||
"req-acl-1",
|
||||
Some(identity_with_scopes("admin-peer", &["admin"])),
|
||||
);
|
||||
let names = op_names(handler(serde_json::json!({}), ctx).await);
|
||||
assert!(names.contains(&"public/echo".to_string()));
|
||||
assert!(names.contains(&"admin/secret".to_string()));
|
||||
assert!(!names.contains(&"internal/hidden".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_filters_by_access_control_unauthorized_peer() {
|
||||
let registry = registry_with_access_controlled_ops();
|
||||
let handler = services_list_handler(Arc::clone(®istry));
|
||||
let ctx = root_context_with_identity(
|
||||
"req-acl-2",
|
||||
Some(identity_with_scopes("regular-peer", &["user"])),
|
||||
);
|
||||
let names = op_names(handler(serde_json::json!({}), ctx).await);
|
||||
assert!(names.contains(&"public/echo".to_string()));
|
||||
assert!(
|
||||
!names.contains(&"admin/secret".to_string()),
|
||||
"unauthorized peer must not see admin/secret"
|
||||
);
|
||||
assert!(!names.contains(&"internal/hidden".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_op_with_default_acl_listed_to_any_peer() {
|
||||
let registry = registry_with_access_controlled_ops();
|
||||
let handler = services_list_handler(Arc::clone(®istry));
|
||||
let ctx = root_context_with_identity("req-acl-3", None);
|
||||
let names = op_names(handler(serde_json::json!({}), ctx).await);
|
||||
assert!(
|
||||
names.contains(&"public/echo".to_string()),
|
||||
"default AccessControl op must be listed to unauthenticated peer"
|
||||
);
|
||||
assert!(!names.contains(&"admin/secret".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_peers_attributes_ops_by_peer_id() {
|
||||
struct PeerEnv {
|
||||
peers: HashMap<String, Vec<String>>,
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl crate::registry::env::OperationEnv for PeerEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_ns: &str,
|
||||
_op: &str,
|
||||
_input: Value,
|
||||
parent: &OperationContext,
|
||||
_policy: crate::registry::context::AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::ok(parent.request_id.clone(), json!({}))
|
||||
}
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
false
|
||||
}
|
||||
fn peer_ids(&self) -> Vec<crate::registry::env::PeerId> {
|
||||
self.peers.keys().cloned().collect()
|
||||
}
|
||||
fn peer_operations(&self, peer: &crate::registry::env::PeerId) -> Vec<String> {
|
||||
self.peers.get(peer).cloned().unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
let mut peers = HashMap::new();
|
||||
peers.insert(
|
||||
"worker-a".to_string(),
|
||||
vec!["container/exec".to_string(), "container/logs".to_string()],
|
||||
);
|
||||
peers.insert("worker-b".to_string(), vec!["container/exec".to_string()]);
|
||||
let env: Arc<dyn crate::registry::env::OperationEnv + Send + Sync> =
|
||||
Arc::new(PeerEnv { peers });
|
||||
|
||||
let registry = registry_with_access_controlled_ops();
|
||||
let handler = services_list_peers_handler(Arc::clone(®istry));
|
||||
let ctx = OperationContext {
|
||||
request_id: "req-peers-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env: ScopedOperationEnv::empty(),
|
||||
env,
|
||||
abort_policy: crate::registry::context::AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: false,
|
||||
};
|
||||
let response = handler(serde_json::json!({}), ctx).await;
|
||||
let output = response.result.expect("ok response");
|
||||
let peers_arr = output
|
||||
.get("peers")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("peers array");
|
||||
let peer_ids: Vec<&str> = peers_arr
|
||||
.iter()
|
||||
.filter_map(|p| p.get("peer_id").and_then(|v| v.as_str()))
|
||||
.collect();
|
||||
assert!(peer_ids.contains(&"local"));
|
||||
assert!(peer_ids.contains(&"worker-a"));
|
||||
assert!(peer_ids.contains(&"worker-b"));
|
||||
let worker_a = peers_arr
|
||||
.iter()
|
||||
.find(|p| p.get("peer_id").and_then(|v| v.as_str()) == Some("worker-a"))
|
||||
.expect("worker-a present");
|
||||
let worker_a_ops = worker_a
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("worker-a operations");
|
||||
let worker_a_names: Vec<&str> = worker_a_ops
|
||||
.iter()
|
||||
.filter_map(|o| o.get("name").and_then(|n| n.as_str()))
|
||||
.collect();
|
||||
assert!(worker_a_names.contains(&"container/exec"));
|
||||
assert!(worker_a_names.contains(&"container/logs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_peers_spec_has_correct_fields() {
|
||||
let spec = services_list_peers_spec();
|
||||
assert_eq!(spec.name, NAME_SERVICES_LIST_PEERS);
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert_eq!(spec.visibility, Visibility::External);
|
||||
assert!(spec.error_schemas.is_empty());
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn services_list_peers_filters_by_access_control() {
|
||||
struct PeerEnv;
|
||||
#[async_trait::async_trait]
|
||||
impl crate::registry::env::OperationEnv for PeerEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_ns: &str,
|
||||
_op: &str,
|
||||
_input: Value,
|
||||
parent: &OperationContext,
|
||||
_policy: crate::registry::context::AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::ok(parent.request_id.clone(), json!({}))
|
||||
}
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
false
|
||||
}
|
||||
fn peer_ids(&self) -> Vec<crate::registry::env::PeerId> {
|
||||
vec!["restricted-peer".to_string()]
|
||||
}
|
||||
fn peer_operations(&self, _peer: &crate::registry::env::PeerId) -> Vec<String> {
|
||||
vec!["admin/secret".to_string(), "public/echo".to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
let registry = registry_with_access_controlled_ops();
|
||||
let handler = services_list_peers_handler(Arc::clone(®istry));
|
||||
let env: Arc<dyn crate::registry::env::OperationEnv + Send + Sync> = Arc::new(PeerEnv);
|
||||
let ctx = OperationContext {
|
||||
request_id: "req-peers-2".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: Some(identity_with_scopes("regular-peer", &["user"])),
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env: ScopedOperationEnv::empty(),
|
||||
env,
|
||||
abort_policy: crate::registry::context::AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: false,
|
||||
};
|
||||
let response = handler(serde_json::json!({}), ctx).await;
|
||||
let output = response.result.expect("ok response");
|
||||
let peers_arr = output
|
||||
.get("peers")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("peers array");
|
||||
let restricted = peers_arr
|
||||
.iter()
|
||||
.find(|p| p.get("peer_id").and_then(|v| v.as_str()) == Some("restricted-peer"))
|
||||
.expect("restricted-peer present");
|
||||
let ops = restricted
|
||||
.get("operations")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("operations");
|
||||
let names: Vec<&str> = ops
|
||||
.iter()
|
||||
.filter_map(|o| o.get("name").and_then(|n| n.as_str()))
|
||||
.collect();
|
||||
assert!(names.contains(&"public/echo"));
|
||||
assert!(
|
||||
!names.contains(&"admin/secret"),
|
||||
"unauthorized peer must not see admin op in list-peers"
|
||||
);
|
||||
}
|
||||
}
|
||||
828
crates/alknet-call/src/registry/env.rs
Normal file
828
crates/alknet-call/src/registry/env.rs
Normal file
@@ -0,0 +1,828 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use super::context::{generate_request_id, AbortPolicy, OperationContext, ScopedOperationEnv};
|
||||
use super::registration::OperationRegistry;
|
||||
use crate::protocol::wire::ResponseEnvelope;
|
||||
|
||||
/// Logical peer identifier (ADR-029 §1, ADR-030 §4). The payload is
|
||||
/// `Identity.id` from `IdentityProvider` resolution (= `PeerEntry.peer_id`),
|
||||
/// stable across key rotation — NOT a connection-assigned UUID and NOT the
|
||||
/// peer's cryptographic material.
|
||||
pub type PeerId = String;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait OperationEnv: Send + Sync {
|
||||
async fn invoke(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
input: Value,
|
||||
parent: &OperationContext,
|
||||
) -> ResponseEnvelope {
|
||||
self.invoke_with_policy(namespace, operation, input, parent, parent.abort_policy)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
input: Value,
|
||||
parent: &OperationContext,
|
||||
policy: AbortPolicy,
|
||||
) -> ResponseEnvelope;
|
||||
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn peer_ids(&self) -> Vec<PeerId> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn peer_contains(&self, _peer: &PeerId, name: &str) -> bool {
|
||||
self.contains(name)
|
||||
}
|
||||
|
||||
fn peer_operations(&self, _peer: &PeerId) -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalOperationEnv {
|
||||
registry: Arc<OperationRegistry>,
|
||||
}
|
||||
|
||||
impl LocalOperationEnv {
|
||||
pub fn new(registry: Arc<OperationRegistry>) -> Self {
|
||||
Self { registry }
|
||||
}
|
||||
|
||||
pub fn registry(&self) -> &Arc<OperationRegistry> {
|
||||
&self.registry
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for LocalOperationEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
input: Value,
|
||||
parent: &OperationContext,
|
||||
policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
let name = format!("{namespace}/{operation}");
|
||||
|
||||
if !parent.scoped_env.allows(&name) {
|
||||
return ResponseEnvelope::not_found(parent.request_id.clone(), &name);
|
||||
}
|
||||
|
||||
let registration = match self.registry.registration(&name) {
|
||||
Some(r) => r,
|
||||
None => return ResponseEnvelope::not_found(parent.request_id.clone(), &name),
|
||||
};
|
||||
|
||||
let context = OperationContext {
|
||||
request_id: generate_request_id(),
|
||||
parent_request_id: Some(parent.request_id.clone()),
|
||||
identity: parent
|
||||
.handler_identity
|
||||
.as_ref()
|
||||
.and_then(|ca| ca.as_identity()),
|
||||
handler_identity: registration.composition_authority.clone(),
|
||||
forwarded_for: None,
|
||||
capabilities: parent.capabilities.clone(),
|
||||
metadata: HashMap::new(),
|
||||
abort_policy: policy,
|
||||
deadline: parent.deadline,
|
||||
scoped_env: registration
|
||||
.scoped_env
|
||||
.clone()
|
||||
.unwrap_or_else(ScopedOperationEnv::empty),
|
||||
env: parent.env.clone(),
|
||||
internal: true,
|
||||
};
|
||||
|
||||
self.registry.invoke(&name, input, context).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-call composite env (ADR-024 + ADR-029 §1). Built by the `Dispatcher`
|
||||
/// in `compose_root_env` from the active layers. The child inherits this by
|
||||
/// `Arc::clone` through `invoke()`. The Layer 2 connection overlay is
|
||||
/// **peer-keyed** — a head node with N worker connections holds a
|
||||
/// `HashMap<PeerId, connection_overlay>`, not one overlay. The singular-
|
||||
/// connection case (one peer) is the degenerate case with a single-entry map.
|
||||
pub struct PeerCompositeEnv {
|
||||
pub base: Arc<dyn OperationEnv + Send + Sync>,
|
||||
pub session: Option<Arc<dyn OperationEnv + Send + Sync>>,
|
||||
pub connections: HashMap<PeerId, Arc<dyn OperationEnv + Send + Sync>>,
|
||||
connection_order: Vec<PeerId>,
|
||||
}
|
||||
|
||||
impl PeerCompositeEnv {
|
||||
pub fn new(base: Arc<dyn OperationEnv + Send + Sync>) -> Self {
|
||||
Self {
|
||||
base,
|
||||
session: None,
|
||||
connections: HashMap::new(),
|
||||
connection_order: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_session(mut self, session: Arc<dyn OperationEnv + Send + Sync>) -> Self {
|
||||
self.session = Some(session);
|
||||
self
|
||||
}
|
||||
|
||||
/// Attach a peer's connection overlay. The `peer_id` comes from
|
||||
/// `connection.identity().id` (IdentityProvider resolution). A connection
|
||||
/// with no resolved identity has no `PeerId` and is NOT attached
|
||||
/// (ADR-030 §5) — its ops are invoked through the `CallConnection` handle
|
||||
/// directly, not via peer-keyed composition.
|
||||
pub fn attach_peer(&mut self, peer_id: PeerId, overlay: Arc<dyn OperationEnv + Send + Sync>) {
|
||||
if !self.connections.contains_key(&peer_id) {
|
||||
self.connection_order.push(peer_id.clone());
|
||||
}
|
||||
self.connections.insert(peer_id, overlay);
|
||||
}
|
||||
|
||||
/// Detach a peer's overlay (on disconnect). The peer's sub-overlay drops;
|
||||
/// in-flight `PeerRef::Specific(that_peer)` gets `NOT_FOUND`.
|
||||
pub fn detach_peer(&mut self, peer_id: &PeerId) {
|
||||
if self.connections.remove(peer_id).is_some() {
|
||||
self.connection_order.retain(|p| p != peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base(&self) -> &Arc<dyn OperationEnv + Send + Sync> {
|
||||
&self.base
|
||||
}
|
||||
|
||||
pub fn session(&self) -> &Option<Arc<dyn OperationEnv + Send + Sync>> {
|
||||
&self.session
|
||||
}
|
||||
|
||||
pub fn connections(&self) -> &HashMap<PeerId, Arc<dyn OperationEnv + Send + Sync>> {
|
||||
&self.connections
|
||||
}
|
||||
|
||||
pub fn connection_order(&self) -> &[PeerId] {
|
||||
&self.connection_order
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for PeerCompositeEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
input: Value,
|
||||
parent: &OperationContext,
|
||||
policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
let name = format!("{namespace}/{operation}");
|
||||
|
||||
if !parent.scoped_env.allows(&name) {
|
||||
return ResponseEnvelope::not_found(parent.request_id.clone(), &name);
|
||||
}
|
||||
|
||||
if let Some(session) = &self.session {
|
||||
if session.contains(&name) {
|
||||
return session
|
||||
.invoke_with_policy(namespace, operation, input, parent, policy)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
for peer_id in &self.connection_order {
|
||||
if let Some(conn_env) = self.connections.get(peer_id) {
|
||||
if conn_env.contains(&name) {
|
||||
return conn_env
|
||||
.invoke_with_policy(namespace, operation, input, parent, policy)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.base
|
||||
.invoke_with_policy(namespace, operation, input, parent, policy)
|
||||
.await
|
||||
}
|
||||
|
||||
fn contains(&self, name: &str) -> bool {
|
||||
self.session.as_ref().is_some_and(|s| s.contains(name))
|
||||
|| self.connections.values().any(|c| c.contains(name))
|
||||
|| self.base.contains(name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::registry::context::CompositionAuthority;
|
||||
use crate::registry::registration::{make_handler, HandlerRegistration, OperationProvenance};
|
||||
use crate::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||
use alknet_core::auth::Identity;
|
||||
use alknet_core::types::Capabilities;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
struct NoopEnv {
|
||||
contains_op: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for NoopEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_namespace: &str,
|
||||
_operation: &str,
|
||||
_input: Value,
|
||||
parent: &OperationContext,
|
||||
_policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::ok(parent.request_id.clone(), Value::String("noop".into()))
|
||||
}
|
||||
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
self.contains_op
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_handler() -> crate::registry::registration::Handler {
|
||||
make_handler(
|
||||
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
|
||||
)
|
||||
}
|
||||
|
||||
fn inspect_handler() -> crate::registry::registration::Handler {
|
||||
make_handler(|_input, context| async move {
|
||||
let internal = context.is_internal();
|
||||
let id = context.identity.as_ref().map(|i| i.id.clone());
|
||||
let forwarded_for_id = context.forwarded_for.as_ref().map(|i| i.id.clone());
|
||||
let metadata_empty = context.metadata.is_empty();
|
||||
let parent_set = context.parent_request_id.is_some();
|
||||
ResponseEnvelope::ok(
|
||||
context.request_id,
|
||||
serde_json::json!({
|
||||
"internal": internal,
|
||||
"identity_id": id,
|
||||
"forwarded_for_id": forwarded_for_id,
|
||||
"metadata_empty": metadata_empty,
|
||||
"parent_set": parent_set,
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn root_context(
|
||||
request_id: &str,
|
||||
identity: Option<Identity>,
|
||||
handler_identity: Option<CompositionAuthority>,
|
||||
scoped_env: ScopedOperationEnv,
|
||||
env: Arc<dyn OperationEnv + Send + Sync>,
|
||||
) -> OperationContext {
|
||||
root_context_with_forwarded_for(
|
||||
request_id,
|
||||
identity,
|
||||
handler_identity,
|
||||
None,
|
||||
scoped_env,
|
||||
env,
|
||||
)
|
||||
}
|
||||
|
||||
fn root_context_with_forwarded_for(
|
||||
request_id: &str,
|
||||
identity: Option<Identity>,
|
||||
handler_identity: Option<CompositionAuthority>,
|
||||
forwarded_for: Option<Identity>,
|
||||
scoped_env: ScopedOperationEnv,
|
||||
env: Arc<dyn OperationEnv + Send + Sync>,
|
||||
) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity,
|
||||
handler_identity,
|
||||
forwarded_for,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env,
|
||||
env,
|
||||
abort_policy: AbortPolicy::default(),
|
||||
deadline: Some(Instant::now() + Duration::from_secs(30)),
|
||||
internal: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn registry_with(
|
||||
name: &str,
|
||||
spec_visibility: Visibility,
|
||||
handler: crate::registry::registration::Handler,
|
||||
composition_authority: Option<CompositionAuthority>,
|
||||
scoped_env: Option<ScopedOperationEnv>,
|
||||
) -> Arc<OperationRegistry> {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
spec_visibility,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
),
|
||||
handler,
|
||||
OperationProvenance::Local,
|
||||
composition_authority,
|
||||
scoped_env,
|
||||
Capabilities::new(),
|
||||
));
|
||||
Arc::new(registry)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_invoke_allowed_op_dispatches() {
|
||||
let registry = registry_with("echo/run", Visibility::External, echo_handler(), None, None);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["echo/run"]);
|
||||
let ctx = root_context("root-1", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("echo", "run", serde_json::json!({"hi": 1}), &ctx)
|
||||
.await;
|
||||
assert!(response.result.is_ok());
|
||||
assert_eq!(response.result.unwrap(), serde_json::json!({"hi": 1}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_invoke_disallowed_op_returns_not_found() {
|
||||
let registry = registry_with("echo/run", Visibility::External, echo_handler(), None, None);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["other/op"]);
|
||||
let ctx = root_context("root-2", None, None, scoped, env.clone());
|
||||
let response = env.invoke("echo", "run", serde_json::json!({}), &ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_invoke_internal_op_dispatches_as_internal_call() {
|
||||
let registry = registry_with(
|
||||
"secret/op",
|
||||
Visibility::Internal,
|
||||
inspect_handler(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["secret/op"]);
|
||||
let ctx = root_context("root-3", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("secret", "op", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
let out = response.result.expect("ok");
|
||||
assert_eq!(out["internal"], Value::Bool(true));
|
||||
assert_eq!(out["parent_set"], Value::Bool(true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_child_identity_is_parent_handler_identity() {
|
||||
let authority = CompositionAuthority::new("agent-chat", ["fs:read".to_string()]);
|
||||
let registry = registry_with(
|
||||
"child/run",
|
||||
Visibility::External,
|
||||
inspect_handler(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["child/run"]);
|
||||
let ctx = root_context(
|
||||
"root-4",
|
||||
Some(Identity {
|
||||
id: "wire-caller".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
}),
|
||||
Some(authority.clone()),
|
||||
scoped,
|
||||
env.clone(),
|
||||
);
|
||||
let response = env
|
||||
.invoke("child", "run", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
let out = response.result.expect("ok");
|
||||
assert_eq!(out["identity_id"], Value::String("agent-chat".into()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_child_metadata_is_fresh_not_parent() {
|
||||
let registry = registry_with(
|
||||
"child/run",
|
||||
Visibility::External,
|
||||
inspect_handler(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["child/run"]);
|
||||
let mut ctx = root_context("root-5", None, None, scoped, env.clone());
|
||||
ctx.metadata
|
||||
.insert("secret".to_string(), Value::String("leak".into()));
|
||||
let response = env
|
||||
.invoke("child", "run", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
let out = response.result.expect("ok");
|
||||
assert_eq!(out["metadata_empty"], Value::Bool(true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_child_does_not_inherit_forwarded_for() {
|
||||
let registry = registry_with(
|
||||
"child/run",
|
||||
Visibility::External,
|
||||
inspect_handler(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["child/run"]);
|
||||
let forwarded = Identity {
|
||||
id: "alice".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let ctx = root_context_with_forwarded_for(
|
||||
"root-ff",
|
||||
None,
|
||||
None,
|
||||
Some(forwarded),
|
||||
scoped,
|
||||
env.clone(),
|
||||
);
|
||||
assert!(ctx.forwarded_for.is_some());
|
||||
let response = env
|
||||
.invoke("child", "run", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
let out = response.result.expect("ok");
|
||||
assert!(
|
||||
out["forwarded_for_id"].is_null(),
|
||||
"composed child must NOT inherit forwarded_for (wire-ingress only, ADR-032)"
|
||||
);
|
||||
}
|
||||
|
||||
struct ProbeEnv {
|
||||
name: String,
|
||||
contains_set: Vec<String>,
|
||||
dispatched: std::sync::Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for ProbeEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
namespace: &str,
|
||||
operation: &str,
|
||||
_input: Value,
|
||||
parent: &OperationContext,
|
||||
_policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
*self.dispatched.lock().unwrap() = Some(format!("{namespace}/{operation}"));
|
||||
ResponseEnvelope::ok(parent.request_id.clone(), Value::String(self.name.clone()))
|
||||
}
|
||||
|
||||
fn contains(&self, name: &str) -> bool {
|
||||
self.contains_set.iter().any(|n| n == name)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_routes_to_session_when_it_contains_op() {
|
||||
let base = Arc::new(NoopEnv { contains_op: true });
|
||||
let session = Arc::new(ProbeEnv {
|
||||
name: "session".to_string(),
|
||||
contains_set: vec!["agent/chat".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let composite = PeerCompositeEnv::new(base).with_session(session.clone());
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["agent/chat"]);
|
||||
let ctx = root_context("root-6", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("agent", "chat", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert_eq!(response.result.unwrap(), Value::String("session".into()));
|
||||
assert_eq!(
|
||||
session.dispatched.lock().unwrap().as_deref(),
|
||||
Some("agent/chat")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_routes_to_first_peer_in_insertion_order() {
|
||||
let base = Arc::new(ProbeEnv {
|
||||
name: "base".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let worker_a = Arc::new(ProbeEnv {
|
||||
name: "worker-a".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let worker_b = Arc::new(ProbeEnv {
|
||||
name: "worker-b".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base);
|
||||
composite.attach_peer("worker-a".to_string(), worker_a.clone());
|
||||
composite.attach_peer("worker-b".to_string(), worker_b.clone());
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-7", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert_eq!(response.result.unwrap(), Value::String("worker-a".into()));
|
||||
assert_eq!(
|
||||
worker_a.dispatched.lock().unwrap().as_deref(),
|
||||
Some("worker/exec")
|
||||
);
|
||||
assert!(worker_b.dispatched.lock().unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_falls_through_to_base_when_no_overlay_contains() {
|
||||
let base = Arc::new(ProbeEnv {
|
||||
name: "base".to_string(),
|
||||
contains_set: vec!["fs/readFile".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let session = Arc::new(ProbeEnv {
|
||||
name: "session".to_string(),
|
||||
contains_set: vec!["agent/chat".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let connection = Arc::new(ProbeEnv {
|
||||
name: "connection".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base.clone()).with_session(session);
|
||||
composite.attach_peer("worker-a".to_string(), connection);
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["fs/readFile"]);
|
||||
let ctx = root_context("root-8", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("fs", "readFile", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert_eq!(response.result.unwrap(), Value::String("base".into()));
|
||||
assert_eq!(
|
||||
base.dispatched.lock().unwrap().as_deref(),
|
||||
Some("fs/readFile")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_reachability_check_returns_not_found() {
|
||||
let base = Arc::new(NoopEnv { contains_op: true });
|
||||
let composite = PeerCompositeEnv::new(base);
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::empty();
|
||||
let ctx = root_context("root-9", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("agent", "chat", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_composite_env_contains_aggregates_layers() {
|
||||
let base = Arc::new(ProbeEnv {
|
||||
name: "base".to_string(),
|
||||
contains_set: vec!["fs/readFile".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let session = Arc::new(ProbeEnv {
|
||||
name: "session".to_string(),
|
||||
contains_set: vec!["agent/chat".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let connection = Arc::new(ProbeEnv {
|
||||
name: "connection".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base).with_session(session);
|
||||
composite.attach_peer("worker-a".to_string(), connection);
|
||||
assert!(composite.contains("fs/readFile"));
|
||||
assert!(composite.contains("agent/chat"));
|
||||
assert!(composite.contains("worker/exec"));
|
||||
assert!(!composite.contains("unknown/op"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_detach_peer_drops_overlay_and_returns_not_found() {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(LocalOperationEnv::new(Arc::new(OperationRegistry::new())));
|
||||
let connection = Arc::new(ProbeEnv {
|
||||
name: "connection".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base);
|
||||
composite.attach_peer("worker-a".to_string(), connection.clone());
|
||||
composite.detach_peer(&"worker-a".to_string());
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-10", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND after detach, got {other:?}"),
|
||||
}
|
||||
assert!(connection.dispatched.lock().unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_detach_peer_then_reattach_routes_again() {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(LocalOperationEnv::new(Arc::new(OperationRegistry::new())));
|
||||
let connection = Arc::new(ProbeEnv {
|
||||
name: "connection".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base);
|
||||
composite.attach_peer("worker-a".to_string(), connection.clone());
|
||||
composite.detach_peer(&"worker-a".to_string());
|
||||
composite.attach_peer("worker-a".to_string(), connection.clone());
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-10b", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert_eq!(response.result.unwrap(), Value::String("connection".into()));
|
||||
assert_eq!(
|
||||
connection.dispatched.lock().unwrap().as_deref(),
|
||||
Some("worker/exec")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_composite_env_attach_peer_preserves_insertion_order_on_re_attach() {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> = Arc::new(NoopEnv { contains_op: true });
|
||||
let overlay_a: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(NoopEnv { contains_op: true });
|
||||
let overlay_b: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(NoopEnv { contains_op: true });
|
||||
let mut composite = PeerCompositeEnv::new(base);
|
||||
composite.attach_peer("worker-a".to_string(), overlay_a);
|
||||
composite.attach_peer("worker-b".to_string(), overlay_b);
|
||||
assert_eq!(composite.connection_order(), &["worker-a", "worker-b"]);
|
||||
let overlay_a2: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(NoopEnv { contains_op: true });
|
||||
composite.attach_peer("worker-a".to_string(), overlay_a2);
|
||||
assert_eq!(
|
||||
composite.connection_order(),
|
||||
&["worker-a", "worker-b"],
|
||||
"re-attach keeps original position"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_composite_env_routes_to_connection_when_session_absent_or_missing() {
|
||||
let base = Arc::new(ProbeEnv {
|
||||
name: "base".to_string(),
|
||||
contains_set: vec![],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let connection = Arc::new(ProbeEnv {
|
||||
name: "connection".to_string(),
|
||||
contains_set: vec!["worker/exec".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let session = Arc::new(ProbeEnv {
|
||||
name: "session".to_string(),
|
||||
contains_set: vec!["agent/chat".to_string()],
|
||||
dispatched: std::sync::Mutex::new(None),
|
||||
});
|
||||
let mut composite = PeerCompositeEnv::new(base).with_session(session);
|
||||
composite.attach_peer("worker-a".to_string(), connection.clone());
|
||||
let env: Arc<dyn OperationEnv + Send + Sync> = Arc::new(composite);
|
||||
let scoped = ScopedOperationEnv::new(["worker/exec"]);
|
||||
let ctx = root_context("root-11", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("worker", "exec", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert_eq!(response.result.unwrap(), Value::String("connection".into()));
|
||||
assert_eq!(
|
||||
connection.dispatched.lock().unwrap().as_deref(),
|
||||
Some("worker/exec")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_unknown_op_after_reachability_pass_returns_not_found() {
|
||||
let registry = Arc::new(OperationRegistry::new());
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["fs/readFile"]);
|
||||
let ctx = root_context("root-12", None, None, scoped, env.clone());
|
||||
let response = env
|
||||
.invoke("fs", "readFile", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_env_child_inherits_parent_deadline() {
|
||||
let registry = registry_with(
|
||||
"child/run",
|
||||
Visibility::External,
|
||||
inspect_handler(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let env = Arc::new(LocalOperationEnv::new(Arc::clone(®istry)));
|
||||
let scoped = ScopedOperationEnv::new(["child/run"]);
|
||||
let deadline = Instant::now() + Duration::from_secs(5);
|
||||
let mut ctx = root_context("root-13", None, None, scoped, env.clone());
|
||||
ctx.deadline = Some(deadline);
|
||||
let response = env
|
||||
.invoke("child", "run", serde_json::json!({}), &ctx)
|
||||
.await;
|
||||
assert!(response.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_env_default_contains_is_true() {
|
||||
let registry = Arc::new(OperationRegistry::new());
|
||||
let env = LocalOperationEnv::new(registry);
|
||||
assert!(env.contains("anything"));
|
||||
assert!(env.contains(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn abort_policy_is_copy() {
|
||||
let p = AbortPolicy::default();
|
||||
let _ = p;
|
||||
let _ = p;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composition_authority_none_propagates_as_none_identity() {
|
||||
assert!(CompositionAuthority::none().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_env_new_exposes_registry() {
|
||||
let registry = Arc::new(OperationRegistry::new());
|
||||
let env = LocalOperationEnv::new(Arc::clone(®istry));
|
||||
assert!(Arc::ptr_eq(env.registry(), ®istry));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_composite_env_accessors_return_refs() {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> = Arc::new(NoopEnv { contains_op: true });
|
||||
let session: Arc<dyn OperationEnv + Send + Sync> = Arc::new(NoopEnv { contains_op: true });
|
||||
let connection: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(NoopEnv { contains_op: false });
|
||||
let mut composite =
|
||||
PeerCompositeEnv::new(Arc::clone(&base)).with_session(Arc::clone(&session));
|
||||
composite.attach_peer("worker-a".to_string(), Arc::clone(&connection));
|
||||
assert!(Arc::ptr_eq(composite.base(), &base));
|
||||
assert!(composite.session().is_some());
|
||||
assert!(composite.connections().get("worker-a").is_some());
|
||||
assert_eq!(composite.connection_order(), &["worker-a"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn peer_composite_env_singular_connection_is_degenerate_single_entry_map() {
|
||||
let base: Arc<dyn OperationEnv + Send + Sync> = Arc::new(NoopEnv { contains_op: true });
|
||||
let connection: Arc<dyn OperationEnv + Send + Sync> =
|
||||
Arc::new(NoopEnv { contains_op: true });
|
||||
let mut composite = PeerCompositeEnv::new(base);
|
||||
composite.attach_peer("worker-a".to_string(), connection);
|
||||
assert_eq!(composite.connections().len(), 1);
|
||||
assert_eq!(composite.connection_order().len(), 1);
|
||||
assert!(composite.connections().contains_key("worker-a"));
|
||||
}
|
||||
}
|
||||
12
crates/alknet-call/src/registry/mod.rs
Normal file
12
crates/alknet-call/src/registry/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Operation registry: specs, handlers, access control, service discovery.
|
||||
//!
|
||||
//! Maps operation names to specs and handlers, enforces access control, and
|
||||
//! dispatches `call.requested` events to local handlers. The registry is
|
||||
//! layered by trust boundary (ADR-024): a curated layer (immutable after
|
||||
//! startup) plus dynamic session and connection overlays.
|
||||
|
||||
pub mod context;
|
||||
pub mod discovery;
|
||||
pub mod env;
|
||||
pub mod registration;
|
||||
pub mod spec;
|
||||
734
crates/alknet-call/src/registry/registration.rs
Normal file
734
crates/alknet-call/src/registry/registration.rs
Normal file
@@ -0,0 +1,734 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use alknet_core::types::Capabilities;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::context::{CompositionAuthority, OperationContext, ScopedOperationEnv};
|
||||
use super::spec::{AccessResult, OperationSpec, Visibility};
|
||||
use crate::protocol::wire::ResponseEnvelope;
|
||||
|
||||
pub type Handler = Arc<
|
||||
dyn Fn(Value, OperationContext) -> Pin<Box<dyn Future<Output = ResponseEnvelope> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OperationProvenance {
|
||||
Local,
|
||||
FromOpenAPI,
|
||||
FromMCP,
|
||||
FromCall,
|
||||
FromJsonSchema,
|
||||
Session,
|
||||
}
|
||||
|
||||
pub struct HandlerRegistration {
|
||||
pub spec: OperationSpec,
|
||||
pub handler: Handler,
|
||||
pub provenance: OperationProvenance,
|
||||
pub composition_authority: Option<CompositionAuthority>,
|
||||
pub scoped_env: Option<ScopedOperationEnv>,
|
||||
pub capabilities: Capabilities,
|
||||
}
|
||||
|
||||
impl HandlerRegistration {
|
||||
pub fn new(
|
||||
spec: OperationSpec,
|
||||
handler: Handler,
|
||||
provenance: OperationProvenance,
|
||||
composition_authority: Option<CompositionAuthority>,
|
||||
scoped_env: Option<ScopedOperationEnv>,
|
||||
capabilities: Capabilities,
|
||||
) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
handler,
|
||||
provenance,
|
||||
composition_authority,
|
||||
scoped_env,
|
||||
capabilities,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OperationRegistry {
|
||||
operations: HashMap<String, HandlerRegistration>,
|
||||
}
|
||||
|
||||
impl OperationRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, registration: HandlerRegistration) {
|
||||
self.operations
|
||||
.insert(registration.spec.name.clone(), registration);
|
||||
}
|
||||
|
||||
pub fn registration(&self, name: &str) -> Option<&HandlerRegistration> {
|
||||
self.operations.get(name)
|
||||
}
|
||||
|
||||
pub fn list_operations(&self) -> Vec<&OperationSpec> {
|
||||
self.operations
|
||||
.values()
|
||||
.filter(|r| r.spec.visibility == Visibility::External)
|
||||
.map(|r| &r.spec)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn invoke(
|
||||
&self,
|
||||
name: &str,
|
||||
input: Value,
|
||||
context: OperationContext,
|
||||
) -> ResponseEnvelope {
|
||||
let request_id = context.request_id.clone();
|
||||
let registration = match self.operations.get(name) {
|
||||
Some(r) => r,
|
||||
None => return ResponseEnvelope::not_found(request_id, name),
|
||||
};
|
||||
|
||||
if registration.spec.visibility == Visibility::Internal && !context.internal {
|
||||
return ResponseEnvelope::not_found(request_id, name);
|
||||
}
|
||||
|
||||
let acl = ®istration.spec.access_control;
|
||||
let identity = if context.internal {
|
||||
context
|
||||
.handler_identity
|
||||
.as_ref()
|
||||
.and_then(|ca| ca.as_identity())
|
||||
} else {
|
||||
context.identity.clone()
|
||||
};
|
||||
|
||||
if let AccessResult::Forbidden(message) = acl.check(identity.as_ref()) {
|
||||
return ResponseEnvelope::forbidden(request_id, message);
|
||||
}
|
||||
|
||||
let handler = Arc::clone(®istration.handler);
|
||||
(handler)(input, context).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OperationRegistryBuilder {
|
||||
operations: HashMap<String, HandlerRegistration>,
|
||||
}
|
||||
|
||||
impl OperationRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn store(mut self, registration: HandlerRegistration) -> Self {
|
||||
self.operations
|
||||
.insert(registration.spec.name.clone(), registration);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_local(
|
||||
self,
|
||||
spec: OperationSpec,
|
||||
handler: Handler,
|
||||
composition_authority: Option<CompositionAuthority>,
|
||||
scoped_env: Option<ScopedOperationEnv>,
|
||||
capabilities: Capabilities,
|
||||
) -> Self {
|
||||
let registration = HandlerRegistration::new(
|
||||
spec,
|
||||
handler,
|
||||
OperationProvenance::Local,
|
||||
composition_authority,
|
||||
scoped_env,
|
||||
capabilities,
|
||||
);
|
||||
self.store(registration)
|
||||
}
|
||||
|
||||
pub fn with_leaf(
|
||||
self,
|
||||
spec: OperationSpec,
|
||||
handler: Handler,
|
||||
capabilities: Capabilities,
|
||||
) -> Self {
|
||||
self.with_leaf_provenance(
|
||||
spec,
|
||||
handler,
|
||||
OperationProvenance::FromOpenAPI,
|
||||
capabilities,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn with_leaf_provenance(
|
||||
self,
|
||||
spec: OperationSpec,
|
||||
handler: Handler,
|
||||
provenance: OperationProvenance,
|
||||
capabilities: Capabilities,
|
||||
) -> Self {
|
||||
let registration =
|
||||
HandlerRegistration::new(spec, handler, provenance, None, None, capabilities);
|
||||
self.store(registration)
|
||||
}
|
||||
|
||||
pub fn with(self, registration: HandlerRegistration) -> Self {
|
||||
self.store(registration)
|
||||
}
|
||||
|
||||
pub fn build(self) -> OperationRegistry {
|
||||
OperationRegistry {
|
||||
operations: self.operations,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistryBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_handler<F, Fut>(f: F) -> Handler
|
||||
where
|
||||
F: Fn(Value, OperationContext) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = ResponseEnvelope> + Send + 'static,
|
||||
{
|
||||
Arc::new(move |input, context| Box::pin(f(input, context)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::wire::CallError;
|
||||
use crate::registry::context::AbortPolicy;
|
||||
use crate::registry::env::OperationEnv;
|
||||
use crate::registry::spec::{AccessControl, OperationType};
|
||||
use alknet_core::auth::Identity;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
struct NoopEnv;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OperationEnv for NoopEnv {
|
||||
async fn invoke_with_policy(
|
||||
&self,
|
||||
_namespace: &str,
|
||||
_operation: &str,
|
||||
_input: Value,
|
||||
_parent: &OperationContext,
|
||||
_policy: AbortPolicy,
|
||||
) -> ResponseEnvelope {
|
||||
ResponseEnvelope::error("test", CallError::internal("noop env does not dispatch"))
|
||||
}
|
||||
|
||||
fn contains(&self, _name: &str) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn root_context(
|
||||
request_id: &str,
|
||||
identity: Option<Identity>,
|
||||
handler_identity: Option<CompositionAuthority>,
|
||||
internal: bool,
|
||||
scoped_env: ScopedOperationEnv,
|
||||
) -> OperationContext {
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity,
|
||||
handler_identity,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: HashMap::new(),
|
||||
scoped_env,
|
||||
env: Arc::new(NoopEnv),
|
||||
abort_policy: AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal,
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_handler() -> Handler {
|
||||
make_handler(
|
||||
|input, context| async move { ResponseEnvelope::ok(context.request_id, input) },
|
||||
)
|
||||
}
|
||||
|
||||
fn error_handler() -> Handler {
|
||||
make_handler(|_input, context| async move {
|
||||
ResponseEnvelope::error(context.request_id, CallError::internal("handler failure"))
|
||||
})
|
||||
}
|
||||
|
||||
fn external_spec(name: &str, acl: AccessControl) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
acl,
|
||||
)
|
||||
}
|
||||
|
||||
fn internal_spec(name: &str, acl: AccessControl) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::Internal,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
acl,
|
||||
)
|
||||
}
|
||||
|
||||
fn identity_with_scopes(id: &str, scopes: &[&str]) -> Identity {
|
||||
Identity {
|
||||
id: id.to_string(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
resources: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_and_invoke_simple_operation() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("echo", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context("req-1", None, None, false, ScopedOperationEnv::empty());
|
||||
let response = registry
|
||||
.invoke("echo", serde_json::json!({"hi": 1}), ctx)
|
||||
.await;
|
||||
assert_eq!(response.request_id, "req-1");
|
||||
assert_eq!(response.result, Ok(serde_json::json!({"hi": 1})));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn internal_op_from_external_call_returns_not_found() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec("secret", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context("req-2", None, None, false, ScopedOperationEnv::empty());
|
||||
let response = registry.invoke("secret", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "NOT_FOUND");
|
||||
assert!(e.message.contains("secret"));
|
||||
}
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn internal_op_from_internal_call_invokes_handler() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec("secret", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context("req-3", None, None, true, ScopedOperationEnv::empty());
|
||||
let response = registry
|
||||
.invoke("secret", serde_json::json!({"x": 2}), ctx)
|
||||
.await;
|
||||
assert_eq!(response.request_id, "req-3");
|
||||
assert_eq!(response.result, Ok(serde_json::json!({"x": 2})));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_op_returns_not_found() {
|
||||
let registry = OperationRegistry::new();
|
||||
let ctx = root_context("req-4", None, None, false, ScopedOperationEnv::empty());
|
||||
let response = registry.invoke("missing", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "NOT_FOUND"),
|
||||
other => panic!("expected NOT_FOUND, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acl_sufficient_scopes_allowed() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec(
|
||||
"admin",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context(
|
||||
"req-5",
|
||||
Some(identity_with_scopes("caller", &["admin"])),
|
||||
None,
|
||||
false,
|
||||
ScopedOperationEnv::empty(),
|
||||
);
|
||||
let response = registry.invoke("admin", serde_json::json!({}), ctx).await;
|
||||
assert!(response.result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acl_insufficient_scopes_forbidden() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec(
|
||||
"admin",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context(
|
||||
"req-6",
|
||||
Some(identity_with_scopes("caller", &["user"])),
|
||||
None,
|
||||
false,
|
||||
ScopedOperationEnv::empty(),
|
||||
);
|
||||
let response = registry.invoke("admin", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "FORBIDDEN");
|
||||
assert!(e.message.contains("admin"));
|
||||
}
|
||||
other => panic!("expected FORBIDDEN, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acl_restricted_op_no_identity_forbidden() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec(
|
||||
"admin",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context("req-7", None, None, false, ScopedOperationEnv::empty());
|
||||
let response = registry.invoke("admin", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "FORBIDDEN");
|
||||
assert_eq!(e.message, "authentication required");
|
||||
}
|
||||
other => panic!("expected FORBIDDEN, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn internal_call_acl_uses_handler_identity() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let composing_authority = CompositionAuthority::new("agent-chat", ["admin".to_string()]);
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec(
|
||||
"secret",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context(
|
||||
"req-8",
|
||||
Some(identity_with_scopes("user", &["user"])),
|
||||
Some(composing_authority),
|
||||
true,
|
||||
ScopedOperationEnv::empty(),
|
||||
);
|
||||
let response = registry.invoke("secret", serde_json::json!({}), ctx).await;
|
||||
assert!(
|
||||
response.result.is_ok(),
|
||||
"internal call should use handler_identity (admin), not caller (user)"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn internal_call_acl_insufficient_handler_identity_forbidden() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let weak_authority = CompositionAuthority::new("weak", ["user".to_string()]);
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec(
|
||||
"secret",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context(
|
||||
"req-9",
|
||||
Some(identity_with_scopes("user", &["admin"])),
|
||||
Some(weak_authority),
|
||||
true,
|
||||
ScopedOperationEnv::empty(),
|
||||
);
|
||||
let response = registry.invoke("secret", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => {
|
||||
assert_eq!(e.code, "FORBIDDEN");
|
||||
assert!(e.message.contains("admin"));
|
||||
}
|
||||
other => panic!("expected FORBIDDEN, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn external_call_acl_uses_caller_identity_not_handler_identity() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let handler_authority = CompositionAuthority::new("agent", ["admin".to_string()]);
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec(
|
||||
"gate",
|
||||
AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
Some(handler_authority),
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context(
|
||||
"req-10",
|
||||
Some(identity_with_scopes("user", &["user"])),
|
||||
Some(CompositionAuthority::new("agent", ["admin".to_string()])),
|
||||
false,
|
||||
ScopedOperationEnv::empty(),
|
||||
);
|
||||
let response = registry.invoke("gate", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "FORBIDDEN"),
|
||||
other => panic!("expected FORBIDDEN, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_operations_returns_external_only() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("echo", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
internal_spec("secret", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ops = registry.list_operations();
|
||||
assert_eq!(ops.len(), 1);
|
||||
assert_eq!(ops[0].name, "echo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_returned_error_passes_through() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("boom", AccessControl::default()),
|
||||
error_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let ctx = root_context("req-11", None, None, false, ScopedOperationEnv::empty());
|
||||
let response = registry.invoke("boom", serde_json::json!({}), ctx).await;
|
||||
match response.result {
|
||||
Err(e) => assert_eq!(e.code, "INTERNAL"),
|
||||
other => panic!("expected INTERNAL error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_local_sets_provenance_local() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with_local(
|
||||
external_spec("echo", AccessControl::default()),
|
||||
echo_handler(),
|
||||
CompositionAuthority::none(),
|
||||
ScopedOperationEnv::empty().into(),
|
||||
Capabilities::new(),
|
||||
)
|
||||
.build();
|
||||
let reg = registry.registration("echo").expect("registered");
|
||||
assert_eq!(reg.provenance, OperationProvenance::Local);
|
||||
assert!(reg.composition_authority.is_none());
|
||||
assert!(reg.scoped_env.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_local_carries_authority_and_scoped_env() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with_local(
|
||||
external_spec("agent", AccessControl::default()),
|
||||
echo_handler(),
|
||||
Some(CompositionAuthority::new("agent", ["fs:read".to_string()])),
|
||||
Some(ScopedOperationEnv::new(["fs/readFile"])),
|
||||
Capabilities::new(),
|
||||
)
|
||||
.build();
|
||||
let reg = registry.registration("agent").expect("registered");
|
||||
assert_eq!(reg.provenance, OperationProvenance::Local);
|
||||
let authority = reg.composition_authority.as_ref().expect("authority set");
|
||||
assert_eq!(authority.label, "agent");
|
||||
assert_eq!(authority.scopes, vec!["fs:read".to_string()]);
|
||||
assert!(reg.scoped_env.is_some());
|
||||
assert!(reg.scoped_env.as_ref().unwrap().allows("fs/readFile"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_leaf_sets_provenance_and_no_authority() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with_leaf(
|
||||
external_spec("vastai", AccessControl::default()),
|
||||
echo_handler(),
|
||||
Capabilities::new(),
|
||||
)
|
||||
.build();
|
||||
let reg = registry.registration("vastai").expect("registered");
|
||||
assert_eq!(reg.provenance, OperationProvenance::FromOpenAPI);
|
||||
assert!(reg.composition_authority.is_none());
|
||||
assert!(reg.scoped_env.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_leaf_provenance_overrides_provenance() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with_leaf_provenance(
|
||||
external_spec("remote", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::FromCall,
|
||||
Capabilities::new(),
|
||||
)
|
||||
.build();
|
||||
let reg = registry.registration("remote").expect("registered");
|
||||
assert_eq!(reg.provenance, OperationProvenance::FromCall);
|
||||
assert!(reg.composition_authority.is_none());
|
||||
assert!(reg.scoped_env.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_takes_full_bundle() {
|
||||
let registration = HandlerRegistration::new(
|
||||
external_spec("agent", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Session,
|
||||
Some(CompositionAuthority::new("sandbox", [])),
|
||||
Some(ScopedOperationEnv::new(["fs/readFile"])),
|
||||
Capabilities::new(),
|
||||
);
|
||||
let registry = OperationRegistryBuilder::new().with(registration).build();
|
||||
let reg = registry.registration("agent").expect("registered");
|
||||
assert_eq!(reg.provenance, OperationProvenance::Session);
|
||||
assert!(reg.composition_authority.is_some());
|
||||
assert!(reg.scoped_env.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_default_is_new() {
|
||||
let builder = OperationRegistryBuilder::default();
|
||||
let registry = builder.build();
|
||||
assert!(registry.list_operations().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_default_is_new() {
|
||||
let registry = OperationRegistry::default();
|
||||
assert!(registry.list_operations().is_empty());
|
||||
assert!(registry.registration("anything").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registration_lookup_returns_bundle_fields() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let authority = CompositionAuthority::new("agent", ["fs:read".to_string()]);
|
||||
let scoped = ScopedOperationEnv::new(["fs/readFile"]);
|
||||
let caps = Capabilities::new().with_api_key("google", "k".to_string());
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("agent", AccessControl::default()),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
Some(authority.clone()),
|
||||
Some(scoped.clone()),
|
||||
caps.clone(),
|
||||
));
|
||||
let reg = registry.registration("agent").expect("found");
|
||||
assert_eq!(reg.spec.name, "agent");
|
||||
assert_eq!(reg.provenance, OperationProvenance::Local);
|
||||
assert_eq!(reg.composition_authority.as_ref().unwrap().label, "agent");
|
||||
assert!(reg.scoped_env.as_ref().unwrap().allows("fs/readFile"));
|
||||
}
|
||||
}
|
||||
327
crates/alknet-call/src/registry/spec.rs
Normal file
327
crates/alknet-call/src/registry/spec.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
//! Operation specifications: `OperationSpec`, `OperationType`, `Visibility`,
|
||||
//! `ErrorDefinition`, and `AccessControl`.
|
||||
//!
|
||||
//! See `docs/architecture/crates/call/operation-registry.md` for the full
|
||||
//! specification.
|
||||
|
||||
use alknet_core::auth::Identity;
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OperationType {
|
||||
Query,
|
||||
Mutation,
|
||||
Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Visibility {
|
||||
External,
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ErrorDefinition {
|
||||
pub code: String,
|
||||
pub description: String,
|
||||
pub schema: Value,
|
||||
pub http_status: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct AccessControl {
|
||||
pub required_scopes: Vec<String>,
|
||||
pub required_scopes_any: Option<Vec<String>>,
|
||||
pub resource_type: Option<String>,
|
||||
pub resource_action: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AccessResult {
|
||||
Allowed,
|
||||
Forbidden(String),
|
||||
}
|
||||
|
||||
impl AccessResult {
|
||||
pub fn is_allowed(&self) -> bool {
|
||||
matches!(self, AccessResult::Allowed)
|
||||
}
|
||||
}
|
||||
|
||||
impl AccessControl {
|
||||
pub fn has_restrictions(&self) -> bool {
|
||||
!self.required_scopes.is_empty()
|
||||
|| self.required_scopes_any.is_some()
|
||||
|| self.resource_type.is_some()
|
||||
|| self.resource_action.is_some()
|
||||
}
|
||||
|
||||
pub fn check(&self, identity: Option<&Identity>) -> AccessResult {
|
||||
if !self.has_restrictions() {
|
||||
return AccessResult::Allowed;
|
||||
}
|
||||
let identity = match identity {
|
||||
Some(id) => id,
|
||||
None => return AccessResult::Forbidden("authentication required".to_string()),
|
||||
};
|
||||
|
||||
for scope in &self.required_scopes {
|
||||
if !identity.scopes.iter().any(|s| s == scope) {
|
||||
return AccessResult::Forbidden(format!("missing required scope: {scope}"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(any) = &self.required_scopes_any {
|
||||
let has_one = any.iter().any(|s| identity.scopes.iter().any(|i| i == s));
|
||||
if !has_one {
|
||||
return AccessResult::Forbidden(
|
||||
"missing required scope (any of: ".to_string() + &any.join(", ") + ")",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(rt) = &self.resource_type {
|
||||
let allowed = identity.resources.get(rt);
|
||||
match &self.resource_action {
|
||||
Some(action) => match allowed {
|
||||
Some(actions) if actions.iter().any(|a| a == action) => {}
|
||||
_ => {
|
||||
return AccessResult::Forbidden(format!("missing resource: {rt}/{action}"))
|
||||
}
|
||||
},
|
||||
None => match allowed {
|
||||
Some(actions) if !actions.is_empty() => {}
|
||||
_ => return AccessResult::Forbidden(format!("missing resource: {rt}")),
|
||||
},
|
||||
}
|
||||
} else if let Some(action) = &self.resource_action {
|
||||
let found = identity
|
||||
.resources
|
||||
.values()
|
||||
.any(|actions| actions.iter().any(|a| a == action));
|
||||
if !found {
|
||||
return AccessResult::Forbidden(format!("missing resource action: {action}"));
|
||||
}
|
||||
}
|
||||
|
||||
AccessResult::Allowed
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct OperationSpec {
|
||||
pub name: String,
|
||||
pub namespace: String,
|
||||
pub op_type: OperationType,
|
||||
pub visibility: Visibility,
|
||||
pub input_schema: Value,
|
||||
pub output_schema: Value,
|
||||
pub error_schemas: Vec<ErrorDefinition>,
|
||||
pub access_control: AccessControl,
|
||||
}
|
||||
|
||||
impl OperationSpec {
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
op_type: OperationType,
|
||||
visibility: Visibility,
|
||||
input_schema: Value,
|
||||
output_schema: Value,
|
||||
error_schemas: Vec<ErrorDefinition>,
|
||||
access_control: AccessControl,
|
||||
) -> Self {
|
||||
let name = name.into();
|
||||
let namespace = name
|
||||
.split('/')
|
||||
.next()
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
Self {
|
||||
name,
|
||||
namespace,
|
||||
op_type,
|
||||
visibility,
|
||||
input_schema,
|
||||
output_schema,
|
||||
error_schemas,
|
||||
access_control,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn path(&self) -> String {
|
||||
format!("/{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn identity(scopes: &[&str], resources: &[(&str, &[&str])]) -> Identity {
|
||||
let mut res = HashMap::new();
|
||||
for (k, v) in resources {
|
||||
res.insert(
|
||||
(*k).to_string(),
|
||||
v.iter().map(|s| (*s).to_string()).collect(),
|
||||
);
|
||||
}
|
||||
Identity {
|
||||
id: "caller".to_string(),
|
||||
scopes: scopes.iter().map(|s| (*s).to_string()).collect(),
|
||||
resources: res,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn path_has_leading_slash() {
|
||||
let spec = OperationSpec::new(
|
||||
"fs/readFile",
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
);
|
||||
assert_eq!(spec.path(), "/fs/readFile");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_derived_from_name() {
|
||||
let spec = OperationSpec::new(
|
||||
"agent/chat",
|
||||
OperationType::Subscription,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
);
|
||||
assert_eq!(spec.namespace, "agent");
|
||||
assert_eq!(spec.name, "agent/chat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn namespace_for_single_segment() {
|
||||
let spec = OperationSpec::new(
|
||||
"list",
|
||||
OperationType::Query,
|
||||
Visibility::Internal,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
);
|
||||
assert_eq!(spec.namespace, "list");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_access_control_allowed_for_all() {
|
||||
let acl = AccessControl::default();
|
||||
assert_eq!(acl.check(None), AccessResult::Allowed);
|
||||
let id = identity(&[], &[]);
|
||||
assert_eq!(acl.check(Some(&id)), AccessResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn none_identity_with_restrictions_forbidden() {
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(
|
||||
acl.check(None),
|
||||
AccessResult::Forbidden("authentication required".to_string())
|
||||
);
|
||||
|
||||
let acl2 = AccessControl {
|
||||
required_scopes_any: Some(vec!["read".to_string()]),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(
|
||||
acl2.check(None),
|
||||
AccessResult::Forbidden("authentication required".to_string())
|
||||
);
|
||||
|
||||
let acl3 = AccessControl {
|
||||
resource_type: Some("service".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(
|
||||
acl3.check(None),
|
||||
AccessResult::Forbidden("authentication required".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn required_scopes_and_checked() {
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["a".to_string(), "b".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
let id_missing = identity(&["a"], &[]);
|
||||
assert!(matches!(
|
||||
acl.check(Some(&id_missing)),
|
||||
AccessResult::Forbidden(_)
|
||||
));
|
||||
let id_ok = identity(&["a", "b", "c"], &[]);
|
||||
assert_eq!(acl.check(Some(&id_ok)), AccessResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn required_scopes_any_or_checked() {
|
||||
let acl = AccessControl {
|
||||
required_scopes_any: Some(vec!["x".to_string(), "y".to_string()]),
|
||||
..Default::default()
|
||||
};
|
||||
let id_x = identity(&["x"], &[]);
|
||||
assert_eq!(acl.check(Some(&id_x)), AccessResult::Allowed);
|
||||
let id_y = identity(&["y"], &[]);
|
||||
assert_eq!(acl.check(Some(&id_y)), AccessResult::Allowed);
|
||||
let id_none = identity(&["z"], &[]);
|
||||
assert!(matches!(
|
||||
acl.check(Some(&id_none)),
|
||||
AccessResult::Forbidden(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resource_check_with_type_and_action() {
|
||||
let acl = AccessControl {
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let id_ok = identity(&[], &[("service", &["read"])]);
|
||||
assert_eq!(acl.check(Some(&id_ok)), AccessResult::Allowed);
|
||||
let id_missing_action = identity(&[], &[("service", &["write"])]);
|
||||
assert!(matches!(
|
||||
acl.check(Some(&id_missing_action)),
|
||||
AccessResult::Forbidden(_)
|
||||
));
|
||||
let id_missing_type = identity(&[], &[("other", &["read"])]);
|
||||
assert!(matches!(
|
||||
acl.check(Some(&id_missing_type)),
|
||||
AccessResult::Forbidden(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_scopes_and_resources() {
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let id_ok = identity(&["admin"], &[("service", &["read"])]);
|
||||
assert_eq!(acl.check(Some(&id_ok)), AccessResult::Allowed);
|
||||
let id_missing_scope = identity(&["user"], &[("service", &["read"])]);
|
||||
assert!(matches!(
|
||||
acl.check(Some(&id_missing_scope)),
|
||||
AccessResult::Forbidden(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
301
crates/alknet-call/tests/two_node_call.rs
Normal file
301
crates/alknet-call/tests/two_node_call.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
//! Integration test: two-node `alknet/call` round-trip over a real QUIC
|
||||
//! loopback. A `CallAdapter` server accepts, a `CallClient` connects, and
|
||||
//! the client calls back into the server (connection symmetry, ADR-017 §2).
|
||||
//! Verifies the shared dispatch loop works end-to-end.
|
||||
|
||||
#![cfg(feature = "quinn")]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use alknet_call::client::{CallClient, CallCredentials};
|
||||
use alknet_call::protocol::adapter::CallAdapter;
|
||||
use alknet_call::protocol::wire::ResponseEnvelope;
|
||||
use alknet_call::registry::discovery::{
|
||||
services_list_handler, services_list_spec, services_schema_handler, services_schema_spec,
|
||||
};
|
||||
use alknet_call::registry::registration::{
|
||||
make_handler, Handler, HandlerRegistration, OperationProvenance, OperationRegistry,
|
||||
};
|
||||
use alknet_call::registry::spec::{AccessControl, OperationSpec, OperationType, Visibility};
|
||||
use alknet_core::auth::{Identity, IdentityProvider};
|
||||
use alknet_core::types::{Capabilities, Connection, ProtocolHandler};
|
||||
|
||||
struct NoopIdentityProvider;
|
||||
impl IdentityProvider for NoopIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, _: &alknet_core::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn external_spec(name: &str) -> OperationSpec {
|
||||
OperationSpec::new(
|
||||
name,
|
||||
OperationType::Query,
|
||||
Visibility::External,
|
||||
serde_json::json!({}),
|
||||
serde_json::json!({}),
|
||||
vec![],
|
||||
AccessControl::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn echo_handler() -> Handler {
|
||||
make_handler(|input, context| async move { ResponseEnvelope::ok(context.request_id, input) })
|
||||
}
|
||||
|
||||
/// Build a raw quinn server endpoint with a self-signed cert and the
|
||||
/// `CallAdapter` accepting `alknet/call` connections. Returns
|
||||
/// `(bound_addr, join_handle)`. The accept loop spawns a task per connection
|
||||
/// that hands the connection to `CallAdapter::handle`.
|
||||
async fn build_raw_quinn_server(
|
||||
registry: Arc<OperationRegistry>,
|
||||
) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NoopIdentityProvider);
|
||||
let adapter = Arc::new(CallAdapter::new(
|
||||
Arc::clone(®istry),
|
||||
Arc::clone(&provider),
|
||||
));
|
||||
|
||||
let key_pair = rcgen::KeyPair::generate().expect("key gen");
|
||||
let params = rcgen::CertificateParams::default();
|
||||
let cert = params.self_signed(&key_pair).expect("self-signed cert");
|
||||
let cert_der = cert.der().clone();
|
||||
let key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(
|
||||
rustls::pki_types::PrivatePkcs8KeyDer::from(key_pair.serialize_der()),
|
||||
);
|
||||
|
||||
let provider_crypto = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
|
||||
let mut server_config = rustls::ServerConfig::builder_with_provider(provider_crypto)
|
||||
.with_safe_default_protocol_versions()
|
||||
.unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert_der], key_der)
|
||||
.unwrap();
|
||||
server_config.alpn_protocols = vec![b"alknet/call".to_vec()];
|
||||
server_config.max_early_data_size = u32::MAX;
|
||||
|
||||
let quic_server_config =
|
||||
quinn::crypto::rustls::QuicServerConfig::try_from(server_config).unwrap();
|
||||
let quinn_server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config));
|
||||
|
||||
let quinn_endpoint =
|
||||
quinn::Endpoint::server(quinn_server_config, "127.0.0.1:0".parse().unwrap())
|
||||
.expect("server bind");
|
||||
let bound_addr = quinn_endpoint.local_addr().expect("local addr");
|
||||
|
||||
let join = tokio::spawn(async move {
|
||||
while let Some(incoming) = quinn_endpoint.accept().await {
|
||||
let adapter = Arc::clone(&adapter);
|
||||
tokio::spawn(async move {
|
||||
let connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(_) => return,
|
||||
};
|
||||
let conn = match connecting.await {
|
||||
Ok(c) => c,
|
||||
Err(_) => return,
|
||||
};
|
||||
let alpn = b"alknet/call".to_vec();
|
||||
let conn = Connection::from_quinn_with_alpn(conn, alpn.clone());
|
||||
let auth = alknet_core::auth::AuthContext {
|
||||
identity: None,
|
||||
alpn,
|
||||
remote_addr: conn.remote_addr(),
|
||||
tls_client_fingerprint: None,
|
||||
};
|
||||
let _ = adapter.handle(conn, &auth).await;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
(bound_addr, join)
|
||||
}
|
||||
|
||||
/// Build the server's registry: an echo op, a secret op, and the
|
||||
/// services/list + services/schema discovery handlers.
|
||||
fn build_server_registry() -> Arc<OperationRegistry> {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("server/echo"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
registry.register(HandlerRegistration::new(
|
||||
external_spec("server/secret"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new().with_api_key("google", "server-secret".to_string()),
|
||||
));
|
||||
let discovery_registry = Arc::new(registry);
|
||||
let list_handler = services_list_handler(Arc::clone(&discovery_registry));
|
||||
let schema_handler = services_schema_handler(Arc::clone(&discovery_registry));
|
||||
let mut full = OperationRegistry::new();
|
||||
full.register(HandlerRegistration::new(
|
||||
external_spec("server/echo"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
full.register(HandlerRegistration::new(
|
||||
external_spec("server/secret"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new().with_api_key("google", "server-secret".to_string()),
|
||||
));
|
||||
full.register(HandlerRegistration::new(
|
||||
services_list_spec(),
|
||||
list_handler,
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
full.register(HandlerRegistration::new(
|
||||
services_schema_spec(),
|
||||
schema_handler,
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
Arc::new(full)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn two_node_call_round_trip() {
|
||||
let server_registry = build_server_registry();
|
||||
let (server_addr, _server_join) = build_raw_quinn_server(Arc::clone(&server_registry)).await;
|
||||
|
||||
// Client side: a CallClient with its own ops so the server can call back
|
||||
// (connection symmetry).
|
||||
let mut client_registry = OperationRegistry::new();
|
||||
client_registry.register(HandlerRegistration::new(
|
||||
external_spec("client/echo"),
|
||||
echo_handler(),
|
||||
OperationProvenance::Local,
|
||||
None,
|
||||
None,
|
||||
Capabilities::new(),
|
||||
));
|
||||
let client_registry = Arc::new(client_registry);
|
||||
let client = CallClient::new(Arc::clone(&client_registry), Arc::new(NoopIdentityProvider));
|
||||
|
||||
let conn = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
client.connect(server_addr, CallCredentials::new()),
|
||||
)
|
||||
.await
|
||||
.expect("connect did not time out")
|
||||
.expect("connect succeeds");
|
||||
|
||||
// Outbound call: client -> server's echo op.
|
||||
let response = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
conn.call("server/echo", serde_json::json!({"hi": 1})),
|
||||
)
|
||||
.await
|
||||
.expect("call did not time out");
|
||||
assert_eq!(response.result, Ok(serde_json::json!({"hi": 1})));
|
||||
|
||||
// Peer authorization is enforced by the AccessControl gate in
|
||||
// OperationRegistry::invoke (ADR-029 §3) — exercised by the unit tests in
|
||||
// `registry/registration.rs`. This integration test focuses on the QUIC
|
||||
// connect path + shared dispatch loop working end-to-end (the call above
|
||||
// proves the CallClient opened a real connection, the shared loop
|
||||
// dispatched, and the CallConnection::call() round-tripped).
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn from_call_discovers_and_forwards_over_quic_loopback() {
|
||||
use alknet_call::client::{from_call, FromCallConfig};
|
||||
use alknet_call::registry::context::ScopedOperationEnv;
|
||||
|
||||
let server_registry = build_server_registry();
|
||||
let (server_addr, _server_join) = build_raw_quinn_server(Arc::clone(&server_registry)).await;
|
||||
|
||||
// Client with an empty registry — from_call will populate its overlay.
|
||||
let client_registry = Arc::new(OperationRegistry::new());
|
||||
let client = CallClient::new(Arc::clone(&client_registry), Arc::new(NoopIdentityProvider));
|
||||
|
||||
let conn = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
client.connect(server_addr, CallCredentials::new()),
|
||||
)
|
||||
.await
|
||||
.expect("connect did not time out")
|
||||
.expect("connect succeeds");
|
||||
|
||||
// from_call discovers the server's External ops (server/echo, server/secret
|
||||
// — both External; services/list + services/schema themselves are External
|
||||
// too) and builds FromCall forwarding-handler bundles. Register them in the
|
||||
// connection's Layer 2 overlay.
|
||||
let bundles = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
from_call(&conn, FromCallConfig::new()),
|
||||
)
|
||||
.await
|
||||
.expect("from_call did not time out")
|
||||
.expect("from_call succeeds");
|
||||
assert!(
|
||||
!bundles.is_empty(),
|
||||
"from_call must discover at least the server/echo op"
|
||||
);
|
||||
conn.register_imported_all(bundles);
|
||||
|
||||
// The overlay now contains the discovered ops. Verify the forwarding path
|
||||
// by invoking the overlay env directly with a scoped context that allows
|
||||
// server/echo — this is how a composing handler would call the imported op.
|
||||
let env = conn.overlay_env();
|
||||
assert!(
|
||||
env.contains("server/echo"),
|
||||
"overlay must contain the imported server/echo op"
|
||||
);
|
||||
|
||||
// Build a minimal parent context to invoke the overlay env (mirrors how a
|
||||
// composing handler dispatches a child).
|
||||
let scoped = ScopedOperationEnv::new(["server/echo"]);
|
||||
let parent = alknet_call::registry::context::OperationContext {
|
||||
request_id: "parent-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
handler_identity: None,
|
||||
forwarded_for: None,
|
||||
capabilities: Capabilities::new(),
|
||||
metadata: Default::default(),
|
||||
scoped_env: scoped,
|
||||
env: env.clone(),
|
||||
abort_policy: alknet_call::registry::context::AbortPolicy::default(),
|
||||
deadline: Some(std::time::Instant::now() + Duration::from_secs(30)),
|
||||
internal: true,
|
||||
};
|
||||
|
||||
let response = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
env.invoke(
|
||||
"server",
|
||||
"echo",
|
||||
serde_json::json!({"from_call": true}),
|
||||
&parent,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("overlay invoke did not time out");
|
||||
assert_eq!(
|
||||
response.result,
|
||||
Ok(serde_json::json!({"from_call": true})),
|
||||
"from_call forwarding handler must round-trip the input to the remote op"
|
||||
);
|
||||
}
|
||||
@@ -3,54 +3,41 @@ name = "alknet-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Core library for Alknet: pluggable SSH tunnel transport, SOCKS5 proxy, port forwarding, and authentication"
|
||||
description = "Core library for ALPN-based protocol dispatch: ProtocolHandler trait, Connection, auth, config, and multi-connectivity endpoint"
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "alknet_core"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots"]
|
||||
iroh = ["dep:iroh", "dep:url"]
|
||||
acme = ["dep:rustls-acme", "dep:futures", "tls"]
|
||||
http = ["dep:axum", "dep:hyper", "dep:hyper-util", "dep:tower", "dep:http-body-util"]
|
||||
irpc = []
|
||||
testutil = []
|
||||
transport-traits = []
|
||||
default = ["quinn"]
|
||||
quinn = ["dep:quinn"]
|
||||
iroh = ["dep:iroh"]
|
||||
acme = ["dep:rustls-acme"]
|
||||
|
||||
[dependencies]
|
||||
russh = "0.49"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
tokio-util = { version = "0.7", features = ["compat"] }
|
||||
tokio-rustls = { version = "0.26", optional = true }
|
||||
rustls = { version = "0.23", optional = true, features = ["aws_lc_rs"] }
|
||||
rustls-pki-types = { version = "1", optional = true }
|
||||
rustls-acme = { version = "0.12", optional = true }
|
||||
futures = { version = "0.3", optional = true }
|
||||
webpki-roots = { version = "0.26", optional = true }
|
||||
iroh = { version = "0.34", optional = true }
|
||||
url = { version = "2", optional = true }
|
||||
async-trait = "0.1"
|
||||
ipnetwork = "0.21.1"
|
||||
arc-swap = "1"
|
||||
quinn = { version = "0.11", optional = true }
|
||||
iroh = { version = "0.35", optional = true }
|
||||
rustls = "0.23"
|
||||
rustls-pki-types = "1"
|
||||
rustls-pemfile = "2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
arc-swap = "1"
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
thiserror = "2"
|
||||
zeroize = { version = "1", features = ["alloc", "derive"] }
|
||||
bytes = "1"
|
||||
futures = "0.3"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
axum = { version = "0.8", optional = true }
|
||||
hyper = { version = "1", optional = true }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "server", "service"], optional = true }
|
||||
tower = { version = "0.5", optional = true }
|
||||
http-body-util = { version = "0.1", optional = true }
|
||||
rand = "0.8"
|
||||
rcgen = "0.13"
|
||||
ed25519-dalek = { version = "2", features = ["rand_core"] }
|
||||
rustls-acme = { version = "0.12", optional = true, features = ["aws-lc-rs"] }
|
||||
|
||||
[dev-dependencies]
|
||||
alknet-core = { path = ".", features = ["testutil", "tls", "iroh", "http"] }
|
||||
tempfile = "3"
|
||||
rcgen = "0.14"
|
||||
rand_core = "0.6"
|
||||
ssh-key = { version = "0.6", features = ["ed25519", "alloc"] }
|
||||
rand = "0.10.1"
|
||||
tempfile = "3"
|
||||
612
crates/alknet-core/src/auth.rs
Normal file
612
crates/alknet-core/src/auth.rs
Normal file
@@ -0,0 +1,612 @@
|
||||
//! Authentication: `AuthContext`, `Identity`, `IdentityProvider`, `AuthToken`,
|
||||
//! `ConfigIdentityProvider`.
|
||||
//!
|
||||
//! See `docs/architecture/crates/core/auth.md` for the full specification and
|
||||
//! [ADR-034](../../../docs/architecture/decisions/034-outgoing-only-x509-and-three-peer-roles.md)
|
||||
//! for the three-remote-roles decision.
|
||||
//!
|
||||
//! # Three remote roles (ADR-034 §1)
|
||||
//!
|
||||
//! The three credential types (`PeerEntry.fingerprints` entries) describe how
|
||||
//! a *single* `PeerEntry` can be authenticated. Separately, there are three
|
||||
//! distinct remote roles that must not be conflated:
|
||||
//!
|
||||
//! | Role | Identity | alknet peer? | `PeerEntry` on local side? |
|
||||
//! |------|----------|--------------|----------------------------|
|
||||
//! | **Public X.509 endpoint** | Domain + CA-issued X.509 | No (local node is a client) | No |
|
||||
//! | **Transport relay** (iroh's DERP-equivalent) | iroh `NodeId` (Ed25519) | No (infrastructure) | No |
|
||||
//! | **Hub / hosting node** | Ed25519 raw key **and/or** X.509 | Yes (full peer) | Yes |
|
||||
//!
|
||||
//! `PeerEntry` (and the `PeerId` it resolves to) is the model for peers in
|
||||
//! the call-protocol peer graph (ADR-029) — peers that get a stable logical
|
||||
//! identity, are addressable via `PeerRef::Specific`, and whose ops land in
|
||||
//! the peer-keyed overlay. A pure-client connection to a public X.509
|
||||
//! endpoint (e.g. a third-party API) is **not** in that graph on the client
|
||||
//! side: no `PeerEntry`, no `PeerId`, no `PeerRef::Specific` routing. The
|
||||
//! asymmetry is deliberate — a public domain's operator can change hands, so
|
||||
//! there is no stable logical identity to attach.
|
||||
//!
|
||||
//! The hub case is an ordinary `PeerEntry` that happens to expose both an
|
||||
//! Ed25519 fingerprint (P2P path) and an X.509 fingerprint
|
||||
//! (`SHA256:<hex>`, WebTransport/HTTPS path) — already supported by
|
||||
//! `PeerEntry.fingerprints: Vec<String>` (ADR-030).
|
||||
//!
|
||||
//! # Client-side verifier selection (ADR-034 §3)
|
||||
//!
|
||||
//! The `CallClient` / `from_openapi` / `from_mcp` client-side
|
||||
//! `ServerCertVerifier` is selected by **whether the local node has a
|
||||
//! `PeerEntry` for the remote**, not by key type alone:
|
||||
//!
|
||||
//! | Local has `PeerEntry` for remote? | Remote cert type | Client verifier |
|
||||
//! |----------------------------------|------------------|-----------------|
|
||||
//! | No (public X.509 endpoint) | X.509 | `WebPkiServerVerifier` (CA verification) |
|
||||
//! | No | Ed25519 raw key | fails closed (no CA to fall back to) |
|
||||
//! | Yes (hub, Ed25519 path) | Ed25519 raw key | fingerprint match (`ed25519:<hex>`) |
|
||||
//! | Yes (hub, X.509 path) | X.509 | fingerprint match (`SHA256:<hex>`) |
|
||||
//!
|
||||
//! This is the key-type-aware verifier from OQ-29, with the peer-model
|
||||
//! criterion (ADR-034) made explicit. The client-side verifier selection is
|
||||
//! a `CallClient` concern (`call/call-client-verifier-selection`), not an
|
||||
//! `IdentityProvider` concern — `IdentityProvider` is unchanged by ADR-034.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::config::{DynamicConfig, PeerEntry};
|
||||
use crate::store::StoreError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Identity {
|
||||
pub id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub resources: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthToken {
|
||||
pub raw: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthContext {
|
||||
pub identity: Option<Identity>,
|
||||
pub alpn: Vec<u8>,
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub tls_client_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
pub trait IdentityProvider: Send + Sync + 'static {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity>;
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity>;
|
||||
}
|
||||
|
||||
/// Write trait — management path, async (ADR-035). `ConfigIdentityProvider`
|
||||
/// does NOT implement this (config reload is its write path). A persistence
|
||||
/// adapter (e.g. `SqliteIdentityProvider` in `alknet-store-sqlite`) does:
|
||||
/// writes hit the backend, emit a honker `NOTIFY`, and the local `LISTEN`
|
||||
/// refreshes the in-memory read index.
|
||||
#[async_trait]
|
||||
pub trait IdentityStore: IdentityProvider {
|
||||
async fn put_peer(&self, peer: &PeerEntry) -> Result<(), StoreError>;
|
||||
async fn update_peer(&self, peer_id: &str, peer: &PeerEntry) -> Result<(), StoreError>;
|
||||
async fn remove_peer(&self, peer_id: &str) -> Result<(), StoreError>;
|
||||
}
|
||||
|
||||
pub struct ConfigIdentityProvider {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigIdentityProvider {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
}
|
||||
|
||||
impl IdentityProvider for ConfigIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
config.auth.resolve_identity_from_fingerprint(fingerprint)
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
let token_str = String::from_utf8_lossy(&token.raw);
|
||||
config.auth.resolve_identity_from_token(&token_str)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{ApiKeyEntry, AuthPolicy, DynamicConfig, PeerEntry, RateLimitConfig};
|
||||
|
||||
fn compute_api_key_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
fn make_provider(
|
||||
config: DynamicConfig,
|
||||
) -> (ConfigIdentityProvider, Arc<ArcSwap<DynamicConfig>>) {
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(config)));
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
(provider, arc_swap)
|
||||
}
|
||||
|
||||
fn peer_entry_with_fingerprint(peer_id: &str, fingerprint: &str) -> PeerEntry {
|
||||
PeerEntry {
|
||||
peer_id: peer_id.to_string(),
|
||||
fingerprints: vec![fingerprint.to_string()],
|
||||
auth_token_hash: None,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
display_name: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn config_with_peer_entry(peer_id: &str, fingerprint: &str) -> DynamicConfig {
|
||||
DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: vec![peer_entry_with_fingerprint(peer_id, fingerprint)],
|
||||
api_keys: Vec::new(),
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn config_with_api_key(entry: ApiKeyEntry) -> DynamicConfig {
|
||||
DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: Vec::new(),
|
||||
api_keys: vec![entry],
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_token_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_fields_and_equality() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert(
|
||||
"service".to_string(),
|
||||
vec!["gitea".to_string(), "registry".to_string()],
|
||||
);
|
||||
let id = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources,
|
||||
};
|
||||
let id2 = id.clone();
|
||||
assert_eq!(id, id2);
|
||||
assert_eq!(id.id, "SHA256:abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_token_is_clone() {
|
||||
let token = AuthToken {
|
||||
raw: b"alk_test".to_vec(),
|
||||
};
|
||||
let cloned = token.clone();
|
||||
assert_eq!(token.raw, cloned.raw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_context_is_clone() {
|
||||
let ctx = AuthContext {
|
||||
identity: None,
|
||||
alpn: b"alknet/test".to_vec(),
|
||||
remote_addr: None,
|
||||
tls_client_fingerprint: None,
|
||||
};
|
||||
let cloned = ctx.clone();
|
||||
assert_eq!(cloned.alpn, b"alknet/test");
|
||||
assert!(cloned.identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_known_returns_some() {
|
||||
let (provider, _) = make_provider(config_with_peer_entry("worker-a", "SHA256:abc123"));
|
||||
let identity = provider
|
||||
.resolve_from_fingerprint("SHA256:abc123")
|
||||
.expect("known fingerprint resolves");
|
||||
assert_eq!(identity.id, "worker-a");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect".to_string()]);
|
||||
assert!(identity.resources.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_unknown_returns_none() {
|
||||
let (provider, _) = make_provider(config_with_peer_entry("worker-a", "SHA256:abc123"));
|
||||
assert!(provider
|
||||
.resolve_from_fingerprint("SHA256:unknown")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_empty_config_returns_none() {
|
||||
let (provider, _) = make_provider(DynamicConfig::default());
|
||||
assert!(provider
|
||||
.resolve_from_fingerprint("SHA256:anything")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_valid_non_expired_returns_some() {
|
||||
let token_str = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token_str);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let (provider, _) = make_provider(config_with_api_key(entry));
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider
|
||||
.resolve_from_token(&token)
|
||||
.expect("valid non-expired token resolves");
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_expired_returns_none() {
|
||||
let token_str = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token_str);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "expired key".to_string(),
|
||||
expires_at: Some(1),
|
||||
};
|
||||
let (provider, _) = make_provider(config_with_api_key(entry));
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_unknown_returns_none() {
|
||||
let token_str = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token_str);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let (provider, _) = make_provider(config_with_api_key(entry));
|
||||
let token = AuthToken {
|
||||
raw: b"alk_unknown".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_wrong_hash_returns_none() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: "sha256:deadbeef".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "wrong hash".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let (provider, _) = make_provider(config_with_api_key(entry));
|
||||
let token = AuthToken {
|
||||
raw: b"alk_testsecret123".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_non_alk_prefix_returns_none() {
|
||||
let (provider, _) = make_provider(DynamicConfig::default());
|
||||
let token = AuthToken {
|
||||
raw: b"bearer_token".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_changes_resolution_immediately() {
|
||||
let (provider, arc_swap) = make_provider(DynamicConfig::default());
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_none());
|
||||
|
||||
let new_config = config_with_peer_entry("worker-a", "SHA256:abc123");
|
||||
arc_swap.store(Arc::new(new_config));
|
||||
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_removes_fingerprint_access_immediately() {
|
||||
let (provider, arc_swap) =
|
||||
make_provider(config_with_peer_entry("worker-a", "SHA256:abc123"));
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_some());
|
||||
|
||||
arc_swap.store(Arc::new(DynamicConfig::default()));
|
||||
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_handle_reloads_config() {
|
||||
use crate::config::ConfigReloadHandle;
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
let handle = ConfigReloadHandle::new(arc_swap);
|
||||
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_none());
|
||||
|
||||
handle.reload(config_with_peer_entry("worker-a", "SHA256:abc123"));
|
||||
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_is_identity_provider_not_store() {
|
||||
fn assert_provider<T: IdentityProvider>() {}
|
||||
fn assert_not_store<T>() {}
|
||||
assert_provider::<ConfigIdentityProvider>();
|
||||
assert_not_store::<ConfigIdentityProvider>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_via_peer_entry_auth_token_hash_returns_peer_id() {
|
||||
let token_str = "peer-bearer-secret";
|
||||
let mut entry = peer_entry_with_fingerprint("worker-a", "SHA256:abc123");
|
||||
entry.auth_token_hash = Some(compute_token_hash(token_str));
|
||||
let config = DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: Vec::new(),
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
};
|
||||
let (provider, _) = make_provider(config);
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider
|
||||
.resolve_from_token(&token)
|
||||
.expect("matching PeerEntry.auth_token_hash resolves");
|
||||
assert_eq!(identity.id, "worker-a");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_falls_through_to_api_key_when_no_peer_entry_matches() {
|
||||
let api_token = "alk_test_secret";
|
||||
let mut entry = peer_entry_with_fingerprint("worker-a", "SHA256:abc123");
|
||||
entry.auth_token_hash = Some(compute_token_hash("different-token"));
|
||||
let api_entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: compute_api_key_hash(api_token),
|
||||
scopes: vec!["admin".to_string()],
|
||||
description: "fall-through key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let config = DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![api_entry],
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
};
|
||||
let (provider, _) = make_provider(config);
|
||||
let token = AuthToken {
|
||||
raw: api_token.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider
|
||||
.resolve_from_token(&token)
|
||||
.expect("api key fall-through resolves");
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["admin".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_peer_entry_returns_none_on_fingerprint_resolution() {
|
||||
let mut entry = peer_entry_with_fingerprint("worker-a", "SHA256:abc123");
|
||||
entry.enabled = false;
|
||||
let config = DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: Vec::new(),
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
};
|
||||
let (provider, _) = make_provider(config);
|
||||
assert!(provider.resolve_from_fingerprint("SHA256:abc123").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_peer_entry_returns_none_on_token_resolution() {
|
||||
let token_str = "peer-bearer-secret";
|
||||
let mut entry = peer_entry_with_fingerprint("worker-a", "SHA256:abc123");
|
||||
entry.auth_token_hash = Some(compute_token_hash(token_str));
|
||||
entry.enabled = false;
|
||||
let config = DynamicConfig {
|
||||
auth: AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: Vec::new(),
|
||||
},
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
};
|
||||
let (provider, _) = make_provider(config);
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod identity_store_tests {
|
||||
use super::*;
|
||||
use crate::config::PeerEntry;
|
||||
use std::collections::HashMap as StdHashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
fn make_peer(peer_id: &str) -> PeerEntry {
|
||||
PeerEntry {
|
||||
peer_id: peer_id.to_string(),
|
||||
fingerprints: vec![format!("SHA256:{peer_id}")],
|
||||
auth_token_hash: None,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: StdHashMap::new(),
|
||||
display_name: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
struct MockIdentityStore {
|
||||
peers: RwLock<HashMap<String, PeerEntry>>,
|
||||
}
|
||||
|
||||
impl MockIdentityStore {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
peers: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockIdentityStore {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
let peers = self.peers.read().unwrap_or_else(|e| e.into_inner());
|
||||
peers.values().find_map(|p| {
|
||||
if p.fingerprints.iter().any(|f| f == fingerprint) && p.enabled {
|
||||
Some(Identity {
|
||||
id: p.peer_id.clone(),
|
||||
scopes: p.scopes.clone(),
|
||||
resources: p.resources.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl IdentityStore for MockIdentityStore {
|
||||
async fn put_peer(&self, peer: &PeerEntry) -> Result<(), StoreError> {
|
||||
let mut peers = self.peers.write().unwrap_or_else(|e| e.into_inner());
|
||||
peers.insert(peer.peer_id.clone(), peer.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_peer(&self, peer_id: &str, peer: &PeerEntry) -> Result<(), StoreError> {
|
||||
let mut peers = self.peers.write().unwrap_or_else(|e| e.into_inner());
|
||||
if !peers.contains_key(peer_id) {
|
||||
return Err(StoreError::NotFound {
|
||||
entity: peer_id.to_string(),
|
||||
});
|
||||
}
|
||||
peers.remove(peer_id);
|
||||
peers.insert(peer.peer_id.clone(), peer.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_peer(&self, peer_id: &str) -> Result<(), StoreError> {
|
||||
let mut peers = self.peers.write().unwrap_or_else(|e| e.into_inner());
|
||||
if peers.remove(peer_id).is_none() {
|
||||
return Err(StoreError::NotFound {
|
||||
entity: peer_id.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_put_peer_upserts() {
|
||||
let store = MockIdentityStore::new();
|
||||
let mut peer = make_peer("worker-a");
|
||||
store.put_peer(&peer).await.unwrap();
|
||||
assert_eq!(
|
||||
store
|
||||
.resolve_from_fingerprint("SHA256:worker-a")
|
||||
.unwrap()
|
||||
.id,
|
||||
"worker-a"
|
||||
);
|
||||
|
||||
peer.display_name = Some("renamed".to_string());
|
||||
store.put_peer(&peer).await.unwrap();
|
||||
let peers = store.peers.read().unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(peers.len(), 1);
|
||||
assert_eq!(
|
||||
peers.get("worker-a").unwrap().display_name.as_deref(),
|
||||
Some("renamed")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_update_peer_existing_succeeds() {
|
||||
let store = MockIdentityStore::new();
|
||||
store.put_peer(&make_peer("worker-a")).await.unwrap();
|
||||
let updated = make_peer("worker-b");
|
||||
store.update_peer("worker-a", &updated).await.unwrap();
|
||||
assert!(store.resolve_from_fingerprint("SHA256:worker-a").is_none());
|
||||
assert!(store.resolve_from_fingerprint("SHA256:worker-b").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_update_peer_missing_returns_not_found() {
|
||||
let store = MockIdentityStore::new();
|
||||
let err = store
|
||||
.update_peer("ghost", &make_peer("ghost"))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, StoreError::NotFound { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_remove_peer_existing_succeeds() {
|
||||
let store = MockIdentityStore::new();
|
||||
store.put_peer(&make_peer("worker-a")).await.unwrap();
|
||||
store.remove_peer("worker-a").await.unwrap();
|
||||
assert!(store.resolve_from_fingerprint("SHA256:worker-a").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_remove_peer_missing_returns_not_found() {
|
||||
let store = MockIdentityStore::new();
|
||||
let err = store.remove_peer("ghost").await.unwrap_err();
|
||||
assert!(matches!(err, StoreError::NotFound { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mock_identity_store_is_identity_provider() {
|
||||
fn assert_provider<T: IdentityProvider>() {}
|
||||
assert_provider::<MockIdentityStore>();
|
||||
}
|
||||
}
|
||||
@@ -1,262 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::auth::identity::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AuthProtocol {
|
||||
VerifyPubkey {
|
||||
fingerprint: String,
|
||||
key_data: Vec<u8>,
|
||||
},
|
||||
VerifyToken {
|
||||
token_bytes: Vec<u8>,
|
||||
timestamp: u64,
|
||||
},
|
||||
ReloadKeys,
|
||||
CheckAccess {
|
||||
identity: Identity,
|
||||
operation: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AuthResult {
|
||||
Ok(Identity),
|
||||
Denied(String),
|
||||
}
|
||||
|
||||
pub struct AuthServiceImpl {
|
||||
provider: ConfigIdentityProvider,
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl AuthServiceImpl {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&dynamic));
|
||||
Self { provider, dynamic }
|
||||
}
|
||||
|
||||
pub fn verify_pubkey(&self, fingerprint: &str) -> AuthResult {
|
||||
match self.provider.resolve_from_fingerprint(fingerprint) {
|
||||
Some(identity) => AuthResult::Ok(identity),
|
||||
None => AuthResult::Denied(format!("key not authorized: {}", fingerprint)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify_token(&self, token: &AuthToken) -> AuthResult {
|
||||
match self.provider.resolve_from_token(token) {
|
||||
Some(identity) => AuthResult::Ok(identity),
|
||||
None => AuthResult::Denied("token verification failed".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_keys(&self) {
|
||||
self.dynamic.rcu(Arc::clone);
|
||||
}
|
||||
|
||||
pub fn check_access(&self, identity: &Identity, operation: &str) -> AuthResult {
|
||||
if identity.scopes.iter().any(|s| s == operation) {
|
||||
AuthResult::Ok(identity.clone())
|
||||
} else {
|
||||
AuthResult::Denied(format!(
|
||||
"identity {} lacks scope: {}",
|
||||
identity.id, operation
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AuthServiceImpl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AuthServiceImpl").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::keys::PrivateKey;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
russh::keys::decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_service(keys_content: &str) -> (AuthServiceImpl, Arc<ArcSwap<DynamicConfig>>) {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let service = AuthServiceImpl::new(Arc::clone(&arc_swap));
|
||||
(service, arc_swap)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_valid() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let result = service.verify_pubkey(&fingerprint);
|
||||
assert!(matches!(result, AuthResult::Ok(_)));
|
||||
if let AuthResult::Ok(identity) = result {
|
||||
assert_eq!(identity.id, fingerprint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_invalid() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let result = service.verify_pubkey("SHA256:invalid");
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_pubkey_matches_identity_provider() {
|
||||
let (service, arc_swap) = make_service(ED25519_PUBLIC_KEY);
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
|
||||
let service_result = service.verify_pubkey(&fingerprint);
|
||||
let provider_result = provider.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match service_result {
|
||||
AuthResult::Ok(identity) => {
|
||||
assert_eq!(identity, provider_result.unwrap());
|
||||
}
|
||||
AuthResult::Denied(_) => {
|
||||
assert!(provider_result.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_verify_token_returns_denied() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let token = AuthToken {
|
||||
raw: b"test-token".to_vec(),
|
||||
};
|
||||
let result = service.verify_token(&token);
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_check_access_granted() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let identity = Identity {
|
||||
id: fingerprint,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = service.check_access(&identity, "relay:connect");
|
||||
assert!(matches!(result, AuthResult::Ok(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_service_check_access_denied() {
|
||||
let (service, _) = make_service(ED25519_PUBLIC_KEY);
|
||||
let identity = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = service.check_access(&identity, "admin:write");
|
||||
assert!(matches!(result, AuthResult::Denied(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_protocol_variants() {
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let verify_pubkey = AuthProtocol::VerifyPubkey {
|
||||
fingerprint: "SHA256:abc".to_string(),
|
||||
key_data: vec![1, 2, 3],
|
||||
};
|
||||
match &verify_pubkey {
|
||||
AuthProtocol::VerifyPubkey {
|
||||
fingerprint,
|
||||
key_data,
|
||||
} => {
|
||||
assert_eq!(fingerprint, "SHA256:abc");
|
||||
assert_eq!(key_data, &vec![1, 2, 3]);
|
||||
}
|
||||
_ => panic!("expected VerifyPubkey variant"),
|
||||
}
|
||||
|
||||
let verify_token = AuthProtocol::VerifyToken {
|
||||
token_bytes: vec![4, 5, 6],
|
||||
timestamp: 12345,
|
||||
};
|
||||
match &verify_token {
|
||||
AuthProtocol::VerifyToken {
|
||||
token_bytes,
|
||||
timestamp,
|
||||
} => {
|
||||
assert_eq!(token_bytes, &vec![4, 5, 6]);
|
||||
assert_eq!(*timestamp, 12345);
|
||||
}
|
||||
_ => panic!("expected VerifyToken variant"),
|
||||
}
|
||||
|
||||
assert!(matches!(AuthProtocol::ReloadKeys, AuthProtocol::ReloadKeys));
|
||||
|
||||
let check = AuthProtocol::CheckAccess {
|
||||
identity: identity.clone(),
|
||||
operation: "relay:connect".to_string(),
|
||||
};
|
||||
match &check {
|
||||
AuthProtocol::CheckAccess {
|
||||
identity: id,
|
||||
operation,
|
||||
} => {
|
||||
assert_eq!(id.id, "SHA256:abc");
|
||||
assert_eq!(operation, "relay:connect");
|
||||
}
|
||||
_ => panic!("expected CheckAccess variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_result_ok_identity() {
|
||||
let identity = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let result = AuthResult::Ok(identity.clone());
|
||||
assert_eq!(result, AuthResult::Ok(identity));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_result_denied_message() {
|
||||
let result = AuthResult::Denied("access denied".to_string());
|
||||
assert_eq!(result, AuthResult::Denied("access denied".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use russh::client;
|
||||
use russh::keys::key::PrivateKeyWithHashAlg;
|
||||
use russh::keys::{PrivateKey, PublicKey};
|
||||
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// Client-side SSH authentication configuration.
|
||||
///
|
||||
/// Holds the private key used for SSH authentication and an optional
|
||||
/// public key override. When no public key is provided, it is derived
|
||||
/// from the private key.
|
||||
pub struct ClientAuthConfig {
|
||||
private_key: Arc<PrivateKey>,
|
||||
public_key: PublicKey,
|
||||
}
|
||||
|
||||
impl ClientAuthConfig {
|
||||
/// Load a `ClientAuthConfig` from a key source (file or in-memory).
|
||||
pub fn from_key_source(source: KeySource) -> Result<Self, ConfigError> {
|
||||
let private_key = crate::auth::keys::load_private_key(source)?;
|
||||
let public_key = private_key.public_key().clone();
|
||||
Ok(Self {
|
||||
private_key: Arc::new(private_key),
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the private key wrapped in `Arc` for use with russh authentication.
|
||||
pub fn private_key(&self) -> Arc<PrivateKey> {
|
||||
Arc::clone(&self.private_key)
|
||||
}
|
||||
|
||||
/// Returns the public key derived from (or overridden for) this config.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.public_key
|
||||
}
|
||||
|
||||
/// Authenticate with the given SSH session handle and username.
|
||||
pub async fn authenticate<H: client::Handler>(
|
||||
&self,
|
||||
handle: &mut client::Handle<H>,
|
||||
username: &str,
|
||||
) -> Result<bool, russh::Error> {
|
||||
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::clone(&self.private_key), None)?;
|
||||
handle.authenticate_publickey(username, key_with_alg).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Client handler implementing `russh::client::Handler`.
|
||||
///
|
||||
/// Provides the callbacks required by russh during the SSH handshake.
|
||||
/// Server key verification is delegated to a configurable callback;
|
||||
/// the default accepts all server keys (suitable for testing or when
|
||||
/// transport-layer verification — e.g. TLS — is already in place).
|
||||
pub struct ClientHandler {
|
||||
pub_key: PublicKey,
|
||||
check_server_key_fn: Box<dyn Fn(&PublicKey) -> bool + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create a new client handler from a `ClientAuthConfig`.
|
||||
pub fn from_config(config: &ClientAuthConfig) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(|_| true),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a client handler with a custom server key verification callback.
|
||||
pub fn with_server_key_check(
|
||||
config: &ClientAuthConfig,
|
||||
check_fn: impl Fn(&PublicKey) -> bool + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
pub_key: config.public_key().clone(),
|
||||
check_server_key_fn: Box::new(check_fn),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the public key associated with this handler.
|
||||
pub fn public_key(&self) -> &PublicKey {
|
||||
&self.pub_key
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl client::Handler for ClientHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn check_server_key(
|
||||
&mut self,
|
||||
server_public_key: &PublicKey,
|
||||
) -> Result<bool, Self::Error> {
|
||||
Ok((self.check_server_key_fn)(server_public_key))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use russh::client::Handler;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
#[test]
|
||||
fn from_key_source_memory() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
assert_eq!(
|
||||
config.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_from_config() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::from_config(&config);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_with_custom_server_key_check() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
assert_eq!(
|
||||
handler.public_key().algorithm(),
|
||||
russh::keys::Algorithm::Ed25519
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_key_source_invalid_key() {
|
||||
let source = KeySource::Memory(b"not a key".to_vec());
|
||||
let result = ClientAuthConfig::from_key_source(source);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_accepts_by_default() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::from_config(&config);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_check_server_key_rejects_with_custom_fn() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let mut handler = ClientHandler::with_server_key_check(&config, |_pk| false);
|
||||
let some_key = config.public_key().clone();
|
||||
let result = handler.check_server_key(&some_key).await.unwrap();
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_key_arc_dedup() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source).unwrap();
|
||||
let key1 = config.private_key();
|
||||
let key2 = config.private_key();
|
||||
assert!(Arc::ptr_eq(&key1, &key2));
|
||||
}
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
//! Identity resolution and the `IdentityProvider` trait.
|
||||
//!
|
||||
//! See [ADR-029](docs/architecture/decisions/029-identity-provider.md) and
|
||||
//! [ADR-028](docs/architecture/decisions/028-identity-model.md).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Identity {
|
||||
pub id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub resources: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthToken {
|
||||
pub raw: Vec<u8>,
|
||||
}
|
||||
|
||||
pub trait IdentityProvider: Send + Sync + 'static {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity>;
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity>;
|
||||
}
|
||||
|
||||
pub struct ConfigIdentityProvider {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigIdentityProvider {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
}
|
||||
|
||||
impl IdentityProvider for ConfigIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
let auth = &config.auth;
|
||||
auth.resolve_identity_from_fingerprint(fingerprint)
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
let config = self.dynamic.load();
|
||||
let auth = &config.auth;
|
||||
let token_str = String::from_utf8_lossy(&token.raw);
|
||||
if token_str.starts_with(crate::config::API_KEY_PREFIX) {
|
||||
return auth.resolve_api_key(&token_str);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::keys::PrivateKey;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
russh::keys::decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_provider(keys_content: &str) -> (ConfigIdentityProvider, Arc<ArcSwap<DynamicConfig>>) {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(Arc::clone(&arc_swap));
|
||||
(provider, arc_swap)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_fields() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert(
|
||||
"service".to_string(),
|
||||
vec!["gitea".to_string(), "registry".to_string()],
|
||||
);
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec![
|
||||
"relay:connect".to_string(),
|
||||
"service:gitea:read".to_string(),
|
||||
],
|
||||
resources,
|
||||
};
|
||||
assert_eq!(identity.id, "SHA256:abc123");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "service:gitea:read"]);
|
||||
assert_eq!(
|
||||
identity.resources.get("service").unwrap(),
|
||||
&vec!["gitea".to_string(), "registry".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_equality() {
|
||||
let id1 = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let id2 = Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
assert_eq!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_inequality_different_id() {
|
||||
let id1 = Identity {
|
||||
id: "a".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let id2 = Identity {
|
||||
id: "b".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolves_valid_fingerprint() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, fingerprint);
|
||||
assert!(!identity.scopes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_invalid_fingerprint() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let identity = provider.resolve_from_fingerprint("SHA256:invalid");
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_empty_config_rejects_all() {
|
||||
let dynamic = DynamicConfig::default();
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
let identity = provider.resolve_from_fingerprint("SHA256:anything");
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolve_from_token_returns_none() {
|
||||
let (provider, _) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let token = AuthToken {
|
||||
raw: b"test-token".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&token).is_none());
|
||||
}
|
||||
|
||||
fn compute_api_key_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_resolves_valid_api_key() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider.resolve_from_token(&auth_token);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_expired_api_key() {
|
||||
let token = "alk_expiredkey1";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_expi".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "expired key".to_string(),
|
||||
expires_at: Some(1),
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_rejects_wrong_hash_api_key() {
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||
.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "bad hash".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: b"alk_testsecret123".to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_api_key_unknown_prefix_falls_through() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_other".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "other key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
assert!(provider.resolve_from_token(&auth_token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_api_key_scopes_in_identity() {
|
||||
let token = "alk_scopedkey12";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = crate::config::ApiKeyEntry {
|
||||
prefix: "alk_sco".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string(), "secrets:derive".to_string()],
|
||||
description: "scoped key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let auth_policy = crate::config::AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry],
|
||||
);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
let arc_swap = Arc::new(ArcSwap::new(Arc::new(dynamic)));
|
||||
let provider = ConfigIdentityProvider::new(arc_swap);
|
||||
|
||||
let auth_token = AuthToken {
|
||||
raw: token.as_bytes().to_vec(),
|
||||
};
|
||||
let identity = provider.resolve_from_token(&auth_token).unwrap();
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "secrets:derive"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_token_holds_raw_bytes() {
|
||||
let token = AuthToken { raw: vec![1, 2, 3] };
|
||||
assert_eq!(token.raw, vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_identity_provider_reflects_config_reload() {
|
||||
let (provider, arc_swap) = make_provider(ED25519_PUBLIC_KEY);
|
||||
let key = load_key().public_key().clone();
|
||||
let fingerprint = format!("{}", key.fingerprint(HashAlg::Sha256));
|
||||
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_some());
|
||||
|
||||
let new_dynamic = DynamicConfig::default();
|
||||
arc_swap.store(Arc::new(new_dynamic));
|
||||
|
||||
let identity = provider.resolve_from_fingerprint(&fingerprint);
|
||||
assert!(identity.is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
//! Key loading and parsing for SSH authentication.
|
||||
//!
|
||||
//! Supports `KeySource` (file path or in-memory) for private keys, public keys,
|
||||
//! and certificate authority entries. All keys must be in OpenSSH format.
|
||||
//! PEM-encoded keys (PKCS#1, PKCS#8) are rejected with a clear error message.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use russh::keys::{decode_secret_key, parse_public_key_base64, PrivateKey, PublicKey};
|
||||
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// Source for key material — either a filesystem path or in-memory bytes.
|
||||
///
|
||||
/// Used throughout the API to accept keys without committing to a specific
|
||||
/// loading mechanism. In-memory keys are primarily for the NAPI wrapper.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KeySource {
|
||||
File(PathBuf),
|
||||
Memory(Vec<u8>),
|
||||
}
|
||||
|
||||
/// A certificate authority entry parsed from an `authorized_keys` file.
|
||||
///
|
||||
/// Contains the CA public key and its associated options (e.g., `cert-authority`,
|
||||
/// `permit-port-forwarding`). Used by `ServerAuthConfig` for certificate validation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertAuthorityEntry {
|
||||
pub public_key: PublicKey,
|
||||
pub options: Vec<String>,
|
||||
}
|
||||
|
||||
fn resolve_bytes(source: &KeySource) -> Result<Vec<u8>, ConfigError> {
|
||||
match source {
|
||||
KeySource::File(path) => {
|
||||
if !path.exists() {
|
||||
return Err(ConfigError::KeyFileNotFound {
|
||||
path: path.display().to_string(),
|
||||
});
|
||||
}
|
||||
std::fs::read(path).map_err(|_| ConfigError::KeyFileNotFound {
|
||||
path: path.display().to_string(),
|
||||
})
|
||||
}
|
||||
KeySource::Memory(data) => Ok(data.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_openssh_private_key(data: &[u8]) -> Result<(), ConfigError> {
|
||||
let s = String::from_utf8_lossy(data);
|
||||
if s.contains("-----BEGIN OPENSSH PRIVATE KEY-----") {
|
||||
return Ok(());
|
||||
}
|
||||
if s.contains("-----BEGIN RSA PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN ENCRYPTED PRIVATE KEY-----")
|
||||
|| s.contains("-----BEGIN EC PRIVATE KEY-----")
|
||||
{
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "PEM-encoded key is not supported; use OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
|
||||
});
|
||||
}
|
||||
Err(ConfigError::InvalidFlag {
|
||||
name: "unrecognized private key format; expected OpenSSH format (-----BEGIN OPENSSH PRIVATE KEY-----)".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_private_key(source: KeySource) -> Result<PrivateKey, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
check_openssh_private_key(&data)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
decode_secret_key(&s, None).map_err(|e| ConfigError::InvalidFlag {
|
||||
name: format!("failed to decode private key: {e}"),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_authorized_keys_line(line: &str) -> Option<Result<(PublicKey, Vec<String>), ConfigError>> {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.splitn(4, ' ').collect();
|
||||
if parts.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut options = Vec::new();
|
||||
let key_type_idx;
|
||||
|
||||
if parts[0].starts_with("cert-authority")
|
||||
|| parts[0].starts_with("no-")
|
||||
|| parts[0].starts_with("permit-")
|
||||
|| parts[0].starts_with("from=")
|
||||
|| parts[0].starts_with("command=")
|
||||
|| parts[0].starts_with("environment=")
|
||||
|| parts[0].starts_with("tunnel=")
|
||||
|| parts[0].starts_with("principals=")
|
||||
{
|
||||
let opts_str = parts[0];
|
||||
options = opts_str.split(',').map(|s| s.to_string()).collect();
|
||||
key_type_idx = 1;
|
||||
} else if parts[0].starts_with("ssh-") || parts[0].starts_with("ecdsa-") {
|
||||
key_type_idx = 0;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
|
||||
if parts.len() <= key_type_idx {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key_base64 = parts[key_type_idx + 1];
|
||||
match parse_public_key_base64(key_base64) {
|
||||
Ok(pk) => Some(Ok((pk, options))),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_public_keys(source: KeySource) -> Result<Vec<PublicKey>, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
let mut keys = Vec::new();
|
||||
for line in s.lines() {
|
||||
if let Some(Ok((pk, _))) = parse_authorized_keys_line(line) {
|
||||
keys.push(pk);
|
||||
}
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
pub fn load_cert_authority_entries(
|
||||
source: KeySource,
|
||||
) -> Result<Vec<CertAuthorityEntry>, ConfigError> {
|
||||
let data = resolve_bytes(&source)?;
|
||||
let s = String::from_utf8_lossy(&data);
|
||||
let mut entries = Vec::new();
|
||||
for line in s.lines() {
|
||||
if let Some(result) = parse_authorized_keys_line(line) {
|
||||
match result {
|
||||
Ok((pk, options)) if !options.is_empty() => {
|
||||
entries.push(CertAuthorityEntry {
|
||||
public_key: pk,
|
||||
options,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
const PEM_PRIVATE_KEY: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEINTuctv5E1hK1bbY8fdp+K06/nwoy/HU++CXqI9EdVhC\n-----END PRIVATE KEY-----\n";
|
||||
|
||||
fn make_authorized_keys(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
write!(f, "{content}").unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_private_key_file(content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ed25519_key_from_file() {
|
||||
let f = make_private_key_file(ED25519_PRIVATE_KEY);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let key = load_private_key(source).unwrap();
|
||||
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ed25519_key_from_memory() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let key = load_private_key(source).unwrap();
|
||||
assert_eq!(key.algorithm(), russh::keys::Algorithm::Ed25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_key_file_not_found() {
|
||||
let source = KeySource::File(PathBuf::from("/nonexistent/key"));
|
||||
let result = load_private_key(source);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, ConfigError::KeyFileNotFound { .. }));
|
||||
assert!(err.to_string().contains("/nonexistent/key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_pem_format() {
|
||||
let source = KeySource::Memory(PEM_PRIVATE_KEY.to_vec());
|
||||
let result = load_private_key(source);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(err, ConfigError::InvalidFlag { .. }));
|
||||
assert!(err.to_string().contains("PEM"));
|
||||
}
|
||||
|
||||
const ED25519_PUBLIC_KEY_2: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
|
||||
#[test]
|
||||
fn parse_authorized_keys_multiple_entries() {
|
||||
let content = format!("{ED25519_PUBLIC_KEY}\n# comment line\n\n{ED25519_PUBLIC_KEY_2}\n");
|
||||
let f = make_authorized_keys(&content);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let keys = load_public_keys(source).unwrap();
|
||||
assert_eq!(keys.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_authorized_keys_from_memory() {
|
||||
let content = format!("{ED25519_PUBLIC_KEY}\n");
|
||||
let source = KeySource::Memory(content.into_bytes());
|
||||
let keys = load_public_keys(source).unwrap();
|
||||
assert_eq!(keys.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_cert_authority_entry() {
|
||||
let content =
|
||||
"cert-authority,permit-port-forwarding ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV CA name\n";
|
||||
let f = make_authorized_keys(content);
|
||||
let source = KeySource::File(f.path().to_path_buf());
|
||||
let entries = load_cert_authority_entries(source).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].options.len(), 2);
|
||||
assert_eq!(entries[0].options[0], "cert-authority");
|
||||
assert_eq!(entries[0].options[1], "permit-port-forwarding");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mixed_authorized_keys() {
|
||||
let content = format!(
|
||||
"{ED25519_PUBLIC_KEY}\ncert-authority ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE CA name\n"
|
||||
);
|
||||
let source = KeySource::Memory(content.into_bytes());
|
||||
let keys = load_public_keys(source.clone()).unwrap();
|
||||
assert_eq!(keys.len(), 2);
|
||||
let entries = load_cert_authority_entries(source).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].options, vec!["cert-authority"]);
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//! SSH authentication (Ed25519 public key and OpenSSH certificate authority).
|
||||
//!
|
||||
//! Supports file-path and in-memory key sources. No password authentication.
|
||||
//! See ADR-012 for the design rationale.
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub mod auth_protocol;
|
||||
pub mod client_auth;
|
||||
pub mod identity;
|
||||
pub mod keys;
|
||||
pub mod server_auth;
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub use auth_protocol::{AuthProtocol, AuthResult, AuthServiceImpl};
|
||||
pub use client_auth::{ClientAuthConfig, ClientHandler};
|
||||
pub use identity::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
pub use keys::{load_private_key, load_public_keys, CertAuthorityEntry, KeySource};
|
||||
pub use server_auth::ServerAuthConfig;
|
||||
@@ -1,395 +0,0 @@
|
||||
//! Server-side authentication configuration and validation.
|
||||
//!
|
||||
//! `ServerAuthConfig` holds the set of authorized public keys and optional certificate
|
||||
//! authority entries. Authentication is key-based only (Ed25519 + optional OpenSSH CA).
|
||||
//! No password authentication. See ADR-012.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
use russh::keys::helpers::EncodedExt;
|
||||
use russh::keys::{Certificate, PublicKey};
|
||||
|
||||
use super::keys::{load_cert_authority_entries, load_public_keys, CertAuthorityEntry, KeySource};
|
||||
use crate::error::AuthError;
|
||||
|
||||
/// Server-side authentication configuration.
|
||||
///
|
||||
/// Holds authorized public keys (constant-time comparison) and optional certificate
|
||||
/// authority entries for validating OpenSSH certificates.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerAuthConfig {
|
||||
pub authorized_keys: HashSet<PublicKey>,
|
||||
pub cert_authorities: Vec<CertAuthorityEntry>,
|
||||
encoded_keys: HashSet<Vec<u8>>,
|
||||
}
|
||||
|
||||
fn encode_key_data(key: &PublicKey) -> Vec<u8> {
|
||||
key.key_data().encoded().unwrap_or_default()
|
||||
}
|
||||
|
||||
impl ServerAuthConfig {
|
||||
pub fn from_keys_and_ca(
|
||||
authorized_keys_source: Option<KeySource>,
|
||||
cert_authority_source: Option<KeySource>,
|
||||
) -> Result<Self, crate::error::ConfigError> {
|
||||
let authorized_keys: HashSet<PublicKey> = match authorized_keys_source {
|
||||
Some(src) => load_public_keys(src)?.into_iter().collect(),
|
||||
None => HashSet::new(),
|
||||
};
|
||||
|
||||
let encoded_keys: HashSet<Vec<u8>> = authorized_keys.iter().map(encode_key_data).collect();
|
||||
|
||||
let cert_authorities = match cert_authority_source {
|
||||
Some(src) => load_cert_authority_entries(src)?,
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
Ok(ServerAuthConfig {
|
||||
authorized_keys,
|
||||
cert_authorities,
|
||||
encoded_keys,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate_publickey(&self, key: &PublicKey) -> Result<(), AuthError> {
|
||||
let encoded = encode_key_data(key);
|
||||
if self.encoded_keys.contains(&encoded) {
|
||||
return Ok(());
|
||||
}
|
||||
Err(AuthError::KeyRejected)
|
||||
}
|
||||
|
||||
pub fn authenticate_certificate(
|
||||
&self,
|
||||
cert: &Certificate,
|
||||
user: &str,
|
||||
client_ip: Option<IpAddr>,
|
||||
) -> Result<(), AuthError> {
|
||||
let matching_ca = self
|
||||
.cert_authorities
|
||||
.iter()
|
||||
.find(|ca| cert.signature_key() == ca.public_key.key_data());
|
||||
|
||||
let ca_entry = match matching_ca {
|
||||
Some(entry) => entry,
|
||||
None => return Err(AuthError::CertInvalid),
|
||||
};
|
||||
|
||||
if cert.verify_signature().is_err() {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
let now_secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
|
||||
return Err(AuthError::CertExpired);
|
||||
}
|
||||
|
||||
let principals = cert.valid_principals();
|
||||
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
|
||||
return Err(AuthError::CertPrincipalMismatch);
|
||||
}
|
||||
|
||||
check_critical_options(cert, ca_entry, client_ip)?;
|
||||
|
||||
check_extensions(cert, ca_entry)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_critical_options(
|
||||
cert: &Certificate,
|
||||
ca_entry: &CertAuthorityEntry,
|
||||
client_ip: Option<IpAddr>,
|
||||
) -> Result<(), AuthError> {
|
||||
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
|
||||
|
||||
for (name, data) in cert.critical_options().iter() {
|
||||
match name.as_str() {
|
||||
"source-address" => {
|
||||
if !check_source_address(data, client_ip) {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
"force-command" => {}
|
||||
"no-pty" => {}
|
||||
_ => {
|
||||
let _ = ca_has_no_pty;
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_extensions(cert: &Certificate, ca_entry: &CertAuthorityEntry) -> Result<(), AuthError> {
|
||||
let ca_permit_port_forwarding = ca_entry
|
||||
.options
|
||||
.iter()
|
||||
.any(|o| o == "permit-port-forwarding");
|
||||
|
||||
if ca_permit_port_forwarding {
|
||||
let cert_allows = cert
|
||||
.extensions()
|
||||
.iter()
|
||||
.any(|(n, _)| n == "permit-port-forwarding");
|
||||
if !cert_allows {
|
||||
return Err(AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_source_address(allowed: &str, client_ip: Option<IpAddr>) -> bool {
|
||||
let Some(ip) = client_ip else {
|
||||
return false;
|
||||
};
|
||||
|
||||
for pattern in allowed.split(',') {
|
||||
let pattern = pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(cidr) = IpNetwork::from_str(pattern) {
|
||||
if cidr.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(net_ip) = IpAddr::from_str(pattern) {
|
||||
if net_ip == ip {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand_core::OsRng;
|
||||
use russh::keys::ssh_key::certificate::{Builder, CertType};
|
||||
use russh::keys::{decode_secret_key, Certificate, PrivateKey};
|
||||
use std::io::Write;
|
||||
|
||||
const CA_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+gAAAJjP22Bpz9tg\naQAAAAtzc2gtZWQyNTUxOQAAACA6pFKBI327JsRFmZULalNjpoUPJMVxzsk9bGbDByat+g\nAAAEBcRrWyUU+lLpjHbaaYN5YeOlvz6HnuBndUWevEmHk00jqkUoEjfbsmxEWZlQtqU2Om\nhQ8kxXHOyT1sZsMHJq36AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const USER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5NgAAAJgM/+f3DP/n\n9wAAAAtzc2gtZWQyNTUxOQAAACAoTr8X7HqltuKBdBdB2Vjb+K7bi3vVPcuWAYIb3ur5Ng\nAAAEADN/ZEFvX/mflX8aEGwS/tMzys564rYEaMzd4vmYKZkShOvxfseqW24oF0F0HZWNv4\nrtuLe9U9y5YBghve6vk2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const OTHER_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdgAAAJgXj2UzF49l\nMwAAAAtzc2gtZWQyNTUxOQAAACC/7V2LLT4WRm1Mfje8eSPWlhN+kNXz2ryKoqCkSrGzdg\nAAAEBVadyi5nAUfkjpp4zyQ08b8h1o4RTEgwtLejTjX5Tycb/tXYstPhZGbUx+N7x5I9aW\nE36Q1fPavIqioKRKsbN2AAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn load_ca_key() -> PrivateKey {
|
||||
decode_secret_key(CA_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn load_user_key() -> PrivateKey {
|
||||
decode_secret_key(USER_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn load_other_key() -> PrivateKey {
|
||||
decode_secret_key(OTHER_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_cert(
|
||||
ca_key: &PrivateKey,
|
||||
user_pub: &PublicKey,
|
||||
valid_after: u64,
|
||||
valid_before: u64,
|
||||
principals: Vec<&str>,
|
||||
) -> Certificate {
|
||||
let key_data: russh::keys::ssh_key::public::KeyData = user_pub.into();
|
||||
let mut builder =
|
||||
Builder::new_with_random_nonce(&mut OsRng, key_data, valid_after, valid_before)
|
||||
.unwrap();
|
||||
|
||||
builder.cert_type(CertType::User).unwrap();
|
||||
|
||||
for p in principals {
|
||||
builder.valid_principal(p).unwrap();
|
||||
}
|
||||
|
||||
builder.sign(ca_key).unwrap()
|
||||
}
|
||||
|
||||
fn make_authorized_keys_file(keys: &[&PublicKey]) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
for key in keys {
|
||||
let line = format!("{}\n", key.to_openssh().unwrap());
|
||||
f.write_all(line.as_bytes()).unwrap();
|
||||
}
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn make_ca_file(ca_pub: &PublicKey, options: &[&str]) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
let opts = if options.is_empty() {
|
||||
"cert-authority".to_string()
|
||||
} else {
|
||||
format!("cert-authority,{}", options.join(","))
|
||||
};
|
||||
let line = format!("{} {} CA\n", opts, ca_pub.to_openssh().unwrap());
|
||||
f.write_all(line.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_key_accepted() {
|
||||
let user_key = load_user_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let f = make_authorized_keys_file(&[&user_pub]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
assert!(config.authenticate_publickey(&user_pub).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_key_rejected() {
|
||||
let user_key = load_user_key();
|
||||
let other_key = load_other_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let other_pub = other_key.public_key().clone();
|
||||
let f = make_authorized_keys_file(&[&user_pub]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_publickey(&other_pub),
|
||||
Err(AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_authority_signed_cert_accepted() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["testuser"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert!(config
|
||||
.authenticate_certificate(&cert, "testuser", None)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_cert_rejected() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 7200, now - 3600, vec!["testuser"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "testuser", None),
|
||||
Err(AuthError::CertExpired)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_principal_rejected() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(&ca_key, &user_pub, now - 60, now + 3600, vec!["alice"]);
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "bob", None),
|
||||
Err(AuthError::CertPrincipalMismatch)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_wildcard_principals_accepts_any_user() {
|
||||
let ca_key = load_ca_key();
|
||||
let user_key = load_user_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let key_data: russh::keys::ssh_key::public::KeyData = (&user_pub).into();
|
||||
let mut builder =
|
||||
Builder::new_with_random_nonce(&mut OsRng, key_data, now - 60, now + 3600).unwrap();
|
||||
builder.cert_type(CertType::User).unwrap();
|
||||
builder.all_principals_valid().unwrap();
|
||||
let cert = builder.sign(&ca_key).unwrap();
|
||||
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert!(config
|
||||
.authenticate_certificate(&cert, "anyuser", None)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_wrong_ca_rejected() {
|
||||
let user_key = load_user_key();
|
||||
let other_ca_key = load_other_key();
|
||||
let user_pub = user_key.public_key().clone();
|
||||
let now = now_secs();
|
||||
let cert = make_cert(
|
||||
&other_ca_key,
|
||||
&user_pub,
|
||||
now - 60,
|
||||
now + 3600,
|
||||
vec!["testuser"],
|
||||
);
|
||||
let ca_key = load_ca_key();
|
||||
let ca_pub = ca_key.public_key().clone();
|
||||
let f = make_ca_file(&ca_pub, &[]);
|
||||
let config =
|
||||
ServerAuthConfig::from_keys_and_ca(None, Some(KeySource::File(f.path().to_path_buf())))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
config.authenticate_certificate(&cert, "testuser", None),
|
||||
Err(AuthError::CertInvalid)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_config_accepts_nothing() {
|
||||
let config = ServerAuthConfig::from_keys_and_ca(None, None).unwrap();
|
||||
let other_pub = load_other_key().public_key().clone();
|
||||
assert_eq!(
|
||||
config.authenticate_publickey(&other_pub),
|
||||
Err(AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OperationContext {
|
||||
pub request_id: String,
|
||||
pub parent_request_id: Option<String>,
|
||||
pub identity: Option<crate::auth::Identity>,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub env: OperationEnv,
|
||||
pub trusted: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::OperationRegistry;
|
||||
|
||||
fn make_context() -> OperationContext {
|
||||
let registry = OperationRegistry::new();
|
||||
OperationContext {
|
||||
request_id: "req-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_context_fields() {
|
||||
let ctx = make_context();
|
||||
assert_eq!(ctx.request_id, "req-1");
|
||||
assert!(ctx.parent_request_id.is_none());
|
||||
assert!(ctx.identity.is_none());
|
||||
assert!(ctx.metadata.is_empty());
|
||||
assert!(!ctx.trusted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_context_with_parent() {
|
||||
let registry = OperationRegistry::new();
|
||||
let ctx = OperationContext {
|
||||
request_id: "req-2".to_string(),
|
||||
parent_request_id: Some("req-1".to_string()),
|
||||
identity: None,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: true,
|
||||
};
|
||||
assert_eq!(ctx.parent_request_id, Some("req-1".to_string()));
|
||||
assert!(ctx.trusted);
|
||||
}
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::registry::OperationRegistry;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::credentials::{CredentialProvider, CredentialSet, SecretStoreCredentialProvider};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OperationEnv {
|
||||
registry: Arc<OperationRegistry>,
|
||||
credential_provider: Arc<dyn CredentialProvider>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OperationEnv {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OperationEnv")
|
||||
.field("registry", &self.registry)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationEnv {
|
||||
pub fn local(registry: OperationRegistry) -> Self {
|
||||
Self {
|
||||
registry: Arc::new(registry),
|
||||
credential_provider: Arc::new(SecretStoreCredentialProvider::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_credential_provider(
|
||||
registry: OperationRegistry,
|
||||
credential_provider: Arc<dyn CredentialProvider>,
|
||||
) -> Self {
|
||||
Self {
|
||||
registry: Arc::new(registry),
|
||||
credential_provider,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
self.credential_provider.get_credentials(service)
|
||||
}
|
||||
|
||||
pub fn invoke(&self, namespace: &str, operation: &str, input: Value) -> ResponseEnvelope {
|
||||
let name = format!("/{namespace}/{operation}");
|
||||
let request_id = format!("env{name}");
|
||||
let context = OperationContext {
|
||||
request_id: request_id.clone(),
|
||||
parent_request_id: None,
|
||||
identity: None,
|
||||
metadata: std::collections::HashMap::new(),
|
||||
env: self.clone(),
|
||||
trusted: true,
|
||||
};
|
||||
self.registry.invoke(&name, input, context)
|
||||
}
|
||||
|
||||
pub fn registry_ref(&self) -> &OperationRegistry {
|
||||
&self.registry
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::registry::OperationRegistryBuilder;
|
||||
use crate::call::spec::{AccessControl, OperationSpec, OperationType};
|
||||
use crate::config::{AuthPolicy, DynamicConfig};
|
||||
use crate::credentials::ConfigCredentialProvider;
|
||||
use arc_swap::ArcSwap;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_spec(name: &str, namespace: &str) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_local_invoke() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/auth/verify", "auth"),
|
||||
Arc::new(|_input, _ctx| {
|
||||
ResponseEnvelope::ok("env-/auth/verify", serde_json::json!({"verified": true}))
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!({"token": "abc"}));
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_invoke_missing() {
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!(null));
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_invoke_trusted() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/auth/verify", "auth"),
|
||||
Arc::new(|_input, ctx| {
|
||||
assert!(ctx.trusted);
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!({"ok": true}))
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::local(registry);
|
||||
let result = env.invoke("auth", "verify", serde_json::json!(null));
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_provides_credentials_from_handler_context() {
|
||||
let mut credentials = HashMap::new();
|
||||
credentials.insert(
|
||||
"vast-ai".to_string(),
|
||||
CredentialSet::Bearer {
|
||||
token: "test-token".to_string(),
|
||||
},
|
||||
);
|
||||
let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials);
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(config)));
|
||||
let provider = Arc::new(ConfigCredentialProvider::new(dynamic));
|
||||
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("/test/creds", "test"),
|
||||
Arc::new(|_input, ctx| {
|
||||
let creds = ctx.env.credentials("vast-ai");
|
||||
match creds {
|
||||
Some(CredentialSet::Bearer { token }) => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({"token": token}),
|
||||
),
|
||||
_ => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({"found": false}),
|
||||
),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.build();
|
||||
|
||||
let env = OperationEnv::with_credential_provider(registry, provider);
|
||||
let result = env.invoke("test", "creds", serde_json::json!(null));
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
assert_eq!(value["token"], "test-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_credentials_returns_none_for_missing_service() {
|
||||
let config = DynamicConfig::default();
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(config)));
|
||||
let provider = Arc::new(ConfigCredentialProvider::new(dynamic));
|
||||
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::with_credential_provider(registry, provider);
|
||||
assert!(env.credentials("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_env_default_credentials_returns_none() {
|
||||
let registry = OperationRegistry::new();
|
||||
let env = OperationEnv::local(registry);
|
||||
assert!(env.credentials("vast-ai").is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EventEnvelope {
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: String,
|
||||
pub id: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
impl EventEnvelope {
|
||||
pub fn new(event_type: impl Into<String>, id: impl Into<String>, payload: Value) -> Self {
|
||||
Self {
|
||||
r#type: event_type.into(),
|
||||
id: id.into(),
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_requested(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_REQUESTED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_responded(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_RESPONDED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_completed(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_COMPLETED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_aborted(id: impl Into<String>, payload: Value) -> Self {
|
||||
Self::new(super::events::CALL_ABORTED, id, payload)
|
||||
}
|
||||
|
||||
pub fn call_error(
|
||||
id: impl Into<String>,
|
||||
code: impl Into<String>,
|
||||
message: impl Into<String>,
|
||||
retryable: bool,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
super::events::CALL_ERROR,
|
||||
id,
|
||||
serde_json::json!({
|
||||
"code": code.into(),
|
||||
"message": message.into(),
|
||||
"retryable": retryable,
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn event_envelope_new() {
|
||||
let env = EventEnvelope::new(
|
||||
"call.requested",
|
||||
"req-1",
|
||||
serde_json::json!({"key": "value"}),
|
||||
);
|
||||
assert_eq!(env.r#type, "call.requested");
|
||||
assert_eq!(env.id, "req-1");
|
||||
assert_eq!(env.payload, serde_json::json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_serialization() {
|
||||
let env = EventEnvelope::new(
|
||||
"call.requested",
|
||||
"req-1",
|
||||
serde_json::json!({"key": "value"}),
|
||||
);
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: EventEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.r#type, "call.requested");
|
||||
assert_eq!(deserialized.id, "req-1");
|
||||
assert_eq!(deserialized.payload, serde_json::json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_serialization_type_field() {
|
||||
let env = EventEnvelope::new("call.requested", "req-1", serde_json::json!(null));
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
assert!(serialized.contains("\"type\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_deserialization() {
|
||||
let json = r#"{"type":"call.responded","id":"req-42","payload":{"result":"ok"}}"#;
|
||||
let env: EventEnvelope = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(env.r#type, "call.responded");
|
||||
assert_eq!(env.id, "req-42");
|
||||
assert_eq!(env.payload["result"], "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_requested() {
|
||||
let env = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
assert_eq!(env.r#type, "call.requested");
|
||||
assert_eq!(env.id, "req-1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_responded() {
|
||||
let env = EventEnvelope::call_responded("req-1", serde_json::json!({"data": 42}));
|
||||
assert_eq!(env.r#type, "call.responded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_completed() {
|
||||
let env = EventEnvelope::call_completed("req-1", serde_json::json!(null));
|
||||
assert_eq!(env.r#type, "call.completed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_aborted() {
|
||||
let env = EventEnvelope::call_aborted("req-1", serde_json::json!({"reason": "cancelled"}));
|
||||
assert_eq!(env.r#type, "call.aborted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_call_error() {
|
||||
let env = EventEnvelope::call_error("req-1", "TIMEOUT", "timed out", true);
|
||||
assert_eq!(env.r#type, "call.error");
|
||||
assert_eq!(env.id, "req-1");
|
||||
assert_eq!(env.payload["code"], "TIMEOUT");
|
||||
assert_eq!(env.payload["message"], "timed out");
|
||||
assert_eq!(env.payload["retryable"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_envelope_empty_id() {
|
||||
let env = EventEnvelope::new("event.broadcast", "", serde_json::json!({"msg": "hello"}));
|
||||
assert_eq!(env.id, "");
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
pub const CALL_REQUESTED: &str = "call.requested";
|
||||
pub const CALL_RESPONDED: &str = "call.responded";
|
||||
pub const CALL_COMPLETED: &str = "call.completed";
|
||||
pub const CALL_ABORTED: &str = "call.aborted";
|
||||
pub const CALL_ERROR: &str = "call.error";
|
||||
|
||||
pub const SERVICE_LIST: &str = "/services/list";
|
||||
pub const SERVICE_SCHEMA: &str = "/services/schema";
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn event_type_constants() {
|
||||
assert_eq!(CALL_REQUESTED, "call.requested");
|
||||
assert_eq!(CALL_RESPONDED, "call.responded");
|
||||
assert_eq!(CALL_COMPLETED, "call.completed");
|
||||
assert_eq!(CALL_ABORTED, "call.aborted");
|
||||
assert_eq!(CALL_ERROR, "call.error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_operation_constants() {
|
||||
assert_eq!(SERVICE_LIST, "/services/list");
|
||||
assert_eq!(SERVICE_SCHEMA, "/services/schema");
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::call::envelope::EventEnvelope;
|
||||
|
||||
pub fn encode(envelope: &EventEnvelope) -> Vec<u8> {
|
||||
let json = serde_json::to_vec(envelope).expect("EventEnvelope serialization must not fail");
|
||||
let len = json.len() as u32;
|
||||
let mut frame = Vec::with_capacity(4 + json.len());
|
||||
frame.extend_from_slice(&len.to_be_bytes());
|
||||
frame.extend_from_slice(&json);
|
||||
frame
|
||||
}
|
||||
|
||||
pub fn decode(data: &[u8]) -> Result<EventEnvelope, FrameDecodeError> {
|
||||
if data.len() < 4 {
|
||||
return Err(FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
|
||||
if data.len() < 4 + len {
|
||||
return Err(FrameDecodeError::Incomplete {
|
||||
expected: 4 + len,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let body = &data[4..4 + len];
|
||||
let envelope: EventEnvelope = serde_json::from_slice(body).map_err(FrameDecodeError::Json)?;
|
||||
Ok(envelope)
|
||||
}
|
||||
|
||||
pub fn decode_with_remainder(data: &[u8]) -> Result<(EventEnvelope, usize), FrameDecodeError> {
|
||||
if data.len() < 4 {
|
||||
return Err(FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
|
||||
let total = 4 + len;
|
||||
if data.len() < total {
|
||||
return Err(FrameDecodeError::Incomplete {
|
||||
expected: total,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
let body = &data[4..total];
|
||||
let envelope: EventEnvelope = serde_json::from_slice(body).map_err(FrameDecodeError::Json)?;
|
||||
Ok((envelope, total))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FrameDecodeError {
|
||||
#[error("frame too short: expected at least {expected} bytes, got {actual}")]
|
||||
TooShort { expected: usize, actual: usize },
|
||||
#[error("incomplete frame: expected {expected} bytes, got {actual}")]
|
||||
Incomplete { expected: usize, actual: usize },
|
||||
#[error("JSON deserialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub struct FrameFramedReader<S> {
|
||||
stream: S,
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<S> FrameFramedReader<S>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
buf: Vec::with_capacity(4096),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> io::Result<Option<EventEnvelope>> {
|
||||
loop {
|
||||
if self.buf.len() >= 4 {
|
||||
let len = u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]])
|
||||
as usize;
|
||||
let total = 4 + len;
|
||||
if self.buf.len() >= total {
|
||||
let body = &self.buf[4..total];
|
||||
match serde_json::from_slice(body) {
|
||||
Ok(envelope) => {
|
||||
self.buf.drain(..total);
|
||||
return Ok(Some(envelope));
|
||||
}
|
||||
Err(e) => {
|
||||
self.buf.drain(..total);
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut tmp = [0u8; 4096];
|
||||
match self.stream.read(&mut tmp).await {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => self.buf.extend_from_slice(&tmp[..n]),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FrameFramedWriter<S> {
|
||||
stream: S,
|
||||
}
|
||||
|
||||
impl<S> FrameFramedWriter<S>
|
||||
where
|
||||
S: AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self { stream }
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, envelope: &EventEnvelope) -> io::Result<()> {
|
||||
let frame = encode(envelope);
|
||||
self.stream.write_all(&frame).await?;
|
||||
self.stream.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::events;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_round_trip() {
|
||||
let envelope = EventEnvelope::new(
|
||||
events::CALL_REQUESTED,
|
||||
"req-1",
|
||||
json!({"namespace": "auth", "operation": "verify"}),
|
||||
);
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_starts_with_length_prefix() {
|
||||
let envelope = EventEnvelope::new(events::CALL_REQUESTED, "req-1", json!({}));
|
||||
let frame = encode(&envelope);
|
||||
let json = serde_json::to_vec(&envelope).unwrap();
|
||||
let expected_len = json.len() as u32;
|
||||
let stored_len = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
|
||||
assert_eq!(stored_len, expected_len);
|
||||
assert_eq!(frame.len(), 4 + json.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_too_short() {
|
||||
let data = [0u8; 2];
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
FrameDecodeError::TooShort {
|
||||
expected: 4,
|
||||
actual: 2
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_incomplete() {
|
||||
let len = 100u32;
|
||||
let mut data = Vec::new();
|
||||
data.extend_from_slice(&len.to_be_bytes());
|
||||
data.extend_from_slice(&[0u8; 10]);
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
FrameDecodeError::Incomplete {
|
||||
expected: 104,
|
||||
actual: 14
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_invalid_json() {
|
||||
let json = b"not valid json";
|
||||
let mut data = Vec::new();
|
||||
data.extend_from_slice(&(json.len() as u32).to_be_bytes());
|
||||
data.extend_from_slice(json);
|
||||
let result = decode(&data);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), FrameDecodeError::Json(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_with_remainder() {
|
||||
let envelope = EventEnvelope::new(events::CALL_RESPONDED, "req-1", json!({"result": 42}));
|
||||
let frame = encode(&envelope);
|
||||
let mut extended = frame.clone();
|
||||
extended.extend_from_slice(&[0u8; 50]);
|
||||
let (decoded, consumed) = decode_with_remainder(&extended).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
assert_eq!(consumed, frame.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_empty_payload() {
|
||||
let envelope = EventEnvelope::new(events::CALL_COMPLETED, "req-1", json!(null));
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_encode_decode_large_payload() {
|
||||
let large_data: Vec<i32> = (0..1000).collect();
|
||||
let envelope = EventEnvelope::new(events::CALL_RESPONDED, "req-big", json!(large_data));
|
||||
let frame = encode(&envelope);
|
||||
let decoded = decode(&frame).unwrap();
|
||||
assert_eq!(decoded, envelope);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_with_remainder_too_short() {
|
||||
let data = [0u8; 1];
|
||||
let result = decode_with_remainder(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
//! Call protocol layer (Layer 3) of the three-layer model.
|
||||
//!
|
||||
//! See [ADR-024](docs/architecture/decisions/024-call-protocol.md) and
|
||||
//! [ADR-033](docs/architecture/decisions/033-call-protocol-extensions.md).
|
||||
|
||||
pub mod context;
|
||||
pub mod env;
|
||||
pub mod envelope;
|
||||
pub mod events;
|
||||
pub mod frame;
|
||||
pub mod pending;
|
||||
pub mod registry;
|
||||
pub mod response;
|
||||
pub mod services;
|
||||
pub mod spec;
|
||||
|
||||
pub use context::OperationContext;
|
||||
pub use env::OperationEnv;
|
||||
pub use envelope::EventEnvelope;
|
||||
pub use events::{CALL_ABORTED, CALL_COMPLETED, CALL_ERROR, CALL_REQUESTED, CALL_RESPONDED};
|
||||
pub use frame::{
|
||||
decode, decode_with_remainder, encode, FrameDecodeError, FrameFramedReader, FrameFramedWriter,
|
||||
};
|
||||
pub use pending::PendingRequestMap;
|
||||
pub use registry::{Handler, OperationRegistry, OperationRegistryBuilder};
|
||||
pub use response::{CallError, ResponseEnvelope};
|
||||
pub use services::{register_default_operations, services_list_spec, services_schema_spec};
|
||||
pub use spec::{AccessControl, OperationSpec, OperationType};
|
||||
@@ -1,265 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::call::response::CallError;
|
||||
|
||||
enum PendingEntry {
|
||||
Call {
|
||||
tx: oneshot::Sender<Result<Value, CallError>>,
|
||||
timeout: Instant,
|
||||
},
|
||||
Subscribe {
|
||||
tx: mpsc::Sender<Result<Value, CallError>>,
|
||||
timeout: Option<Instant>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct PendingRequestMap {
|
||||
pending: HashMap<String, PendingEntry>,
|
||||
}
|
||||
|
||||
impl PendingRequestMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_call(
|
||||
&mut self,
|
||||
request_id: impl Into<String>,
|
||||
tx: oneshot::Sender<Result<Value, CallError>>,
|
||||
timeout: Instant,
|
||||
) {
|
||||
self.pending
|
||||
.insert(request_id.into(), PendingEntry::Call { tx, timeout });
|
||||
}
|
||||
|
||||
pub fn insert_subscribe(
|
||||
&mut self,
|
||||
request_id: impl Into<String>,
|
||||
tx: mpsc::Sender<Result<Value, CallError>>,
|
||||
timeout: Option<Instant>,
|
||||
) {
|
||||
self.pending
|
||||
.insert(request_id.into(), PendingEntry::Subscribe { tx, timeout });
|
||||
}
|
||||
|
||||
pub fn resolve_call(&mut self, request_id: &str, value: Result<Value, CallError>) -> bool {
|
||||
if let Some(PendingEntry::Call { tx, .. }) = self.pending.remove(request_id) {
|
||||
let _ = tx.send(value);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_subscribe(&mut self, request_id: &str, value: Result<Value, CallError>) -> bool {
|
||||
match self.pending.get_mut(request_id) {
|
||||
Some(PendingEntry::Subscribe { tx, .. }) => tx.try_send(value).is_ok(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete_subscribe(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn abort(&mut self, request_id: &str) -> bool {
|
||||
self.pending.remove(request_id).is_some()
|
||||
}
|
||||
|
||||
pub fn contains(&self, request_id: &str) -> bool {
|
||||
self.pending.contains_key(request_id)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.pending.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.pending.is_empty()
|
||||
}
|
||||
|
||||
pub fn sweep_expired(&mut self, now: Instant) -> usize {
|
||||
let expired: Vec<String> = self
|
||||
.pending
|
||||
.iter()
|
||||
.filter(|(_, entry)| match entry {
|
||||
PendingEntry::Call { timeout, .. } => *timeout <= now,
|
||||
PendingEntry::Subscribe { timeout, .. } => timeout.is_some_and(|t| t <= now),
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
let count = expired.len();
|
||||
for id in &expired {
|
||||
self.pending.remove(id);
|
||||
}
|
||||
count
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PendingRequestMap {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_insert_and_resolve_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-1", tx, timeout);
|
||||
assert!(map.contains("req-1"));
|
||||
assert_eq!(map.len(), 1);
|
||||
|
||||
let result = map.resolve_call("req-1", Ok(serde_json::json!({"status": "ok"})));
|
||||
assert!(result);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let response = rx.await.unwrap();
|
||||
assert!(response.is_ok());
|
||||
assert_eq!(response.unwrap(), serde_json::json!({"status": "ok"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_resolve_unknown_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let result = map.resolve_call("unknown", Ok(serde_json::json!(null)));
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_insert_and_push_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, mut rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-1", tx, None);
|
||||
assert!(map.contains("sub-1"));
|
||||
|
||||
let pushed = map.push_subscribe("sub-1", Ok(serde_json::json!({"item": 1})));
|
||||
assert!(pushed);
|
||||
|
||||
let response = rx.recv().await.unwrap();
|
||||
assert!(response.is_ok());
|
||||
assert_eq!(response.unwrap(), serde_json::json!({"item": 1}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_complete_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, mut rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-1", tx, None);
|
||||
|
||||
map.push_subscribe("sub-1", Ok(serde_json::json!({"item": 1})));
|
||||
let completed = map.complete_subscribe("sub-1");
|
||||
assert!(completed);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let _ = rx.recv().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_abort_call() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, _rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-1", tx, timeout);
|
||||
|
||||
let aborted = map.abort("req-1");
|
||||
assert!(aborted);
|
||||
assert!(map.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_abort_unknown() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let aborted = map.abort("unknown");
|
||||
assert!(!aborted);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_sweep_expired() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx1, _rx1) = oneshot::channel();
|
||||
let (tx2, _rx2) = oneshot::channel();
|
||||
let past = Instant::now() - Duration::from_secs(1);
|
||||
let future = Instant::now() + Duration::from_secs(30);
|
||||
|
||||
map.insert_call("expired-1", tx1, past);
|
||||
map.insert_call("active-1", tx2, future);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 1);
|
||||
assert!(!map.contains("expired-1"));
|
||||
assert!(map.contains("active-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_sweep_subscribe_with_timeout() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx1, _rx1) = mpsc::channel(16);
|
||||
let (tx2, _rx2) = mpsc::channel(16);
|
||||
let past = Some(Instant::now() - Duration::from_secs(1));
|
||||
let future = Some(Instant::now() + Duration::from_secs(30));
|
||||
|
||||
map.insert_subscribe("expired-sub", tx1, past);
|
||||
map.insert_subscribe("active-sub", tx2, future);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 1);
|
||||
assert!(!map.contains("expired-sub"));
|
||||
assert!(map.contains("active-sub"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_subscribe_no_timeout_not_swept() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, _rx) = mpsc::channel(16);
|
||||
map.insert_subscribe("sub-no-timeout", tx, None);
|
||||
|
||||
let swept = map.sweep_expired(Instant::now());
|
||||
assert_eq!(swept, 0);
|
||||
assert!(map.contains("sub-no-timeout"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_push_unknown_subscribe() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let pushed = map.push_subscribe("unknown", Ok(serde_json::json!(null)));
|
||||
assert!(!pushed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_request_map_call_error_response() {
|
||||
let mut map = PendingRequestMap::new();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let timeout = Instant::now() + Duration::from_secs(30);
|
||||
map.insert_call("req-err", tx, timeout);
|
||||
|
||||
let result = map.resolve_call(
|
||||
"req-err",
|
||||
Err(CallError {
|
||||
code: "TIMEOUT".to_string(),
|
||||
message: "request timed out".to_string(),
|
||||
retryable: true,
|
||||
}),
|
||||
);
|
||||
assert!(result);
|
||||
assert!(map.is_empty());
|
||||
|
||||
let response = rx.await.unwrap();
|
||||
assert!(response.is_err());
|
||||
let err = response.unwrap_err();
|
||||
assert_eq!(err.code, "TIMEOUT");
|
||||
assert!(err.retryable);
|
||||
}
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::call::spec::OperationSpec;
|
||||
|
||||
pub type Handler = Arc<dyn Fn(Value, OperationContext) -> ResponseEnvelope + Send + Sync>;
|
||||
|
||||
pub struct OperationRegistry {
|
||||
operations: HashMap<String, (OperationSpec, Handler)>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OperationRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OperationRegistry")
|
||||
.field("operation_count", &self.operations.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, spec: OperationSpec, handler: Handler) {
|
||||
self.operations.insert(spec.name.clone(), (spec, handler));
|
||||
}
|
||||
|
||||
pub fn lookup(&self, name: &str) -> Option<(&OperationSpec, &Handler)> {
|
||||
self.operations
|
||||
.get(name)
|
||||
.map(|(spec, handler)| (spec, handler))
|
||||
}
|
||||
|
||||
pub fn invoke(&self, name: &str, input: Value, context: OperationContext) -> ResponseEnvelope {
|
||||
match self.lookup(name) {
|
||||
Some((spec, handler)) => {
|
||||
if !context.trusted {
|
||||
if let Some(ref identity) = context.identity {
|
||||
if !spec.access_control.check(identity) {
|
||||
return ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"FORBIDDEN",
|
||||
"access denied",
|
||||
false,
|
||||
);
|
||||
}
|
||||
} else if spec.access_control.has_restrictions() {
|
||||
return ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"FORBIDDEN",
|
||||
"authentication required",
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
handler(input, context)
|
||||
}
|
||||
None => ResponseEnvelope::err(
|
||||
&context.request_id,
|
||||
"NOT_FOUND",
|
||||
format!("operation not found: {name}"),
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_operations(&self) -> Vec<&OperationSpec> {
|
||||
self.operations.values().map(|(spec, _)| spec).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OperationRegistryBuilder {
|
||||
registry: OperationRegistry,
|
||||
}
|
||||
|
||||
impl OperationRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registry: OperationRegistry::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with(mut self, spec: OperationSpec, handler: Handler) -> Self {
|
||||
self.registry.register(spec, handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> OperationRegistry {
|
||||
self.registry
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationRegistryBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::Identity;
|
||||
use crate::call::env::OperationEnv;
|
||||
use crate::call::spec::{AccessControl, OperationType};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_spec(name: &str, namespace: &str) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn make_spec_with_acl(name: &str, namespace: &str, acl: AccessControl) -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: name.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
op_type: OperationType::Mutation,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: acl,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_context(request_id: &str, identity: Option<Identity>) -> OperationContext {
|
||||
let registry = OperationRegistry::new();
|
||||
OperationContext {
|
||||
request_id: request_id.to_string(),
|
||||
parent_request_id: None,
|
||||
identity,
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry),
|
||||
trusted: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_and_lookup() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let spec = make_spec("fs/readFile", "fs");
|
||||
let handler: Handler = Arc::new(|input, _ctx| ResponseEnvelope::ok("req-1", input));
|
||||
registry.register(spec, handler);
|
||||
let found = registry.lookup("fs/readFile");
|
||||
assert!(found.is_some());
|
||||
let (spec, _) = found.unwrap();
|
||||
assert_eq!(spec.name, "fs/readFile");
|
||||
assert_eq!(spec.namespace, "fs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lookup_missing_returns_none() {
|
||||
let registry = OperationRegistry::new();
|
||||
assert!(registry.lookup("missing").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_operation() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let spec = make_spec("fs/readFile", "fs");
|
||||
let handler: Handler = Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input));
|
||||
registry.register(spec, handler);
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("fs/readFile", serde_json::json!({"path": "/tmp"}), context);
|
||||
assert!(result.result.is_ok());
|
||||
assert_eq!(result.result.unwrap(), serde_json::json!({"path": "/tmp"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_missing_operation() {
|
||||
let registry = OperationRegistry::new();
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("missing", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_with_acl_check_allowed() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let context = make_context("req-1", Some(identity));
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_with_acl_check_denied() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let context = make_context("req-1", Some(identity));
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "FORBIDDEN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_trusted_skips_acl() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let identity = Identity {
|
||||
id: "user-1".to_string(),
|
||||
scopes: vec!["read".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let mut registry2 = OperationRegistry::new();
|
||||
let context = OperationContext {
|
||||
request_id: "req-1".to_string(),
|
||||
parent_request_id: None,
|
||||
identity: Some(identity),
|
||||
metadata: HashMap::new(),
|
||||
env: OperationEnv::local(registry2),
|
||||
trusted: true,
|
||||
};
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invoke_no_identity_with_acl_denied() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
let acl = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let spec = make_spec_with_acl("bash/exec", "bash", acl);
|
||||
let handler: Handler = Arc::new(|_input, ctx| {
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!("done"))
|
||||
});
|
||||
registry.register(spec, handler);
|
||||
|
||||
let context = make_context("req-1", None);
|
||||
let result = registry.invoke("bash/exec", serde_json::json!(null), context);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "FORBIDDEN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_operations() {
|
||||
let mut registry = OperationRegistry::new();
|
||||
registry.register(
|
||||
make_spec("fs/readFile", "fs"),
|
||||
Arc::new(|_, ctx| ResponseEnvelope::ok(&ctx.request_id, serde_json::json!(null))),
|
||||
);
|
||||
registry.register(
|
||||
make_spec("bash/exec", "bash"),
|
||||
Arc::new(|_, ctx| ResponseEnvelope::ok(&ctx.request_id, serde_json::json!(null))),
|
||||
);
|
||||
let ops = registry.list_operations();
|
||||
assert_eq!(ops.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_builder() {
|
||||
let registry = OperationRegistryBuilder::new()
|
||||
.with(
|
||||
make_spec("fs/readFile", "fs"),
|
||||
Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input)),
|
||||
)
|
||||
.with(
|
||||
make_spec("bash/exec", "bash"),
|
||||
Arc::new(|input, ctx| ResponseEnvelope::ok(&ctx.request_id, input)),
|
||||
)
|
||||
.build();
|
||||
assert!(registry.lookup("fs/readFile").is_some());
|
||||
assert!(registry.lookup("bash/exec").is_some());
|
||||
}
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub struct CallError {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
pub retryable: bool,
|
||||
}
|
||||
|
||||
impl CallError {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>, retryable: bool) -> Self {
|
||||
Self {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
retryable,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponseEnvelope {
|
||||
pub request_id: String,
|
||||
pub result: Result<Value, CallError>,
|
||||
}
|
||||
|
||||
impl ResponseEnvelope {
|
||||
pub fn ok(request_id: impl Into<String>, value: Value) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Ok(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(
|
||||
request_id: impl Into<String>,
|
||||
code: impl Into<String>,
|
||||
message: impl Into<String>,
|
||||
retryable: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
request_id: request_id.into(),
|
||||
result: Err(CallError {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
retryable,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn call_error_fields() {
|
||||
let err = CallError {
|
||||
code: "NOT_FOUND".to_string(),
|
||||
message: "operation not found".to_string(),
|
||||
retryable: false,
|
||||
};
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
assert_eq!(err.message, "operation not found");
|
||||
assert!(!err.retryable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_ok() {
|
||||
let env = ResponseEnvelope::ok("req-1", json!({"status": "ok"}));
|
||||
assert_eq!(env.request_id, "req-1");
|
||||
assert!(env.result.is_ok());
|
||||
assert_eq!(env.result.unwrap(), json!({"status": "ok"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_err() {
|
||||
let env = ResponseEnvelope::err("req-1", "NOT_FOUND", "operation not found", false);
|
||||
assert_eq!(env.request_id, "req-1");
|
||||
assert!(env.result.is_err());
|
||||
let err = env.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
assert_eq!(err.message, "operation not found");
|
||||
assert!(!err.retryable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_serialization() {
|
||||
let env = ResponseEnvelope::ok("req-1", json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: ResponseEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.request_id, "req-1");
|
||||
assert!(deserialized.result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_envelope_err_serialization() {
|
||||
let env = ResponseEnvelope::err("req-2", "TIMEOUT", "timed out", true);
|
||||
let serialized = serde_json::to_string(&env).unwrap();
|
||||
let deserialized: ResponseEnvelope = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.request_id, "req-2");
|
||||
let err = deserialized.result.unwrap_err();
|
||||
assert_eq!(err.code, "TIMEOUT");
|
||||
assert!(err.retryable);
|
||||
}
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::call::context::OperationContext;
|
||||
use crate::call::response::ResponseEnvelope;
|
||||
use crate::call::spec::{AccessControl, OperationSpec, OperationType};
|
||||
|
||||
pub fn services_list_spec() -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: super::events::SERVICE_LIST.to_string(),
|
||||
namespace: "services".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}),
|
||||
output_schema: serde_json::json!({
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": { "type": "string" },
|
||||
},
|
||||
},
|
||||
}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn services_schema_spec() -> OperationSpec {
|
||||
OperationSpec {
|
||||
name: super::events::SERVICE_SCHEMA.to_string(),
|
||||
namespace: "services".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
},
|
||||
"required": ["name"],
|
||||
}),
|
||||
output_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"namespace": { "type": "string" },
|
||||
"op_type": { "type": "string" },
|
||||
"input_schema": { "type": "object" },
|
||||
"output_schema": { "type": "object" },
|
||||
},
|
||||
}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_default_operations(registry: &mut crate::call::OperationRegistry) {
|
||||
registry.register(services_list_spec(), Arc::new(services_list_handler));
|
||||
registry.register(services_schema_spec(), Arc::new(services_schema_handler));
|
||||
}
|
||||
|
||||
fn services_list_handler(_input: Value, ctx: OperationContext) -> ResponseEnvelope {
|
||||
let registry = &ctx.env.registry_ref();
|
||||
let specs = registry.list_operations();
|
||||
let ops: Vec<Value> = specs
|
||||
.iter()
|
||||
.map(|spec| {
|
||||
serde_json::json!({
|
||||
"name": spec.name,
|
||||
"namespace": spec.namespace,
|
||||
"op_type": format!("{:?}", spec.op_type).to_lowercase(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
ResponseEnvelope::ok(&ctx.request_id, serde_json::json!({ "operations": ops }))
|
||||
}
|
||||
|
||||
fn services_schema_handler(input: Value, ctx: OperationContext) -> ResponseEnvelope {
|
||||
let name = match input.get("name").and_then(|v| v.as_str()) {
|
||||
Some(n) => n.to_string(),
|
||||
None => {
|
||||
return ResponseEnvelope::err(
|
||||
&ctx.request_id,
|
||||
"INVALID_INPUT",
|
||||
"missing required field: name",
|
||||
false,
|
||||
);
|
||||
}
|
||||
};
|
||||
let registry = &ctx.env.registry_ref();
|
||||
match registry.lookup(&name) {
|
||||
Some((spec, _)) => ResponseEnvelope::ok(
|
||||
&ctx.request_id,
|
||||
serde_json::json!({
|
||||
"name": spec.name,
|
||||
"namespace": spec.namespace,
|
||||
"op_type": format!("{:?}", spec.op_type).to_lowercase(),
|
||||
"input_schema": spec.input_schema,
|
||||
"output_schema": spec.output_schema,
|
||||
}),
|
||||
),
|
||||
None => ResponseEnvelope::err(
|
||||
&ctx.request_id,
|
||||
"NOT_FOUND",
|
||||
format!("operation not found: {name}"),
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::env::OperationEnv;
|
||||
|
||||
fn make_env() -> OperationEnv {
|
||||
let mut registry = crate::call::OperationRegistry::new();
|
||||
registry.register(services_list_spec(), Arc::new(services_list_handler));
|
||||
registry.register(services_schema_spec(), Arc::new(services_schema_handler));
|
||||
OperationEnv::local(registry)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_returns_operations() {
|
||||
let env = make_env();
|
||||
let result = env.invoke("services", "list", serde_json::json!({}));
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
let ops = value.get("operations").unwrap().as_array().unwrap();
|
||||
assert_eq!(ops.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_returns_spec() {
|
||||
let env = make_env();
|
||||
let result = env.invoke(
|
||||
"services",
|
||||
"schema",
|
||||
serde_json::json!({"name": "/services/list"}),
|
||||
);
|
||||
assert!(result.result.is_ok());
|
||||
let value = result.result.unwrap();
|
||||
assert_eq!(value["name"], "/services/list");
|
||||
assert_eq!(value["namespace"], "services");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_missing_name() {
|
||||
let env = make_env();
|
||||
let result = env.invoke("services", "schema", serde_json::json!({}));
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "INVALID_INPUT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_not_found() {
|
||||
let env = make_env();
|
||||
let result = env.invoke(
|
||||
"services",
|
||||
"schema",
|
||||
serde_json::json!({"name": "/nonexistent/op"}),
|
||||
);
|
||||
assert!(result.result.is_err());
|
||||
let err = result.result.unwrap_err();
|
||||
assert_eq!(err.code, "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_list_spec_fields() {
|
||||
let spec = services_list_spec();
|
||||
assert_eq!(spec.name, "/services/list");
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn services_schema_spec_fields() {
|
||||
let spec = services_schema_spec();
|
||||
assert_eq!(spec.name, "/services/schema");
|
||||
assert_eq!(spec.namespace, "services");
|
||||
assert_eq!(spec.op_type, OperationType::Query);
|
||||
assert!(!spec.access_control.has_restrictions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_default_operations_adds_both() {
|
||||
let mut registry = crate::call::OperationRegistry::new();
|
||||
register_default_operations(&mut registry);
|
||||
assert!(registry.lookup("/services/list").is_some());
|
||||
assert!(registry.lookup("/services/schema").is_some());
|
||||
assert_eq!(registry.list_operations().len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
//! Operation specifications (type, access control) for the call protocol.
|
||||
//!
|
||||
//! See [ADR-025](docs/architecture/decisions/025-operation-spec.md) and
|
||||
//! [ADR-033](docs/architecture/decisions/033-call-protocol-extensions.md).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum OperationType {
|
||||
Query,
|
||||
Mutation,
|
||||
Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessControl {
|
||||
pub required_scopes: Vec<String>,
|
||||
pub required_scopes_any: Option<Vec<String>>,
|
||||
pub resource_type: Option<String>,
|
||||
pub resource_action: Option<String>,
|
||||
}
|
||||
|
||||
impl AccessControl {
|
||||
pub fn check(&self, identity: &crate::auth::Identity) -> bool {
|
||||
for scope in &self.required_scopes {
|
||||
if !identity.scopes.contains(scope) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(any) = &self.required_scopes_any {
|
||||
if !any.iter().any(|s| identity.scopes.contains(s)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(res_type) = &self.resource_type {
|
||||
if let Some(actions) = identity.resources.get(res_type) {
|
||||
if let Some(action) = &self.resource_action {
|
||||
if !actions.contains(action) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn has_restrictions(&self) -> bool {
|
||||
!self.required_scopes.is_empty()
|
||||
|| self.required_scopes_any.is_some()
|
||||
|| self.resource_type.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OperationSpec {
|
||||
pub name: String,
|
||||
pub namespace: String,
|
||||
pub op_type: OperationType,
|
||||
pub input_schema: Value,
|
||||
pub output_schema: Value,
|
||||
pub access_control: AccessControl,
|
||||
}
|
||||
|
||||
impl OperationSpec {
|
||||
pub fn path(&self) -> String {
|
||||
format!("/{}", self.name)
|
||||
}
|
||||
|
||||
pub fn namespace_from_name(name: &str) -> String {
|
||||
let trimmed = name.trim_start_matches('/');
|
||||
let parts: Vec<&str> = trimmed.split('/').collect();
|
||||
match parts.len() {
|
||||
n if n >= 3 => parts[1].to_string(),
|
||||
n if n >= 2 => parts[0].to_string(),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_identity(
|
||||
scopes: Vec<String>,
|
||||
resources: HashMap<String, Vec<String>>,
|
||||
) -> crate::auth::Identity {
|
||||
crate::auth::Identity {
|
||||
id: "test".to_string(),
|
||||
scopes,
|
||||
resources,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_allows_matching_scopes() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["read".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_rejects_missing_scopes() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["admin".to_string()],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_required_scopes_any_matches() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: Some(vec!["admin".to_string(), "read".to_string()]),
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_required_scopes_any_rejects() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: Some(vec!["admin".to_string()]),
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
};
|
||||
let id = make_identity(vec!["read".to_string()], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_matches() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["read".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], resources);
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_missing_resource_type() {
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], HashMap::new());
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_resource_check_missing_action() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["write".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(vec![], resources);
|
||||
assert!(!ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_control_combined_scopes_and_resources() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["read".to_string()]);
|
||||
let ac = AccessControl {
|
||||
required_scopes: vec!["relay:connect".to_string()],
|
||||
required_scopes_any: Some(vec!["admin".to_string()]),
|
||||
resource_type: Some("service".to_string()),
|
||||
resource_action: Some("read".to_string()),
|
||||
};
|
||||
let id = make_identity(
|
||||
vec!["relay:connect".to_string(), "admin".to_string()],
|
||||
resources,
|
||||
);
|
||||
assert!(ac.check(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_type_variants() {
|
||||
assert_eq!(OperationType::Query, OperationType::Query);
|
||||
assert_ne!(OperationType::Query, OperationType::Mutation);
|
||||
assert_ne!(OperationType::Mutation, OperationType::Subscription);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_spec_namespace_from_name() {
|
||||
assert_eq!(OperationSpec::namespace_from_name("/auth/verify"), "auth");
|
||||
assert_eq!(OperationSpec::namespace_from_name("/fs/readFile"), "fs");
|
||||
assert_eq!(
|
||||
OperationSpec::namespace_from_name("/head/agent/chat"),
|
||||
"agent"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn operation_spec_path() {
|
||||
let spec = OperationSpec {
|
||||
name: "auth/verify".to_string(),
|
||||
namespace: "auth".to_string(),
|
||||
op_type: OperationType::Query,
|
||||
input_schema: serde_json::json!({}),
|
||||
output_schema: serde_json::json!({}),
|
||||
access_control: AccessControl {
|
||||
required_scopes: vec![],
|
||||
required_scopes_any: None,
|
||||
resource_type: None,
|
||||
resource_action: None,
|
||||
},
|
||||
};
|
||||
assert_eq!(spec.path(), "/auth/verify");
|
||||
}
|
||||
}
|
||||
@@ -1,468 +0,0 @@
|
||||
//! Channel manager with automatic reconnection.
|
||||
//!
|
||||
//! Owns the SSH session handle and provides `open_direct_tcpip()`,
|
||||
//! `request_tcpip_forward()`, and `cancel_tcpip_forward()`. Monitors
|
||||
//! the session for disconnect and attempts reconnection with exponential
|
||||
//! backoff (1s, 2s, 4s, ..., 30s cap). Re-registers remote forwards
|
||||
//! after successful reconnection.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use russh::client;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use crate::error::ChannelError;
|
||||
use crate::transport::Transport;
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct ForwardRequest {
|
||||
pub addr: String,
|
||||
pub port: u32,
|
||||
}
|
||||
|
||||
struct ChannelManagerInner<T: Transport> {
|
||||
transport: Arc<T>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
handle: Arc<RwLock<client::Handle<ClientHandler>>>,
|
||||
username: String,
|
||||
forwards: RwLock<HashSet<ForwardRequest>>,
|
||||
reconnect_attempts: RwLock<u32>,
|
||||
}
|
||||
|
||||
pub struct ChannelManager<T: Transport> {
|
||||
inner: Arc<ChannelManagerInner<T>>,
|
||||
reconnect_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
impl<T: Transport> ChannelManager<T> {
|
||||
pub async fn new(
|
||||
transport: Arc<T>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
username: String,
|
||||
) -> Result<Self, ChannelError> {
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
let handle = Self::establish_session(&*transport, handler, &auth_config, &username)
|
||||
.await
|
||||
.map_err(|_| ChannelError::TargetUnreachable)?;
|
||||
|
||||
let inner = Arc::new(ChannelManagerInner {
|
||||
transport,
|
||||
auth_config,
|
||||
handle: Arc::new(RwLock::new(handle)),
|
||||
username,
|
||||
forwards: RwLock::new(HashSet::new()),
|
||||
reconnect_attempts: RwLock::new(0),
|
||||
});
|
||||
|
||||
let reconnect_handle = Arc::new(RwLock::new(None));
|
||||
let manager = Self {
|
||||
inner,
|
||||
reconnect_handle,
|
||||
};
|
||||
|
||||
manager.start_reconnect_monitor();
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
async fn establish_session(
|
||||
transport: &T,
|
||||
handler: ClientHandler,
|
||||
auth_config: &ClientAuthConfig,
|
||||
username: &str,
|
||||
) -> Result<client::Handle<ClientHandler>, russh::Error> {
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
russh::Error::SendError
|
||||
})?;
|
||||
|
||||
let config = Arc::new(russh::client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler).await?;
|
||||
|
||||
let auth_ok = auth_config.authenticate(&mut handle, username).await?;
|
||||
if !auth_ok {
|
||||
return Err(russh::Error::SendError);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
pub async fn open_direct_tcpip(
|
||||
&self,
|
||||
host: &str,
|
||||
port: u32,
|
||||
) -> Result<russh::Channel<russh::client::Msg>, ChannelError> {
|
||||
let handle = self.inner.handle.read().await;
|
||||
handle
|
||||
.channel_open_direct_tcpip(host, port, "127.0.0.1", 0)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
debug!("channel open failed: {e}");
|
||||
ChannelError::ChannelClosed
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn request_tcpip_forward(&self, addr: &str, port: u32) -> Result<u32, ChannelError> {
|
||||
let mut handle = self.inner.handle.write().await;
|
||||
let result = handle
|
||||
.tcpip_forward(addr, port)
|
||||
.await
|
||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||
|
||||
self.inner.forwards.write().await.insert(ForwardRequest {
|
||||
addr: addr.to_string(),
|
||||
port,
|
||||
});
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn cancel_tcpip_forward(&self, addr: &str, port: u32) -> Result<(), ChannelError> {
|
||||
let handle = self.inner.handle.read().await;
|
||||
handle
|
||||
.cancel_tcpip_forward(addr, port)
|
||||
.await
|
||||
.map_err(|_| ChannelError::ChannelClosed)?;
|
||||
|
||||
self.inner.forwards.write().await.remove(&ForwardRequest {
|
||||
addr: addr.to_string(),
|
||||
port,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
let handle = self.inner.handle.read().await;
|
||||
!handle.is_closed()
|
||||
}
|
||||
|
||||
fn start_reconnect_monitor(&self) {
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let handle_arc = Arc::clone(&self.inner.handle);
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
time::sleep(Duration::from_secs(1)).await;
|
||||
let handle = handle_arc.read().await;
|
||||
if handle.is_closed() {
|
||||
drop(handle);
|
||||
info!("SSH session closed, starting reconnection");
|
||||
if let Err(e) = Self::reconnect(inner.clone()).await {
|
||||
error!("reconnection failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let reconnect_handle = Arc::clone(&self.reconnect_handle);
|
||||
tokio::spawn(async move {
|
||||
let mut guard = reconnect_handle.write().await;
|
||||
*guard = Some(join_handle);
|
||||
});
|
||||
}
|
||||
|
||||
async fn reconnect(inner: Arc<ChannelManagerInner<T>>) -> Result<(), ChannelError> {
|
||||
let mut attempts = inner.reconnect_attempts.write().await;
|
||||
let attempt_num = *attempts;
|
||||
let backoff = backoff_duration(attempt_num);
|
||||
*attempts += 1;
|
||||
drop(attempts);
|
||||
|
||||
warn!(
|
||||
"reconnect attempt #{}, waiting {:?}",
|
||||
attempt_num + 1,
|
||||
backoff
|
||||
);
|
||||
time::sleep(backoff).await;
|
||||
|
||||
let handler = ClientHandler::from_config(&inner.auth_config);
|
||||
match Self::establish_session(
|
||||
&*inner.transport,
|
||||
handler,
|
||||
&inner.auth_config,
|
||||
&inner.username,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(new_handle) => {
|
||||
info!("reconnection successful");
|
||||
{
|
||||
let mut handle_guard = inner.handle.write().await;
|
||||
*handle_guard = new_handle;
|
||||
}
|
||||
{
|
||||
let mut attempts = inner.reconnect_attempts.write().await;
|
||||
*attempts = 0;
|
||||
}
|
||||
Self::re_register_forwards(&inner).await;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("reconnection attempt failed: {e}");
|
||||
Err(ChannelError::ChannelClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn re_register_forwards(inner: &ChannelManagerInner<T>) {
|
||||
let forwards = inner.forwards.read().await;
|
||||
if forwards.is_empty() {
|
||||
return;
|
||||
}
|
||||
let mut handle = inner.handle.write().await;
|
||||
for fwd in forwards.iter() {
|
||||
match handle.tcpip_forward(&fwd.addr, fwd.port).await {
|
||||
Ok(_) => {
|
||||
debug!("re-registered tcpip_forward: {}:{}", fwd.addr, fwd.port);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"failed to re-register tcpip_forward {}:{}: {e}",
|
||||
fwd.addr, fwd.port
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (cap), continues indefinitely.
|
||||
fn backoff_duration(attempt: u32) -> Duration {
|
||||
let secs: u64 = match attempt {
|
||||
0 => 1,
|
||||
1 => 2,
|
||||
2 => 4,
|
||||
3 => 8,
|
||||
4 => 16,
|
||||
_ => 30,
|
||||
};
|
||||
Duration::from_secs(secs)
|
||||
}
|
||||
|
||||
impl<T: Transport> Drop for ChannelManager<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut guard) = self.reconnect_handle.try_write() {
|
||||
if let Some(handle) = guard.take() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::io::duplex;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn make_auth_config() -> Arc<ClientAuthConfig> {
|
||||
let source = crate::auth::keys::KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
Arc::new(ClientAuthConfig::from_key_source(source).unwrap())
|
||||
}
|
||||
|
||||
struct AlwaysFailTransport;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for AlwaysFailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
Err(anyhow::anyhow!("always fails"))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"always-fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct TrackConnectTransport {
|
||||
connect_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl TrackConnectTransport {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for TrackConnectTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"track-connect".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct CountingFailTransport {
|
||||
fail_count: Arc<AtomicUsize>,
|
||||
succeed_after: usize,
|
||||
}
|
||||
|
||||
impl CountingFailTransport {
|
||||
fn new(succeed_after: usize) -> Self {
|
||||
Self {
|
||||
fail_count: Arc::new(AtomicUsize::new(0)),
|
||||
succeed_after,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for CountingFailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
let count = self.fail_count.fetch_add(1, Ordering::SeqCst);
|
||||
if count < self.succeed_after {
|
||||
return Err(anyhow::anyhow!("connection failed (attempt {})", count));
|
||||
}
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"counting-fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_durations() {
|
||||
assert_eq!(backoff_duration(0), Duration::from_secs(1));
|
||||
assert_eq!(backoff_duration(1), Duration::from_secs(2));
|
||||
assert_eq!(backoff_duration(2), Duration::from_secs(4));
|
||||
assert_eq!(backoff_duration(3), Duration::from_secs(8));
|
||||
assert_eq!(backoff_duration(4), Duration::from_secs(16));
|
||||
assert_eq!(backoff_duration(5), Duration::from_secs(30));
|
||||
assert_eq!(backoff_duration(6), Duration::from_secs(30));
|
||||
assert_eq!(backoff_duration(100), Duration::from_secs(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_sequence_matches_spec() {
|
||||
let sequence: Vec<Duration> = (0..6).map(backoff_duration).collect();
|
||||
assert_eq!(
|
||||
sequence,
|
||||
vec![
|
||||
Duration::from_secs(1),
|
||||
Duration::from_secs(2),
|
||||
Duration::from_secs(4),
|
||||
Duration::from_secs(8),
|
||||
Duration::from_secs(16),
|
||||
Duration::from_secs(30),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_request_hash_eq() {
|
||||
let fwd1 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
};
|
||||
let fwd2 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
};
|
||||
let fwd3 = ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
};
|
||||
assert_eq!(fwd1, fwd2);
|
||||
assert_ne!(fwd1, fwd3);
|
||||
let mut set = HashSet::new();
|
||||
set.insert(fwd1.clone());
|
||||
assert!(set.contains(&fwd2));
|
||||
assert!(!set.contains(&fwd3));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_manager_new_transport_fails() {
|
||||
let auth = make_auth_config();
|
||||
let transport = Arc::new(AlwaysFailTransport);
|
||||
let result = ChannelManager::new(transport, auth, "testuser".to_string()).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(ChannelError::TargetUnreachable) => {}
|
||||
other => panic!("expected TargetUnreachable, got {:?}", other.as_ref().err()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_connect_called_on_new() {
|
||||
let transport = Arc::new(TrackConnectTransport::new());
|
||||
let connect_before = transport.connect_count.load(Ordering::SeqCst);
|
||||
assert_eq!(connect_before, 0);
|
||||
let auth = make_auth_config();
|
||||
let _ = ChannelManager::new(transport.clone(), auth, "testuser".to_string()).await;
|
||||
let connect_after = transport.connect_count.load(Ordering::SeqCst);
|
||||
assert!(connect_after > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnect_monitor_detects_closed_handle() {
|
||||
let auth = make_auth_config();
|
||||
let transport = Arc::new(TrackConnectTransport::new());
|
||||
let handler = ClientHandler::from_config(&auth);
|
||||
let config = Arc::new(russh::client::Config::default());
|
||||
let stream = transport.connect().await.unwrap();
|
||||
let handle = client::connect_stream(config, stream, handler).await;
|
||||
match handle {
|
||||
Ok(h) => {
|
||||
assert!(!h.is_closed());
|
||||
drop(h);
|
||||
}
|
||||
Err(_) => {
|
||||
// connect_stream fails without a real SSH server,
|
||||
// but the concept is verified: dropped handle => is_closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_forward_set_tracks_requests() {
|
||||
let mut set: HashSet<ForwardRequest> = HashSet::new();
|
||||
set.insert(ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
});
|
||||
set.insert(ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
});
|
||||
assert_eq!(set.len(), 2);
|
||||
set.remove(&ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
});
|
||||
assert_eq!(set.len(), 1);
|
||||
assert!(set.contains(&ForwardRequest {
|
||||
addr: "0.0.0.0".to_string(),
|
||||
port: 9090,
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_indefinitely_beyond_cap() {
|
||||
for attempt in 0..50 {
|
||||
let duration = backoff_duration(attempt);
|
||||
assert!(duration <= Duration::from_secs(30));
|
||||
assert!(duration >= Duration::from_secs(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,877 +0,0 @@
|
||||
//! Client session management and connection logic.
|
||||
//!
|
||||
//! `ClientSession` establishes an SSH connection over a transport, authenticates,
|
||||
//! starts a SOCKS5 proxy, sets up port forwards, and monitors for reconnection.
|
||||
//! `ConnectOptions` provides a builder-pattern API for programmatic configuration.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use russh::client;
|
||||
use russh::keys::PrivateKey;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::client::forward::{LocalForwarder, PortForwardSpec, RemoteForwarder};
|
||||
use crate::error::ConfigError;
|
||||
use crate::socks5::{HandleChannelOpener, Socks5Server};
|
||||
use crate::transport::Transport;
|
||||
|
||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||
const DRAIN_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
/// Transport mode for the client connection.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TransportMode {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportMode::Tcp => write!(f, "tcp"),
|
||||
TransportMode::Tls => write!(f, "tls"),
|
||||
TransportMode::Iroh => write!(f, "iroh"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Programmatic configuration for an alknet client session.
|
||||
///
|
||||
/// Construct with `ConnectOptions::new(key_source)` and chain builder methods.
|
||||
/// Call `validate()` before passing to `ClientSession::new()`.
|
||||
///
|
||||
/// ```
|
||||
/// use alknet_core::client::{ConnectOptions, TransportMode};
|
||||
/// use alknet_core::auth::keys::KeySource;
|
||||
///
|
||||
/// let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
|
||||
/// .server("example.com:22")
|
||||
/// .transport_mode(TransportMode::Tcp)
|
||||
/// .socks5_addr("127.0.0.1:1080")
|
||||
/// .forward("5432:db.internal:5432");
|
||||
/// opts.validate().unwrap();
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectOptions {
|
||||
pub server: Option<String>,
|
||||
pub peer: Option<String>,
|
||||
pub transport_mode: TransportMode,
|
||||
pub identity: KeySource,
|
||||
pub socks5_addr: String,
|
||||
pub forwards: Vec<String>,
|
||||
pub remote_forwards: Vec<String>,
|
||||
pub proxy: Option<String>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub tls_server_name: Option<String>,
|
||||
pub insecure: bool,
|
||||
}
|
||||
|
||||
impl ConnectOptions {
|
||||
pub fn new(identity: KeySource) -> Self {
|
||||
Self {
|
||||
server: None,
|
||||
peer: None,
|
||||
transport_mode: TransportMode::Tcp,
|
||||
identity,
|
||||
socks5_addr: DEFAULT_SOCKS5_ADDR.to_string(),
|
||||
forwards: Vec::new(),
|
||||
remote_forwards: Vec::new(),
|
||||
proxy: None,
|
||||
iroh_relay: None,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server(mut self, addr: impl Into<String>) -> Self {
|
||||
self.server = Some(addr.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn peer(mut self, endpoint_id: impl Into<String>) -> Self {
|
||||
self.peer = Some(endpoint_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn transport_mode(mut self, mode: TransportMode) -> Self {
|
||||
self.transport_mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn socks5_addr(mut self, addr: impl Into<String>) -> Self {
|
||||
self.socks5_addr = addr.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn remote_forward(mut self, spec: impl Into<String>) -> Self {
|
||||
self.remote_forwards.push(spec.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn proxy(mut self, url: impl Into<String>) -> Self {
|
||||
self.proxy = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn iroh_relay(mut self, url: impl Into<String>) -> Self {
|
||||
self.iroh_relay = Some(url.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tls_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<(), ConfigError> {
|
||||
match self.transport_mode {
|
||||
TransportMode::Tcp | TransportMode::Tls => {
|
||||
if self.server.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--server is required for tcp/tls transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
TransportMode::Iroh => {
|
||||
if self.peer.is_none() {
|
||||
return Err(ConfigError::InvalidFlag {
|
||||
name: "--peer is required for iroh transport".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConnectOptions {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConnectOptions")
|
||||
.field("server", &self.server)
|
||||
.field("peer", &self.peer)
|
||||
.field("transport_mode", &self.transport_mode)
|
||||
.field("identity", &"<KeySource>")
|
||||
.field("socks5_addr", &self.socks5_addr)
|
||||
.field("forwards", &self.forwards)
|
||||
.field("remote_forwards", &self.remote_forwards)
|
||||
.field("proxy", &self.proxy)
|
||||
.field("iroh_relay", &self.iroh_relay)
|
||||
.field("tls_server_name", &self.tls_server_name)
|
||||
.field("insecure", &self.insecure)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// An active SSH client session over a transport.
|
||||
///
|
||||
/// Establishes the connection, authenticates, and runs a SOCKS5 proxy plus
|
||||
/// port forwards until shutdown or transport failure. On transport failure,
|
||||
/// attempts reconnection with exponential backoff (1s, 2s, 4s, ..., 30s cap).
|
||||
pub struct ClientSession<T: Transport> {
|
||||
opts: ConnectOptions,
|
||||
transport: Arc<T>,
|
||||
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
|
||||
auth_config: Arc<ClientAuthConfig>,
|
||||
#[allow(dead_code)]
|
||||
private_key: Arc<PrivateKey>,
|
||||
#[allow(dead_code)]
|
||||
username: String,
|
||||
shutdown_tx: tokio::sync::watch::Sender<bool>,
|
||||
shutdown_rx: tokio::sync::watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl<T: Transport> ClientSession<T> {
|
||||
pub async fn new(opts: ConnectOptions, transport: Arc<T>) -> Result<Self, ConnectError> {
|
||||
opts.validate().map_err(ConnectError::Config)?;
|
||||
|
||||
let auth_config = Arc::new(
|
||||
ClientAuthConfig::from_key_source(opts.identity.clone())
|
||||
.map_err(ConnectError::Config)?,
|
||||
);
|
||||
let private_key = auth_config.private_key();
|
||||
|
||||
let username = derive_username();
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SSH connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, &username)
|
||||
.await
|
||||
.map_err(|_| ConnectError::AuthFailed)?;
|
||||
if !auth_ok {
|
||||
return Err(ConnectError::AuthFailed);
|
||||
}
|
||||
|
||||
let handle = Arc::new(Mutex::new(handle));
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
Ok(Self {
|
||||
opts,
|
||||
transport,
|
||||
handle,
|
||||
auth_config,
|
||||
private_key,
|
||||
username,
|
||||
shutdown_tx,
|
||||
shutdown_rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle(&self) -> Arc<Mutex<client::Handle<ClientHandler>>> {
|
||||
Arc::clone(&self.handle)
|
||||
}
|
||||
|
||||
pub fn auth_config(&self) -> &Arc<ClientAuthConfig> {
|
||||
&self.auth_config
|
||||
}
|
||||
|
||||
pub fn transport(&self) -> &Arc<T> {
|
||||
&self.transport
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.opts
|
||||
}
|
||||
|
||||
pub fn shutdown_sender(&self) -> tokio::sync::watch::Sender<bool> {
|
||||
self.shutdown_tx.clone()
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), ConnectError> {
|
||||
let socks5_addr: SocketAddr = self.opts.socks5_addr.parse().map_err(|_| {
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid SOCKS5 address: {}", self.opts.socks5_addr),
|
||||
})
|
||||
})?;
|
||||
|
||||
let channel_opener = HandleChannelOpener::from_arc(Arc::clone(&self.handle));
|
||||
let socks5_server = Socks5Server::with_addr(channel_opener, &socks5_addr.to_string());
|
||||
let socks5_listen = socks5_server.listen_addr();
|
||||
|
||||
let local_forwarders = build_local_forwarders(&self.opts)?;
|
||||
let remote_specs = build_remote_specs(&self.opts)?;
|
||||
|
||||
for spec in &remote_specs {
|
||||
let remote_forwarder =
|
||||
RemoteForwarder::new(spec.clone()).map_err(|_| ConnectError::ForwardFailed)?;
|
||||
let mut h = self.handle.lock().await;
|
||||
remote_forwarder.register(&mut h).await.map_err(|_| {
|
||||
warn!("failed to register remote forward {}", spec);
|
||||
ConnectError::ForwardFailed
|
||||
})?;
|
||||
info!("registered remote forward: {}", spec);
|
||||
}
|
||||
|
||||
let socks5_task = tokio::spawn(async move {
|
||||
debug!("SOCKS5 server starting on {}", socks5_listen);
|
||||
if let Err(e) = socks5_server.run().await {
|
||||
error!("SOCKS5 server error: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
let fwd_handle = Arc::clone(&self.handle);
|
||||
let fwd_shutdown = self.shutdown_rx.clone();
|
||||
let forward_task = tokio::spawn(async move {
|
||||
crate::client::forward::run_local_forwarders(
|
||||
local_forwarders,
|
||||
fwd_handle,
|
||||
fwd_shutdown,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
info!("alknet client running: SOCKS5 on {}", socks5_listen);
|
||||
|
||||
#[cfg(unix)]
|
||||
let signal_done = {
|
||||
let sig_tx = self.shutdown_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut sigterm_stream =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install SIGTERM handler");
|
||||
tokio::select! {
|
||||
_ = sigterm_stream.recv() => {
|
||||
info!("received SIGTERM");
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("received SIGINT (Ctrl+C)");
|
||||
}
|
||||
}
|
||||
let _ = sig_tx.send(true);
|
||||
})
|
||||
};
|
||||
|
||||
let mut wait_shutdown = self.shutdown_rx.clone();
|
||||
let reconnect_handle = Arc::clone(&self.handle);
|
||||
let reconnect_transport = Arc::clone(&self.transport);
|
||||
let reconnect_auth = Arc::clone(&self.auth_config);
|
||||
let reconnect_username = self.username.clone();
|
||||
let reconnect_shutdown = self.shutdown_rx.clone();
|
||||
let reconnect_remote_specs = remote_specs.clone();
|
||||
|
||||
let reconnect_monitor = tokio::spawn(async move {
|
||||
let mut attempts: u32 = 0;
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
if *reconnect_shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
let h = reconnect_handle.lock().await;
|
||||
if h.is_closed() {
|
||||
drop(h);
|
||||
info!("SSH session closed, starting reconnection");
|
||||
let backoff = backoff_duration(attempts);
|
||||
warn!("reconnect attempt #{}, waiting {:?}", attempts + 1, backoff);
|
||||
tokio::time::sleep(backoff).await;
|
||||
|
||||
let handler = ClientHandler::from_config(&reconnect_auth);
|
||||
let username = reconnect_username.clone();
|
||||
match establish_session(
|
||||
&*reconnect_transport,
|
||||
handler,
|
||||
&reconnect_auth,
|
||||
&username,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(new_handle) => {
|
||||
info!("reconnection successful");
|
||||
{
|
||||
let mut guard = reconnect_handle.lock().await;
|
||||
*guard = new_handle;
|
||||
}
|
||||
for spec in &reconnect_remote_specs {
|
||||
match RemoteForwarder::new(spec.clone()) {
|
||||
Ok(rf) => {
|
||||
let mut h = reconnect_handle.lock().await;
|
||||
match rf.register(&mut h).await {
|
||||
Ok(_) => {
|
||||
debug!("re-registered remote forward: {}", spec)
|
||||
}
|
||||
Err(e) => warn!(
|
||||
"failed to re-register remote forward {}: {e}",
|
||||
spec
|
||||
),
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("failed to create remote forwarder: {e}"),
|
||||
}
|
||||
}
|
||||
attempts = 0;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("reconnection attempt failed: {e}");
|
||||
attempts += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = wait_shutdown.changed() => {
|
||||
if *wait_shutdown.borrow() {
|
||||
info!("shutdown signal received");
|
||||
}
|
||||
}
|
||||
_ = socks5_task => {
|
||||
warn!("SOCKS5 server exited unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
reconnect_monitor.abort();
|
||||
|
||||
#[cfg(unix)]
|
||||
signal_done.abort();
|
||||
|
||||
self.shutdown().await?;
|
||||
|
||||
forward_task.abort();
|
||||
let _ = forward_task.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> Result<(), ConnectError> {
|
||||
info!("initiating graceful shutdown");
|
||||
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
|
||||
{
|
||||
let handle = self.handle.lock().await;
|
||||
if !handle.is_closed() {
|
||||
if let Err(e) = handle
|
||||
.disconnect(russh::Disconnect::ByApplication, "shutdown", "")
|
||||
.await
|
||||
{
|
||||
warn!("failed to send SSH disconnect: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(DRAIN_TIMEOUT).await;
|
||||
|
||||
info!("graceful shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_username() -> String {
|
||||
std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "alknet".to_string())
|
||||
}
|
||||
|
||||
async fn establish_session<T: Transport>(
|
||||
transport: &T,
|
||||
handler: ClientHandler,
|
||||
auth_config: &ClientAuthConfig,
|
||||
username: &str,
|
||||
) -> Result<client::Handle<ClientHandler>, ConnectError> {
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
error!("transport connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
let mut handle = client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("SSH connect failed: {e}");
|
||||
ConnectError::ConnectionFailed
|
||||
})?;
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, username)
|
||||
.await
|
||||
.map_err(|_| ConnectError::AuthFailed)?;
|
||||
if !auth_ok {
|
||||
return Err(ConnectError::AuthFailed);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
fn backoff_duration(attempt: u32) -> Duration {
|
||||
let secs: u64 = match attempt {
|
||||
0 => 1,
|
||||
1 => 2,
|
||||
2 => 4,
|
||||
3 => 8,
|
||||
4 => 16,
|
||||
_ => 30,
|
||||
};
|
||||
Duration::from_secs(secs)
|
||||
}
|
||||
|
||||
fn build_local_forwarders(opts: &ConnectOptions) -> Result<Vec<LocalForwarder>, ConnectError> {
|
||||
let mut forwarders = Vec::new();
|
||||
for spec_str in &opts.forwards {
|
||||
let spec = PortForwardSpec::local(spec_str).map_err(|e| {
|
||||
warn!("invalid local forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
forwarders.push(LocalForwarder::new(spec).map_err(|e| {
|
||||
warn!("failed to create local forwarder: {}", e);
|
||||
ConnectError::ForwardFailed
|
||||
})?);
|
||||
}
|
||||
Ok(forwarders)
|
||||
}
|
||||
|
||||
fn build_remote_specs(opts: &ConnectOptions) -> Result<Vec<PortForwardSpec>, ConnectError> {
|
||||
let mut specs = Vec::new();
|
||||
for spec_str in &opts.remote_forwards {
|
||||
let spec = PortForwardSpec::remote(spec_str).map_err(|e| {
|
||||
warn!("invalid remote forward spec '{}': {}", spec_str, e);
|
||||
ConnectError::Config(ConfigError::InvalidFlag {
|
||||
name: format!("invalid remote forward spec: {}", spec_str),
|
||||
})
|
||||
})?;
|
||||
specs.push(spec);
|
||||
}
|
||||
Ok(specs)
|
||||
}
|
||||
|
||||
/// Errors that can occur during client connection setup and operation.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectError {
|
||||
#[error("connection failed")]
|
||||
ConnectionFailed,
|
||||
#[error("authentication failed")]
|
||||
AuthFailed,
|
||||
#[error("forward setup failed")]
|
||||
ForwardFailed,
|
||||
#[error("config error: {0}")]
|
||||
Config(#[from] ConfigError),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::io::duplex;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
fn make_identity() -> KeySource {
|
||||
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_default_fields() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(opts.server.is_none());
|
||||
assert!(opts.peer.is_none());
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tcp);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:1080");
|
||||
assert!(opts.forwards.is_empty());
|
||||
assert!(opts.remote_forwards.is_empty());
|
||||
assert!(opts.proxy.is_none());
|
||||
assert!(opts.iroh_relay.is_none());
|
||||
assert!(opts.tls_server_name.is_none());
|
||||
assert!(!opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_builder_pattern() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.server("example.com:22")
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.socks5_addr("127.0.0.1:9050")
|
||||
.forward("127.0.0.1:5432:db:5432")
|
||||
.remote_forward("0.0.0.0:8080:127.0.0.1:3000")
|
||||
.proxy("socks5://127.0.0.1:1080")
|
||||
.iroh_relay("https://relay.example.com")
|
||||
.tls_server_name("alknet.test")
|
||||
.insecure(true);
|
||||
|
||||
assert_eq!(opts.server.as_deref(), Some("example.com:22"));
|
||||
assert_eq!(opts.transport_mode, TransportMode::Tls);
|
||||
assert_eq!(opts.socks5_addr, "127.0.0.1:9050");
|
||||
assert_eq!(opts.forwards.len(), 1);
|
||||
assert_eq!(opts.remote_forwards.len(), 1);
|
||||
assert_eq!(opts.proxy.as_deref(), Some("socks5://127.0.0.1:1080"));
|
||||
assert_eq!(
|
||||
opts.iroh_relay.as_deref(),
|
||||
Some("https://relay.example.com")
|
||||
);
|
||||
assert_eq!(opts.tls_server_name.as_deref(), Some("alknet.test"));
|
||||
assert!(opts.insecure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tcp);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tcp_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_requires_server() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Tls);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_tls_with_server_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Tls)
|
||||
.server("example.com:443");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_requires_peer() {
|
||||
let opts = ConnectOptions::new(make_identity()).transport_mode(TransportMode::Iroh);
|
||||
assert!(opts.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_validate_iroh_with_peer_ok() {
|
||||
let opts = ConnectOptions::new(make_identity())
|
||||
.transport_mode(TransportMode::Iroh)
|
||||
.peer("some-endpoint-id");
|
||||
assert!(opts.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_file() {
|
||||
let file_source = KeySource::File(std::path::PathBuf::from("/path/to/key"));
|
||||
let opts = ConnectOptions::new(file_source);
|
||||
match &opts.identity {
|
||||
KeySource::File(p) => assert_eq!(p, &std::path::PathBuf::from("/path/to/key")),
|
||||
_ => panic!("expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_accepts_key_source_memory() {
|
||||
let mem_source = KeySource::Memory(b"key-data".to_vec());
|
||||
let opts = ConnectOptions::new(mem_source);
|
||||
match &opts.identity {
|
||||
KeySource::Memory(d) => assert_eq!(d, b"key-data"),
|
||||
_ => panic!("expected Memory variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_display() {
|
||||
assert_eq!(TransportMode::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportMode::Tls.to_string(), "tls");
|
||||
assert_eq!(TransportMode::Iroh.to_string(), "iroh");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_error_variants() {
|
||||
assert_eq!(
|
||||
ConnectError::ConnectionFailed.to_string(),
|
||||
"connection failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConnectError::AuthFailed.to_string(),
|
||||
"authentication failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConnectError::ForwardFailed.to_string(),
|
||||
"forward setup failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_options_debug_redacts_identity() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let debug_str = format!("{:?}", opts);
|
||||
assert!(debug_str.contains("<KeySource>"));
|
||||
assert!(!debug_str.contains("OPENSSH"));
|
||||
}
|
||||
|
||||
struct FailTransport;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FailTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
Err(anyhow::anyhow!("always fails"))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"fail".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct DuplexTransport {
|
||||
connect_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for DuplexTransport {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn connect(&self) -> anyhow::Result<Self::Stream> {
|
||||
self.connect_count.fetch_add(1, Ordering::SeqCst);
|
||||
let (client, _) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"duplex".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_transport_fails() {
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let transport = Arc::new(FailTransport);
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
ConnectError::ConnectionFailed
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_session_new_ssh_handshake_fails() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
ConnectError::ConnectionFailed
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_valid() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("127.0.0.1:5432:db:5432");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_local_forwarders_invalid_spec() {
|
||||
let opts = ConnectOptions::new(make_identity()).forward("bad-spec");
|
||||
let result = build_local_forwarders(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_empty() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_valid() {
|
||||
let opts =
|
||||
ConnectOptions::new(make_identity()).remote_forward("0.0.0.0:8080:127.0.0.1:3000");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_remote_specs_invalid() {
|
||||
let opts = ConnectOptions::new(make_identity()).remote_forward("bad");
|
||||
let result = build_remote_specs(&opts);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_socks5_addr() {
|
||||
assert_eq!(DEFAULT_SOCKS5_ADDR, "127.0.0.1:1080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drain_timeout_is_two_seconds() {
|
||||
assert_eq!(DRAIN_TIMEOUT, Duration::from_secs(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_mode_equality() {
|
||||
assert_eq!(TransportMode::Tcp, TransportMode::Tcp);
|
||||
assert_ne!(TransportMode::Tcp, TransportMode::Tls);
|
||||
assert_ne!(TransportMode::Tls, TransportMode::Iroh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_sends_disconnect_and_drains() {
|
||||
let transport = Arc::new(DuplexTransport {
|
||||
connect_count: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
let opts = ConnectOptions::new(make_identity()).server("example.com:22");
|
||||
let result = ClientSession::new(opts, transport).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn socks5_is_always_enabled_by_default() {
|
||||
let opts = ConnectOptions::new(make_identity());
|
||||
assert!(!opts.socks5_addr.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_mock_transport_session() {
|
||||
use crate::socks5::{ChannelOpenError, ChannelOpener};
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
struct MockOpener;
|
||||
|
||||
impl ChannelOpener for MockOpener {
|
||||
type Stream = tokio::io::DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
let (client, _server) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let bound_addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
|
||||
let opener = MockOpener;
|
||||
let server = Socks5Server::with_addr(opener, &bound_addr.to_string());
|
||||
|
||||
let _server_task = tokio::spawn(async move {
|
||||
let _ = server.run().await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
let mut conn = TcpStream::connect(bound_addr).await.unwrap();
|
||||
|
||||
let greeting = [0x05, 0x01, 0x00];
|
||||
conn.write_all(&greeting).await.unwrap();
|
||||
|
||||
let mut auth_resp = [0u8; 2];
|
||||
conn.read_exact(&mut auth_resp).await.unwrap();
|
||||
assert_eq!(auth_resp, [0x05, 0x00]);
|
||||
|
||||
let connect_req = [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80];
|
||||
conn.write_all(&connect_req).await.unwrap();
|
||||
|
||||
let mut reply = [0u8; 10];
|
||||
conn.read_exact(&mut reply).await.unwrap();
|
||||
assert_eq!(reply[1], 0x00);
|
||||
|
||||
conn.write_all(b"test data").await.unwrap();
|
||||
conn.shutdown().await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
//! Local and remote port forwarding.
|
||||
//!
|
||||
//! `LocalForwarder` binds a local TCP listener and forwards each connection through
|
||||
//! an SSH `direct-tcpip` channel. `RemoteForwarder` requests `tcpip-forward` from
|
||||
//! the server and handles `forwarded-tcpip` channels. Specs follow the
|
||||
//! `bind_addr:bind_port:target_host:target_port` format.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use russh::client;
|
||||
use tokio::io;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::error::ForwardError;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PortForwardSpecKind {
|
||||
Local,
|
||||
Remote,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PortForwardSpec {
|
||||
pub kind: PortForwardSpecKind,
|
||||
pub bind_addr: String,
|
||||
pub bind_port: u16,
|
||||
pub target_host: String,
|
||||
pub target_port: u16,
|
||||
}
|
||||
|
||||
impl PortForwardSpec {
|
||||
pub fn local(spec: &str) -> Result<Self, ForwardError> {
|
||||
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||
Ok(Self {
|
||||
kind: PortForwardSpecKind::Local,
|
||||
bind_addr,
|
||||
bind_port,
|
||||
target_host,
|
||||
target_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remote(spec: &str) -> Result<Self, ForwardError> {
|
||||
let (bind_addr, bind_port, target_host, target_port) = parse_spec(spec)?;
|
||||
Ok(Self {
|
||||
kind: PortForwardSpecKind::Remote,
|
||||
bind_addr,
|
||||
bind_port,
|
||||
target_host,
|
||||
target_port,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||
format!("{}:{}", self.bind_addr, self.bind_port)
|
||||
.parse()
|
||||
.map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: format!("{}:{}", self.bind_addr, self.bind_port),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn target_addr(&self) -> Result<SocketAddr, ForwardError> {
|
||||
format!("{}:{}", self.target_host, self.target_port)
|
||||
.parse()
|
||||
.map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: format!("{}:{}", self.target_host, self.target_port),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PortForwardSpec {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let prefix = match self.kind {
|
||||
PortForwardSpecKind::Local => "-L",
|
||||
PortForwardSpecKind::Remote => "-R",
|
||||
};
|
||||
write!(
|
||||
f,
|
||||
"{} {}:{}:{}:{}",
|
||||
prefix, self.bind_addr, self.bind_port, self.target_host, self.target_port
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_spec(spec: &str) -> Result<(String, u16, String, u16), ForwardError> {
|
||||
let parts: Vec<&str> = spec.split(':').collect();
|
||||
if parts.len() != 4 {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let bind_addr = parts[0].to_string();
|
||||
let bind_port: u16 = parts[1].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
})?;
|
||||
let target_host = parts[2].to_string();
|
||||
let target_port: u16 = parts[3].parse().map_err(|_| ForwardError::InvalidSpec {
|
||||
spec: spec.to_string(),
|
||||
})?;
|
||||
|
||||
Ok((bind_addr, bind_port, target_host, target_port))
|
||||
}
|
||||
|
||||
pub struct LocalForwarder {
|
||||
spec: PortForwardSpec,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl LocalForwarder {
|
||||
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||
if spec.kind != PortForwardSpecKind::Local {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: format!("expected local spec, got {:?}", spec.kind),
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
spec,
|
||||
listener: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn spec(&self) -> &PortForwardSpec {
|
||||
&self.spec
|
||||
}
|
||||
|
||||
pub async fn run<H: client::Handler + Send + 'static>(
|
||||
&mut self,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
) -> Result<(), ForwardError> {
|
||||
let listen_addr = self.spec.listen_addr()?;
|
||||
let listener: TcpListener = TcpListener::bind(listen_addr)
|
||||
.await
|
||||
.map_err(|e| ForwardError::BindFailed { source: e })?;
|
||||
self.listener = Some(listener);
|
||||
let remote_host = self.spec.target_host.clone();
|
||||
let remote_port = self.spec.target_port;
|
||||
|
||||
info!(
|
||||
"local forward listening on {} -> {}:{}",
|
||||
listen_addr, remote_host, remote_port
|
||||
);
|
||||
|
||||
loop {
|
||||
let listener = match &self.listener {
|
||||
Some(l) => l,
|
||||
None => return Ok(()),
|
||||
};
|
||||
let accept_result = listener.accept().await;
|
||||
let (local_stream, local_addr) = match accept_result {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
let handle = handle.lock().await;
|
||||
if handle.is_closed() {
|
||||
debug!("local forward accept loop ending: ssh session closed");
|
||||
return Ok(());
|
||||
}
|
||||
drop(handle);
|
||||
error!("local forward accept error: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"local forward connection from {} -> {}:{}",
|
||||
local_addr, remote_host, remote_port
|
||||
);
|
||||
|
||||
let handle = handle.clone();
|
||||
let remote_host = remote_host.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
proxy_local_to_remote(local_stream, handle, &remote_host, remote_port).await
|
||||
{
|
||||
debug!("local forward proxy error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(listener) = self.listener.take() {
|
||||
drop(listener);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_port(&self) -> u16 {
|
||||
self.spec.bind_port
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_local_to_remote<H: client::Handler + Send + 'static>(
|
||||
local_stream: TcpStream,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
remote_host: &str,
|
||||
remote_port: u16,
|
||||
) -> Result<(), ForwardError> {
|
||||
let local_addr = local_stream
|
||||
.peer_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let handle_guard = handle.lock().await;
|
||||
let channel = handle_guard
|
||||
.channel_open_direct_tcpip(remote_host, remote_port as u32, &local_addr, 0)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
drop(handle_guard);
|
||||
|
||||
let ssh_stream = channel.into_stream();
|
||||
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||
|
||||
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||
|
||||
match tokio::join!(client_to_server, server_to_client) {
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
debug!("local forward bidirectional copy error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct RemoteForwarder {
|
||||
spec: PortForwardSpec,
|
||||
cancel: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl RemoteForwarder {
|
||||
pub fn new(spec: PortForwardSpec) -> Result<Self, ForwardError> {
|
||||
if spec.kind != PortForwardSpecKind::Remote {
|
||||
return Err(ForwardError::InvalidSpec {
|
||||
spec: format!("expected remote spec, got {:?}", spec.kind),
|
||||
});
|
||||
}
|
||||
Ok(Self { spec, cancel: None })
|
||||
}
|
||||
|
||||
pub fn spec(&self) -> &PortForwardSpec {
|
||||
&self.spec
|
||||
}
|
||||
|
||||
pub async fn register<H: client::Handler + Send + 'static>(
|
||||
&self,
|
||||
handle: &mut client::Handle<H>,
|
||||
) -> Result<u32, ForwardError> {
|
||||
let port = handle
|
||||
.tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
pub async fn handle_forwarded_channel(
|
||||
channel: russh::Channel<russh::client::Msg>,
|
||||
connected_address: &str,
|
||||
connected_port: u32,
|
||||
local_host: &str,
|
||||
local_port: u16,
|
||||
) {
|
||||
debug!(
|
||||
"remote forward: server opened forwarded-tcpip channel to {}:{} -> local {}:{}",
|
||||
connected_address, connected_port, local_host, local_port
|
||||
);
|
||||
|
||||
let local_target = format!("{}:{}", local_host, local_port);
|
||||
let local_stream = match TcpStream::connect(&local_target).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"remote forward: failed to connect to local target {}: {}",
|
||||
local_target, e
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let ssh_stream = channel.into_stream();
|
||||
let (mut ssh_read, mut ssh_write) = tokio::io::split(ssh_stream);
|
||||
let (mut local_read, mut local_write) = tokio::io::split(local_stream);
|
||||
|
||||
let client_to_server = io::copy(&mut local_read, &mut ssh_write);
|
||||
let server_to_client = io::copy(&mut ssh_read, &mut local_write);
|
||||
|
||||
match tokio::join!(client_to_server, server_to_client) {
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
debug!("remote forward bidirectional copy error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unregister<H: client::Handler + Send + 'static>(
|
||||
&self,
|
||||
handle: &client::Handle<H>,
|
||||
) -> Result<(), ForwardError> {
|
||||
handle
|
||||
.cancel_tcpip_forward(&self.spec.bind_addr, self.spec.bind_port as u32)
|
||||
.await
|
||||
.map_err(|e| ForwardError::ChannelOpenFailed {
|
||||
source: Box::new(e) as _,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(cancel) = self.cancel.take() {
|
||||
let _ = cancel.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_local_forwarders<H: client::Handler + Send + 'static>(
|
||||
forwarders: Vec<LocalForwarder>,
|
||||
handle: Arc<Mutex<client::Handle<H>>>,
|
||||
mut shutdown: tokio::sync::watch::Receiver<bool>,
|
||||
) -> Vec<LocalForwarder> {
|
||||
let mut forwarders = forwarders;
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for forwarder in forwarders.drain(..) {
|
||||
let handle = handle.clone();
|
||||
let spec = forwarder.spec().clone();
|
||||
let (_cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut fwd = forwarder;
|
||||
tokio::select! {
|
||||
result = fwd.run(handle) => {
|
||||
if let Err(e) = result {
|
||||
error!("local forward {} failed: {}", spec, e);
|
||||
}
|
||||
}
|
||||
_ = cancel_rx => {
|
||||
fwd.stop().await;
|
||||
}
|
||||
}
|
||||
fwd
|
||||
}));
|
||||
}
|
||||
|
||||
let _ = shutdown.changed().await;
|
||||
|
||||
for task in &tasks {
|
||||
task.abort();
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for task in tasks {
|
||||
match task.await {
|
||||
Ok(fwd) => results.push(fwd),
|
||||
Err(e) => {
|
||||
if !e.is_cancelled() {
|
||||
error!("local forwarder task panicked: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_local_spec() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert_eq!(spec.kind, PortForwardSpecKind::Local);
|
||||
assert_eq!(spec.bind_addr, "127.0.0.1");
|
||||
assert_eq!(spec.bind_port, 5432);
|
||||
assert_eq!(spec.target_host, "db.internal");
|
||||
assert_eq!(spec.target_port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_remote_spec() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert_eq!(spec.kind, PortForwardSpecKind::Remote);
|
||||
assert_eq!(spec.bind_addr, "0.0.0.0");
|
||||
assert_eq!(spec.bind_port, 8080);
|
||||
assert_eq!(spec.target_host, "127.0.0.1");
|
||||
assert_eq!(spec.target_port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_few_parts() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:5432:db").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_many_parts() {
|
||||
assert!(PortForwardSpec::local("a:b:c:d:e").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_port() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:abc:db:5432").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_invalid_target_port() {
|
||||
assert!(PortForwardSpec::local("127.0.0.1:5432:db:abc").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_display() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert_eq!(spec.to_string(), "-L 127.0.0.1:5432:db.internal:5432");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spec_display_remote() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert_eq!(spec.to_string(), "-R 0.0.0.0:8080:127.0.0.1:3000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_forwarder_rejects_remote_spec() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
assert!(LocalForwarder::new(spec).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_forwarder_rejects_local_spec() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
assert!(RemoteForwarder::new(spec).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listen_addr_valid() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
let addr = spec.listen_addr().unwrap();
|
||||
assert_eq!(addr.port(), 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listen_addr_invalid_host() {
|
||||
let spec = PortForwardSpec {
|
||||
kind: PortForwardSpecKind::Local,
|
||||
bind_addr: "!!!invalid".to_string(),
|
||||
bind_port: 5432,
|
||||
target_host: "db".to_string(),
|
||||
target_port: 5432,
|
||||
};
|
||||
assert!(spec.listen_addr().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn local_forward_bind_and_accept() {
|
||||
let spec = PortForwardSpec::local(&format!("127.0.0.1:0:remote:5432")).unwrap();
|
||||
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||
|
||||
let listen_addr = forwarder.spec.listen_addr().unwrap();
|
||||
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
||||
let bound_addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
|
||||
let spec = PortForwardSpec::local(&format!("127.0.0.1:{}:remote:5432", bound_addr.port()))
|
||||
.unwrap();
|
||||
let forwarder = LocalForwarder::new(spec).unwrap();
|
||||
assert_eq!(forwarder.local_port(), bound_addr.port());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_forward_proxy_bidirectional() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let _echo_addr = echo_listener.local_addr().unwrap();
|
||||
|
||||
let echo_server = tokio::spawn(async move {
|
||||
let (mut stream, _) = echo_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(_) => break,
|
||||
};
|
||||
if stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let local_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = local_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_task = tokio::spawn(async move {
|
||||
let (stream, _) = local_listener.accept().await.unwrap();
|
||||
let (mut read, mut write) = tokio::io::split(stream);
|
||||
let _ = io::copy(&mut read, &mut write).await;
|
||||
});
|
||||
|
||||
let mut local_conn = TcpStream::connect(local_addr).await.unwrap();
|
||||
local_conn.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = local_conn.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..n], b"hello");
|
||||
|
||||
echo_server.abort();
|
||||
proxy_task.abort();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarder_spec_access() {
|
||||
let spec = PortForwardSpec::local("127.0.0.1:5432:db.internal:5432").unwrap();
|
||||
let forwarder = LocalForwarder::new(spec.clone()).unwrap();
|
||||
assert_eq!(forwarder.spec(), &spec);
|
||||
assert_eq!(forwarder.local_port(), 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_forwarder_spec_access() {
|
||||
let spec = PortForwardSpec::remote("0.0.0.0:8080:127.0.0.1:3000").unwrap();
|
||||
let forwarder = RemoteForwarder::new(spec.clone()).unwrap();
|
||||
assert_eq!(forwarder.spec(), &spec);
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//! Client-side SSH session management.
|
||||
//!
|
||||
//! Provides `ClientSession` for establishing an SSH connection over any transport,
|
||||
//! running a local SOCKS5 proxy, and managing port forwards. Also provides
|
||||
//! `ChannelManager` for programmatic channel management with automatic reconnection.
|
||||
//!
|
||||
//! The client always starts a SOCKS5 proxy (default `127.0.0.1:1080`) when running
|
||||
//! via `ClientSession::run()`. For VPN-like "route all traffic" behavior, use
|
||||
//! [tun2proxy](https://github.com/tun2proxy/tun2proxy) alongside the SOCKS5 proxy.
|
||||
|
||||
pub mod channel_manager;
|
||||
pub mod connect;
|
||||
pub mod forward;
|
||||
|
||||
pub use channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
pub use forward::{LocalForwarder, PortForwardSpec, PortForwardSpecKind, RemoteForwarder};
|
||||
714
crates/alknet-core/src/config.rs
Normal file
714
crates/alknet-core/src/config.rs
Normal file
@@ -0,0 +1,714 @@
|
||||
//! Configuration: `DynamicConfig`, `AuthPolicy`, `ApiKeyEntry`,
|
||||
//! `RateLimitConfig`, `ConfigReloadHandle`.
|
||||
//!
|
||||
//! See `docs/architecture/crates/core/config.md` for the full specification.
|
||||
//!
|
||||
//! This module provides the dynamic-config types required by
|
||||
//! `auth::ConfigIdentityProvider`. The remaining types (`StaticConfig`,
|
||||
//! `TlsIdentity`, `ConfigError`) are filled in by the core/config task.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::auth::Identity;
|
||||
|
||||
pub const API_KEY_PREFIX: &str = "alk_";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StaticConfig {
|
||||
pub listen_addr: Option<SocketAddr>,
|
||||
pub tls_identity: Option<TlsIdentity>,
|
||||
#[cfg(feature = "iroh")]
|
||||
pub iroh_relay: Option<iroh::RelayUrl>,
|
||||
pub drain_timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Ed25519SecretKey(ed25519_dalek::SigningKey);
|
||||
|
||||
impl Ed25519SecretKey {
|
||||
pub fn generate() -> Self {
|
||||
let mut csprng = rand::rngs::OsRng;
|
||||
Self(ed25519_dalek::SigningKey::generate(&mut csprng))
|
||||
}
|
||||
|
||||
pub fn from_bytes(bytes: &[u8; 32]) -> Self {
|
||||
Self(ed25519_dalek::SigningKey::from_bytes(bytes))
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; 32] {
|
||||
self.0.to_bytes()
|
||||
}
|
||||
|
||||
pub fn public(&self) -> ed25519_dalek::VerifyingKey {
|
||||
self.0.verifying_key()
|
||||
}
|
||||
|
||||
pub fn sign(&self, message: &[u8]) -> ed25519_dalek::Signature {
|
||||
use ed25519_dalek::Signer;
|
||||
self.0.sign(message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Ed25519SecretKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Ed25519SecretKey").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl zeroize::ZeroizeOnDrop for Ed25519SecretKey {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AcmeDirectory {
|
||||
Production,
|
||||
Staging,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl AcmeDirectory {
|
||||
pub fn url(&self) -> &str {
|
||||
match self {
|
||||
AcmeDirectory::Production => "https://acme-v02.api.letsencrypt.org/directory",
|
||||
AcmeDirectory::Staging => "https://acme-staging-v02.api.letsencrypt.org/directory",
|
||||
AcmeDirectory::Custom(url) => url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TlsIdentity {
|
||||
X509 {
|
||||
cert: PathBuf,
|
||||
key: PathBuf,
|
||||
},
|
||||
RawKey(Ed25519SecretKey),
|
||||
SelfSigned,
|
||||
Acme {
|
||||
domains: Vec<String>,
|
||||
cache_dir: PathBuf,
|
||||
directory: AcmeDirectory,
|
||||
contact: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct DynamicConfig {
|
||||
pub auth: AuthPolicy,
|
||||
pub rate_limits: RateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PeerEntry {
|
||||
pub peer_id: String,
|
||||
pub fingerprints: Vec<String>,
|
||||
pub auth_token_hash: Option<String>,
|
||||
pub scopes: Vec<String>,
|
||||
pub resources: HashMap<String, Vec<String>>,
|
||||
pub display_name: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AuthPolicy {
|
||||
pub peers: Vec<PeerEntry>,
|
||||
pub api_keys: Vec<ApiKeyEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApiKeyEntry {
|
||||
pub prefix: String,
|
||||
pub hash: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub description: String,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
impl AuthPolicy {
|
||||
pub fn empty() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn resolve_identity_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
self.peers
|
||||
.iter()
|
||||
.find(|p| p.enabled && p.fingerprints.iter().any(|f| f == fingerprint))
|
||||
.map(|p| Identity {
|
||||
id: p.peer_id.clone(),
|
||||
scopes: p.scopes.clone(),
|
||||
resources: p.resources.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_identity_from_token(&self, token: &str) -> Option<Identity> {
|
||||
let token_hash = sha256_hex(token);
|
||||
self.peers
|
||||
.iter()
|
||||
.find(|p| p.enabled && p.auth_token_hash.as_deref() == Some(&token_hash))
|
||||
.map(|p| Identity {
|
||||
id: p.peer_id.clone(),
|
||||
scopes: p.scopes.clone(),
|
||||
resources: p.resources.clone(),
|
||||
})
|
||||
.or_else(|| self.resolve_api_key(token))
|
||||
}
|
||||
|
||||
pub fn validate_peer_ids(&self) -> Result<(), DuplicatePeerId> {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for peer in &self.peers {
|
||||
if !seen.insert(peer.peer_id.as_str()) {
|
||||
return Err(DuplicatePeerId {
|
||||
peer_id: peer.peer_id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resolve_api_key(&self, token: &str) -> Option<Identity> {
|
||||
if !token.starts_with(API_KEY_PREFIX) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix_part = &token[..token.len().min(8)];
|
||||
|
||||
let entry = self
|
||||
.api_keys
|
||||
.iter()
|
||||
.find(|e| prefix_part.starts_with(&e.prefix))?;
|
||||
|
||||
let expected_hash = sha256_hex(token);
|
||||
|
||||
if entry.hash != expected_hash {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at {
|
||||
let now_secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
if now_secs >= expires_at {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(Identity {
|
||||
id: entry.prefix.clone(),
|
||||
scopes: entry.scopes.clone(),
|
||||
resources: std::collections::HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn sha256_hex(input: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(input.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
|
||||
#[error("duplicate peer_id: {peer_id}")]
|
||||
pub struct DuplicatePeerId {
|
||||
pub peer_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
pub max_connections_per_ip: usize,
|
||||
pub max_auth_attempts: usize,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections_per_ip: 100,
|
||||
max_auth_attempts: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConfigReloadHandle {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigReloadHandle {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
|
||||
pub fn reload(&self, new_config: DynamicConfig) {
|
||||
self.dynamic.store(Arc::new(new_config));
|
||||
}
|
||||
|
||||
pub fn dynamic(&self) -> Arc<DynamicConfig> {
|
||||
self.dynamic.load_full()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("invalid flag: {name}")]
|
||||
InvalidFlag { name: String },
|
||||
#[error("key file not found: {path}")]
|
||||
KeyFileNotFound { path: String },
|
||||
#[error("bind failed: {0}")]
|
||||
BindFailed(#[from] io::Error),
|
||||
#[error("tls config error: {0}")]
|
||||
TlsConfig(io::Error),
|
||||
#[error("incompatible options")]
|
||||
IncompatibleOptions,
|
||||
}
|
||||
|
||||
impl Default for StaticConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
listen_addr: None,
|
||||
tls_identity: None,
|
||||
#[cfg(feature = "iroh")]
|
||||
iroh_relay: None,
|
||||
drain_timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn static_config_default() {
|
||||
let cfg = StaticConfig::default();
|
||||
assert!(cfg.listen_addr.is_none());
|
||||
assert!(cfg.tls_identity.is_none());
|
||||
assert_eq!(cfg.drain_timeout, Duration::from_secs(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_config_default() {
|
||||
let cfg = DynamicConfig::default();
|
||||
assert!(cfg.auth.peers.is_empty());
|
||||
assert!(cfg.auth.api_keys.is_empty());
|
||||
assert_eq!(cfg.rate_limits.max_connections_per_ip, 100);
|
||||
assert_eq!(cfg.rate_limits.max_auth_attempts, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_default() {
|
||||
let policy = AuthPolicy::default();
|
||||
assert!(policy.peers.is_empty());
|
||||
assert!(policy.api_keys.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limit_config_default() {
|
||||
let rl = RateLimitConfig::default();
|
||||
assert!(rl.max_connections_per_ip > 0);
|
||||
assert!(rl.max_auth_attempts > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_entry_construct() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk12345".to_string(),
|
||||
hash: "deadbeef".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: Some(1_700_000_000),
|
||||
};
|
||||
assert_eq!(entry.prefix, "alk12345");
|
||||
assert_eq!(entry.scopes, vec!["admin"]);
|
||||
assert_eq!(entry.expires_at, Some(1_700_000_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_identity_x509_construct() {
|
||||
let id = TlsIdentity::X509 {
|
||||
cert: PathBuf::from("/etc/cert.pem"),
|
||||
key: PathBuf::from("/etc/key.pem"),
|
||||
};
|
||||
match id {
|
||||
TlsIdentity::X509 { cert, key } => {
|
||||
assert_eq!(cert, PathBuf::from("/etc/cert.pem"));
|
||||
assert_eq!(key, PathBuf::from("/etc/key.pem"));
|
||||
}
|
||||
_ => panic!("expected X509"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_identity_self_signed() {
|
||||
let id = TlsIdentity::SelfSigned;
|
||||
let s = format!("{id:?}");
|
||||
assert!(s.contains("SelfSigned"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_handle_swaps_atomically() {
|
||||
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
|
||||
let handle = ConfigReloadHandle::new(dynamic.clone());
|
||||
|
||||
let initial = handle.dynamic();
|
||||
assert!(initial.auth.peers.is_empty());
|
||||
|
||||
let new_auth = AuthPolicy {
|
||||
peers: vec![PeerEntry {
|
||||
peer_id: "worker-a".to_string(),
|
||||
fingerprints: vec!["aa:bb:cc".to_string()],
|
||||
auth_token_hash: None,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
display_name: None,
|
||||
enabled: true,
|
||||
}],
|
||||
api_keys: Vec::new(),
|
||||
};
|
||||
let new_config = DynamicConfig {
|
||||
auth: new_auth,
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
};
|
||||
handle.reload(new_config);
|
||||
|
||||
let after = handle.dynamic();
|
||||
assert_eq!(after.auth.peers.len(), 1);
|
||||
assert_eq!(after.auth.peers[0].peer_id, "worker-a");
|
||||
assert!(initial.auth.peers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_handle_dynamic_returns_current() {
|
||||
let dynamic = Arc::new(ArcSwap::from_pointee(DynamicConfig::default()));
|
||||
let handle = ConfigReloadHandle::new(dynamic);
|
||||
let a = handle.dynamic();
|
||||
let b = handle.dynamic();
|
||||
assert_eq!(
|
||||
a.rate_limits.max_auth_attempts,
|
||||
b.rate_limits.max_auth_attempts
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_invalid_flag_display() {
|
||||
let e = ConfigError::InvalidFlag {
|
||||
name: "foo".to_string(),
|
||||
};
|
||||
assert_eq!(format!("{e}"), "invalid flag: foo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_key_file_not_found_display() {
|
||||
let e = ConfigError::KeyFileNotFound {
|
||||
path: "/x".to_string(),
|
||||
};
|
||||
assert_eq!(format!("{e}"), "key file not found: /x");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_incompatible_options_display() {
|
||||
let e = ConfigError::IncompatibleOptions;
|
||||
assert_eq!(format!("{e}"), "incompatible options");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_bind_failed_from_io() {
|
||||
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "busy");
|
||||
let e: ConfigError = io_err.into();
|
||||
assert!(matches!(e, ConfigError::BindFailed(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_tls_config_display() {
|
||||
let e = ConfigError::TlsConfig(io::Error::new(io::ErrorKind::InvalidData, "bad"));
|
||||
let s = format!("{e}");
|
||||
assert!(s.starts_with("tls config error:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_returns_empty_resources() {
|
||||
let token = "alk_test_secret";
|
||||
let hash = sha256_hex(token);
|
||||
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_tes".to_string(),
|
||||
hash,
|
||||
scopes: vec!["admin".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy = AuthPolicy {
|
||||
peers: Vec::new(),
|
||||
api_keys: vec![entry],
|
||||
};
|
||||
|
||||
let identity = policy.resolve_api_key(token);
|
||||
assert!(
|
||||
identity.is_some(),
|
||||
"api key with matching prefix and hash should resolve"
|
||||
);
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, "alk_tes");
|
||||
assert_eq!(identity.scopes, vec!["admin"]);
|
||||
assert!(
|
||||
identity.resources.is_empty(),
|
||||
"token-resolved identities must have empty resources (Option B — scopes only)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_identity_from_fingerprint_uses_peer_id() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![PeerEntry {
|
||||
peer_id: "worker-a".to_string(),
|
||||
fingerprints: vec!["SHA256:known".to_string()],
|
||||
auth_token_hash: None,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
display_name: None,
|
||||
enabled: true,
|
||||
}],
|
||||
api_keys: vec![],
|
||||
};
|
||||
|
||||
let identity = policy
|
||||
.resolve_identity_from_fingerprint("SHA256:known")
|
||||
.expect("known fingerprint should resolve");
|
||||
assert_eq!(identity.id, "worker-a");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
// --- PeerEntry model (ADR-030) ---------------------------------------
|
||||
|
||||
fn peer_entry(peer_id: &str, fingerprints: &[&str]) -> PeerEntry {
|
||||
PeerEntry {
|
||||
peer_id: peer_id.to_string(),
|
||||
fingerprints: fingerprints.iter().map(|s| s.to_string()).collect(),
|
||||
auth_token_hash: None,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
display_name: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_known_returns_some_with_peer_id() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![peer_entry("worker-a", &["ed25519:abc"])],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let identity = policy
|
||||
.resolve_identity_from_fingerprint("ed25519:abc")
|
||||
.expect("known fingerprint resolves");
|
||||
assert_eq!(identity.id, "worker-a");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_unknown_returns_none() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![peer_entry("worker-a", &["ed25519:abc"])],
|
||||
api_keys: vec![],
|
||||
};
|
||||
assert!(policy
|
||||
.resolve_identity_from_fingerprint("ed25519:unknown")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_resolution_disabled_returns_none() {
|
||||
let mut entry = peer_entry("worker-a", &["ed25519:abc"]);
|
||||
entry.enabled = false;
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![],
|
||||
};
|
||||
assert!(policy
|
||||
.resolve_identity_from_fingerprint("ed25519:abc")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_matching_peer_returns_some_with_peer_id() {
|
||||
let token = "bearer-secret";
|
||||
let mut entry = peer_entry("worker-a", &["ed25519:abc"]);
|
||||
entry.auth_token_hash = Some(sha256_hex(token));
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let identity = policy
|
||||
.resolve_identity_from_token(token)
|
||||
.expect("matching auth_token_hash resolves");
|
||||
assert_eq!(identity.id, "worker-a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_non_matching_falls_through_to_api_key() {
|
||||
let api_token = "alk_test_secret";
|
||||
let mut entry = peer_entry("worker-a", &["ed25519:abc"]);
|
||||
entry.auth_token_hash = Some(sha256_hex("different-token"));
|
||||
let api_entry = ApiKeyEntry {
|
||||
prefix: "alk_tes".to_string(),
|
||||
hash: sha256_hex(api_token),
|
||||
scopes: vec!["admin".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![api_entry],
|
||||
};
|
||||
let identity = policy
|
||||
.resolve_identity_from_token(api_token)
|
||||
.expect("api key fall-through resolves");
|
||||
assert_eq!(identity.id, "alk_tes");
|
||||
assert_eq!(identity.scopes, vec!["admin"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_resolution_no_match_returns_none() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![peer_entry("worker-a", &["ed25519:abc"])],
|
||||
api_keys: vec![],
|
||||
};
|
||||
assert!(policy.resolve_identity_from_token("unknown").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_fingerprint_peer_any_resolves_to_same_peer_id() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![peer_entry("worker-a", &["ed25519:abc", "SHA256:def"])],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let id1 = policy
|
||||
.resolve_identity_from_fingerprint("ed25519:abc")
|
||||
.expect("first fingerprint resolves");
|
||||
let id2 = policy
|
||||
.resolve_identity_from_fingerprint("SHA256:def")
|
||||
.expect("second fingerprint resolves");
|
||||
assert_eq!(id1.id, "worker-a");
|
||||
assert_eq!(id2.id, "worker-a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resources_populated_on_fingerprint_path() {
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["gitea".to_string()]);
|
||||
let mut entry = peer_entry("worker-a", &["ed25519:abc"]);
|
||||
entry.resources = resources.clone();
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let identity = policy
|
||||
.resolve_identity_from_fingerprint("ed25519:abc")
|
||||
.expect("known fingerprint resolves");
|
||||
assert_eq!(identity.resources, resources);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resources_populated_on_token_path() {
|
||||
let token = "bearer-secret";
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["gitea".to_string()]);
|
||||
let mut entry = peer_entry("worker-a", &["ed25519:abc"]);
|
||||
entry.auth_token_hash = Some(sha256_hex(token));
|
||||
entry.resources = resources.clone();
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![entry],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let identity = policy
|
||||
.resolve_identity_from_token(token)
|
||||
.expect("matching token resolves");
|
||||
assert_eq!(identity.resources, resources);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_peer_id_validation_rejects() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![
|
||||
peer_entry("worker-a", &["ed25519:abc"]),
|
||||
peer_entry("worker-a", &["ed25519:def"]),
|
||||
],
|
||||
api_keys: vec![],
|
||||
};
|
||||
let err = policy.validate_peer_ids().expect_err("duplicate detected");
|
||||
assert_eq!(err.peer_id, "worker-a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unique_peer_ids_validate_ok() {
|
||||
let policy = AuthPolicy {
|
||||
peers: vec![
|
||||
peer_entry("worker-a", &["ed25519:abc"]),
|
||||
peer_entry("worker-b", &["ed25519:def"]),
|
||||
],
|
||||
api_keys: vec![],
|
||||
};
|
||||
assert!(policy.validate_peer_ids().is_ok());
|
||||
}
|
||||
|
||||
// --- Ed25519SecretKey -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn ed25519_secret_key_round_trips_bytes() {
|
||||
let key = Ed25519SecretKey::generate();
|
||||
let bytes = key.as_bytes();
|
||||
let restored = Ed25519SecretKey::from_bytes(&bytes);
|
||||
assert_eq!(restored.as_bytes(), bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ed25519_secret_key_sign_verifies_against_public_key() {
|
||||
use ed25519_dalek::{Signature, Verifier};
|
||||
let key = Ed25519SecretKey::generate();
|
||||
let public = key.public();
|
||||
let message = b"alknet coverage check";
|
||||
let signature: Signature = key.sign(message);
|
||||
assert_eq!(signature.to_bytes().len(), 64);
|
||||
assert!(
|
||||
public.verify(message, &signature).is_ok(),
|
||||
"signature produced by Ed25519SecretKey::sign must verify under its public key"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ed25519_secret_key_sign_rejects_tampered_message() {
|
||||
use ed25519_dalek::{Signature, Verifier};
|
||||
let key = Ed25519SecretKey::generate();
|
||||
let public = key.public();
|
||||
let signature: Signature = key.sign(b"original message");
|
||||
assert!(
|
||||
public.verify(b"tampered message", &signature).is_err(),
|
||||
"signature must not verify against a different message"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ed25519_secret_key_debug_does_not_leak_material() {
|
||||
let key = Ed25519SecretKey::generate();
|
||||
let dbg = format!("{key:?}");
|
||||
assert!(dbg.contains("Ed25519SecretKey"));
|
||||
assert!(!dbg.contains("SigningKey"));
|
||||
let raw = hex::encode(key.as_bytes());
|
||||
assert!(
|
||||
!dbg.contains(&raw),
|
||||
"Debug output must not contain the raw key bytes"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ed25519_secret_key_public_matches_underlying_signing_key() {
|
||||
let key = Ed25519SecretKey::generate();
|
||||
let public = key.public();
|
||||
assert_eq!(public.to_bytes().len(), 32);
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
//! Configuration service for runtime config reload.
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use super::{DynamicConfig, ForwardingPolicy, RateLimitConfig};
|
||||
|
||||
pub struct ConfigServiceImpl {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigServiceImpl {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
|
||||
pub fn forwarding_policy(&self) -> Arc<ForwardingPolicy> {
|
||||
Arc::new(self.dynamic.load().forwarding.clone())
|
||||
}
|
||||
|
||||
pub fn rate_limits(&self) -> Arc<RateLimitConfig> {
|
||||
Arc::new(self.dynamic.load().rate_limits.clone())
|
||||
}
|
||||
|
||||
pub fn reload(&self, new_config: DynamicConfig) {
|
||||
self.dynamic.store(Arc::new(new_config));
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConfigServiceImpl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConfigServiceImpl").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
#[allow(dead_code)]
|
||||
pub enum ConfigProtocol {
|
||||
GetForwardingPolicy,
|
||||
GetRateLimits,
|
||||
ReloadForwarding { policy: ForwardingPolicy },
|
||||
ReloadRateLimits { limits: RateLimitConfig },
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::AuthPolicy;
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_forwarding_policy() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let policy = service.forwarding_policy();
|
||||
assert_eq!(policy.default, ForwardingPolicy::allow_all().default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_rate_limits() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let limits = service.rate_limits();
|
||||
assert_eq!(limits.max_auth_attempts, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_reload() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
assert_eq!(
|
||||
service.forwarding_policy().default,
|
||||
ForwardingPolicy::allow_all().default
|
||||
);
|
||||
|
||||
let new_config = DynamicConfig {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::deny_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: std::collections::HashMap::new(),
|
||||
};
|
||||
service.reload(new_config);
|
||||
|
||||
assert_eq!(
|
||||
service.forwarding_policy().default,
|
||||
ForwardingPolicy::deny_all().default
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_service_impl_debug() {
|
||||
let (arc_swap, _) = super::super::new_dynamic_config();
|
||||
let service = ConfigServiceImpl::new(Arc::clone(&arc_swap));
|
||||
let debug_str = format!("{:?}", service);
|
||||
assert!(debug_str.contains("ConfigServiceImpl"));
|
||||
}
|
||||
}
|
||||
@@ -1,603 +0,0 @@
|
||||
//! Runtime-reloadable dynamic configuration (auth policy, forwarding policy, rate limits).
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
|
||||
use crate::auth::identity::Identity;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::forwarding::ForwardingPolicy;
|
||||
use crate::credentials::CredentialSet;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub struct ApiKeyEntry {
|
||||
pub prefix: String,
|
||||
pub hash: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub description: String,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
pub const API_KEY_PREFIX: &str = "alk_";
|
||||
|
||||
pub struct AuthPolicy {
|
||||
pub authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
pub cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
pub api_keys: Vec<ApiKeyEntry>,
|
||||
encoded_keys: std::collections::HashSet<Vec<u8>>,
|
||||
fingerprint_to_key: HashMap<String, russh::keys::PublicKey>,
|
||||
}
|
||||
|
||||
fn encode_key_data(key: &russh::keys::PublicKey) -> Vec<u8> {
|
||||
use russh::keys::helpers::EncodedExt;
|
||||
key.key_data().encoded().unwrap_or_default()
|
||||
}
|
||||
|
||||
impl AuthPolicy {
|
||||
pub fn new(
|
||||
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
) -> Self {
|
||||
Self::with_api_keys(authorized_keys, cert_authorities, Vec::new())
|
||||
}
|
||||
|
||||
pub fn with_api_keys(
|
||||
authorized_keys: std::collections::HashSet<russh::keys::PublicKey>,
|
||||
cert_authorities: Vec<crate::auth::keys::CertAuthorityEntry>,
|
||||
api_keys: Vec<ApiKeyEntry>,
|
||||
) -> Self {
|
||||
let encoded_keys = authorized_keys.iter().map(encode_key_data).collect();
|
||||
let fingerprint_to_key = authorized_keys
|
||||
.iter()
|
||||
.map(|k| (format!("{}", k.fingerprint(HashAlg::Sha256)), k.clone()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
authorized_keys,
|
||||
cert_authorities,
|
||||
api_keys,
|
||||
encoded_keys,
|
||||
fingerprint_to_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_server_auth_config(config: ServerAuthConfig) -> Self {
|
||||
Self::new(config.authorized_keys, config.cert_authorities)
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self::new(std::collections::HashSet::new(), Vec::new())
|
||||
}
|
||||
|
||||
pub fn resolve_identity_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
if self.fingerprint_to_key.contains_key(fingerprint) {
|
||||
Some(Identity {
|
||||
id: fingerprint.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_api_key(&self, token: &str) -> Option<Identity> {
|
||||
if !token.starts_with(API_KEY_PREFIX) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix_part = &token[..token.len().min(8)];
|
||||
|
||||
let entry = self
|
||||
.api_keys
|
||||
.iter()
|
||||
.find(|e| prefix_part.starts_with(&e.prefix))?;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
let expected_hash = format!("sha256:{}", hex::encode(result));
|
||||
|
||||
if entry.hash != expected_hash {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at {
|
||||
let now_secs = std::time::SystemTime::now()
|
||||
.duration_since(std::time::SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
if now_secs >= expires_at {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(Identity {
|
||||
id: entry.prefix.clone(),
|
||||
scopes: entry.scopes.clone(),
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate_publickey(
|
||||
&self,
|
||||
key: &russh::keys::PublicKey,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let encoded = encode_key_data(key);
|
||||
if self.encoded_keys.contains(&encoded) {
|
||||
return Ok(());
|
||||
}
|
||||
Err(crate::error::AuthError::KeyRejected)
|
||||
}
|
||||
|
||||
pub fn authenticate_certificate(
|
||||
&self,
|
||||
cert: &russh::keys::Certificate,
|
||||
user: &str,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
use std::time::SystemTime;
|
||||
|
||||
let matching_ca = self
|
||||
.cert_authorities
|
||||
.iter()
|
||||
.find(|ca| cert.signature_key() == ca.public_key.key_data());
|
||||
|
||||
let ca_entry = match matching_ca {
|
||||
Some(entry) => entry,
|
||||
None => return Err(crate::error::AuthError::CertInvalid),
|
||||
};
|
||||
|
||||
if cert.verify_signature().is_err() {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
let now_secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
if now_secs < cert.valid_after() || now_secs >= cert.valid_before() {
|
||||
return Err(crate::error::AuthError::CertExpired);
|
||||
}
|
||||
|
||||
let principals = cert.valid_principals();
|
||||
if !principals.is_empty() && !principals.iter().any(|p| p == user) {
|
||||
return Err(crate::error::AuthError::CertPrincipalMismatch);
|
||||
}
|
||||
|
||||
check_critical_options(cert, ca_entry, client_ip)?;
|
||||
check_extensions(cert, ca_entry)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_critical_options(
|
||||
cert: &russh::keys::Certificate,
|
||||
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let ca_has_no_pty = ca_entry.options.iter().any(|o| o == "no-pty");
|
||||
|
||||
for (name, data) in cert.critical_options().iter() {
|
||||
match name.as_str() {
|
||||
"source-address" => {
|
||||
if !check_source_address(data, client_ip) {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
"force-command" => {}
|
||||
"no-pty" => {}
|
||||
_ => {
|
||||
let _ = ca_has_no_pty;
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_extensions(
|
||||
cert: &russh::keys::Certificate,
|
||||
ca_entry: &crate::auth::keys::CertAuthorityEntry,
|
||||
) -> Result<(), crate::error::AuthError> {
|
||||
let ca_permit_port_forwarding = ca_entry
|
||||
.options
|
||||
.iter()
|
||||
.any(|o| o == "permit-port-forwarding");
|
||||
|
||||
if ca_permit_port_forwarding {
|
||||
let cert_allows = cert
|
||||
.extensions()
|
||||
.iter()
|
||||
.any(|(n, _)| n == "permit-port-forwarding");
|
||||
if !cert_allows {
|
||||
return Err(crate::error::AuthError::CertInvalid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_source_address(allowed: &str, client_ip: Option<std::net::IpAddr>) -> bool {
|
||||
use ipnetwork::IpNetwork;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
let Some(ip) = client_ip else {
|
||||
return false;
|
||||
};
|
||||
|
||||
for pattern in allowed.split(',') {
|
||||
let pattern = pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(cidr) = IpNetwork::from_str(pattern) {
|
||||
if cidr.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(net_ip) = IpAddr::from_str(pattern) {
|
||||
if net_ip == ip {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AuthPolicy {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AuthPolicy")
|
||||
.field("authorized_keys_count", &self.authorized_keys.len())
|
||||
.field("cert_authorities_count", &self.cert_authorities.len())
|
||||
.field("api_keys_count", &self.api_keys.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AuthPolicy {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
authorized_keys: self.authorized_keys.clone(),
|
||||
cert_authorities: self.cert_authorities.clone(),
|
||||
api_keys: self.api_keys.clone(),
|
||||
encoded_keys: self.encoded_keys.clone(),
|
||||
fingerprint_to_key: self.fingerprint_to_key.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
pub max_connections_per_ip: usize,
|
||||
pub max_auth_attempts: usize,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_connections_per_ip: 0,
|
||||
max_auth_attempts: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub struct DynamicConfig {
|
||||
pub auth: AuthPolicy,
|
||||
pub forwarding: ForwardingPolicy,
|
||||
pub rate_limits: RateLimitConfig,
|
||||
pub credentials: HashMap<String, CredentialSet>,
|
||||
}
|
||||
|
||||
impl DynamicConfig {
|
||||
pub fn new(auth: AuthPolicy) -> Self {
|
||||
Self {
|
||||
auth,
|
||||
forwarding: ForwardingPolicy::allow_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_parts(
|
||||
auth: AuthPolicy,
|
||||
forwarding: ForwardingPolicy,
|
||||
rate_limits: RateLimitConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
auth,
|
||||
forwarding,
|
||||
rate_limits,
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_forwarding_policy(mut self, policy: ForwardingPolicy) -> Self {
|
||||
self.forwarding = policy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_rate_limits(mut self, limits: RateLimitConfig) -> Self {
|
||||
self.rate_limits = limits;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_credentials(mut self, credentials: HashMap<String, CredentialSet>) -> Self {
|
||||
self.credentials = credentials;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DynamicConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::allow_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConfigReloadHandle {
|
||||
pub(crate) dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigReloadHandle {
|
||||
pub fn reload(&self, new_config: DynamicConfig) {
|
||||
self.dynamic.store(Arc::new(new_config));
|
||||
}
|
||||
|
||||
pub fn dynamic(&self) -> Arc<DynamicConfig> {
|
||||
self.dynamic.load_full()
|
||||
}
|
||||
|
||||
pub fn dynamic_arc(&self) -> Arc<ArcSwap<DynamicConfig>> {
|
||||
Arc::clone(&self.dynamic)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ConfigReloadHandle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ConfigReloadHandle").finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_dynamic_config() -> (Arc<ArcSwap<DynamicConfig>>, ConfigReloadHandle) {
|
||||
let inner = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let handle = ConfigReloadHandle {
|
||||
dynamic: Arc::clone(&inner),
|
||||
};
|
||||
(inner, handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::forwarding::ForwardingAction;
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_allow_all_default() {
|
||||
let policy = ForwardingPolicy::allow_all();
|
||||
assert_eq!(policy.default, ForwardingAction::Allow);
|
||||
assert!(policy.rules.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_deny_all() {
|
||||
let policy = ForwardingPolicy::deny_all();
|
||||
assert_eq!(policy.default, ForwardingAction::Deny);
|
||||
assert!(policy.rules.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_config_default() {
|
||||
let config = DynamicConfig::default();
|
||||
assert_eq!(config.forwarding.default, ForwardingAction::Allow);
|
||||
assert_eq!(config.rate_limits.max_connections_per_ip, 0);
|
||||
assert_eq!(config.rate_limits.max_auth_attempts, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_reload_handle_updates_dynamic() {
|
||||
let (arc_swap, handle) = new_dynamic_config();
|
||||
let initial = arc_swap.load();
|
||||
assert_eq!(initial.forwarding.default, ForwardingAction::Allow);
|
||||
|
||||
let new_config = DynamicConfig {
|
||||
auth: AuthPolicy::empty(),
|
||||
forwarding: ForwardingPolicy::deny_all(),
|
||||
rate_limits: RateLimitConfig::default(),
|
||||
credentials: HashMap::new(),
|
||||
};
|
||||
handle.reload(new_config);
|
||||
|
||||
let updated = arc_swap.load();
|
||||
assert_eq!(updated.forwarding.default, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_config_with_forwarding_policy_builder() {
|
||||
let config = DynamicConfig::new(AuthPolicy::empty())
|
||||
.with_forwarding_policy(ForwardingPolicy::deny_all());
|
||||
assert_eq!(config.forwarding.default, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limit_config_custom() {
|
||||
let limits = RateLimitConfig {
|
||||
max_connections_per_ip: 5,
|
||||
max_auth_attempts: 3,
|
||||
};
|
||||
assert_eq!(limits.max_connections_per_ip, 5);
|
||||
assert_eq!(limits.max_auth_attempts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_action_equality() {
|
||||
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
|
||||
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
|
||||
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_empty_rejects_all() {
|
||||
let policy = AuthPolicy::empty();
|
||||
let key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key =
|
||||
russh::keys::parse_public_key_base64(key_text.split_whitespace().nth(1).unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
policy.authenticate_publickey(&other_ssh_key),
|
||||
Err(crate::error::AuthError::KeyRejected)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_debug_redacts_keys() {
|
||||
let policy = AuthPolicy::empty();
|
||||
let debug_str = format!("{:?}", policy);
|
||||
assert!(debug_str.contains("authorized_keys_count"));
|
||||
assert!(debug_str.contains("cert_authorities_count"));
|
||||
assert!(debug_str.contains("api_keys_count"));
|
||||
}
|
||||
|
||||
fn compute_api_key_hash(token: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("sha256:{}", hex::encode(result))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_valid_authenticates() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
let identity = policy.resolve_api_key(token);
|
||||
assert!(identity.is_some());
|
||||
let identity = identity.unwrap();
|
||||
assert_eq!(identity.id, "alk_test");
|
||||
assert_eq!(identity.scopes, vec!["relay:connect"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_expired_rejected() {
|
||||
let token = "alk_expiredkey1";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_expi".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "expired key".to_string(),
|
||||
expires_at: Some(1),
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key(token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_wrong_hash_rejected() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_test".to_string(),
|
||||
hash: "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||
.to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "bad hash".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key("alk_testsecret123").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_unknown_prefix_falls_through() {
|
||||
let token = "alk_testsecret123";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_other".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "other key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
assert!(policy.resolve_api_key(token).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_scopes_propagate() {
|
||||
let token = "alk_scopesecret";
|
||||
let hash = compute_api_key_hash(token);
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_sco".to_string(),
|
||||
hash,
|
||||
scopes: vec!["relay:connect".to_string(), "secrets:derive".to_string()],
|
||||
description: "scoped key".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy =
|
||||
AuthPolicy::with_api_keys(std::collections::HashSet::new(), Vec::new(), vec![entry]);
|
||||
let identity = policy.resolve_api_key(token).unwrap();
|
||||
assert_eq!(identity.scopes, vec!["relay:connect", "secrets:derive"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_api_key_prefix_returns_none() {
|
||||
let policy = AuthPolicy::empty();
|
||||
assert!(policy.resolve_api_key("bearer-some-token").is_none());
|
||||
assert!(policy.resolve_api_key("regular-token").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_entry_default_empty() {
|
||||
let config = DynamicConfig::default();
|
||||
assert!(config.auth.api_keys.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_policy_with_api_keys_preserves_entries() {
|
||||
let entry = ApiKeyEntry {
|
||||
prefix: "alk_abc".to_string(),
|
||||
hash: "sha256:abcdef".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
description: "test".to_string(),
|
||||
expires_at: None,
|
||||
};
|
||||
let policy = AuthPolicy::with_api_keys(
|
||||
std::collections::HashSet::new(),
|
||||
Vec::new(),
|
||||
vec![entry.clone()],
|
||||
);
|
||||
assert_eq!(policy.api_keys.len(), 1);
|
||||
assert_eq!(policy.api_keys[0], entry);
|
||||
}
|
||||
}
|
||||
@@ -1,534 +0,0 @@
|
||||
//! Forwarding policy engine for per-identity and per-transport access control.
|
||||
//!
|
||||
//! See [ADR-031](docs/architecture/decisions/031-forwarding-policy.md).
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::ops::Range;
|
||||
use std::str::FromStr;
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
|
||||
use crate::auth::identity::Identity;
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ForwardingAction {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub enum TargetPattern {
|
||||
Any,
|
||||
Host(String),
|
||||
Cidr(IpNetwork),
|
||||
PortRange(String, Range<u16>),
|
||||
AlknetPrefix,
|
||||
}
|
||||
|
||||
impl TargetPattern {
|
||||
pub fn matches(&self, target: &str, port: u16) -> bool {
|
||||
match self {
|
||||
TargetPattern::Any => true,
|
||||
TargetPattern::Host(pattern) => match_host_pattern(pattern, target),
|
||||
TargetPattern::Cidr(network) => match_cidr(network, target),
|
||||
TargetPattern::PortRange(host_pattern, port_range) => {
|
||||
match_host_pattern(host_pattern, target) && port_range.contains(&port)
|
||||
}
|
||||
TargetPattern::AlknetPrefix => {
|
||||
target.starts_with(crate::server::control_channel::ALKNET_PREFIX)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn match_host_pattern(pattern: &str, target: &str) -> bool {
|
||||
if pattern == target {
|
||||
return true;
|
||||
}
|
||||
if pattern.contains('*') {
|
||||
if let Some(pos) = pattern.find('*') {
|
||||
let prefix = &pattern[..pos];
|
||||
let suffix = &pattern[pos + 1..];
|
||||
return target.starts_with(prefix)
|
||||
&& target.ends_with(suffix)
|
||||
&& target.len() >= prefix.len() + suffix.len();
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn match_cidr(network: &IpNetwork, target: &str) -> bool {
|
||||
let Ok(addr) = IpAddr::from_str(target) else {
|
||||
return false;
|
||||
};
|
||||
network.contains(addr)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[non_exhaustive]
|
||||
pub struct ForwardingRule {
|
||||
pub target: TargetPattern,
|
||||
pub action: ForwardingAction,
|
||||
pub principals: Vec<String>,
|
||||
pub transports: Vec<TransportKind>,
|
||||
}
|
||||
|
||||
impl ForwardingRule {
|
||||
pub fn new(
|
||||
target: TargetPattern,
|
||||
action: ForwardingAction,
|
||||
principals: Vec<String>,
|
||||
transports: Vec<TransportKind>,
|
||||
) -> Self {
|
||||
Self {
|
||||
target,
|
||||
action,
|
||||
principals,
|
||||
transports,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ForwardingRule {
|
||||
fn matches_principal(&self, identity: &Identity) -> bool {
|
||||
if self.principals.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.principals
|
||||
.iter()
|
||||
.any(|p| p == &identity.id || identity.scopes.contains(p))
|
||||
}
|
||||
|
||||
fn matches_transport(&self, transport: &TransportKind) -> bool {
|
||||
if self.transports.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.transports.contains(transport)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ForwardingPolicy {
|
||||
pub default: ForwardingAction,
|
||||
pub rules: Vec<ForwardingRule>,
|
||||
}
|
||||
|
||||
impl ForwardingPolicy {
|
||||
pub fn allow_all() -> Self {
|
||||
Self {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deny_all() -> Self {
|
||||
Self {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(
|
||||
&self,
|
||||
target: &str,
|
||||
port: u16,
|
||||
identity: &Identity,
|
||||
transport: TransportKind,
|
||||
) -> bool {
|
||||
for rule in &self.rules {
|
||||
if rule.target.matches(target, port)
|
||||
&& rule.matches_principal(identity)
|
||||
&& rule.matches_transport(&transport)
|
||||
{
|
||||
return rule.action == ForwardingAction::Allow;
|
||||
}
|
||||
}
|
||||
self.default == ForwardingAction::Allow
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_identity(id: &str, scopes: Vec<&str>) -> Identity {
|
||||
Identity {
|
||||
id: id.to_string(),
|
||||
scopes: scopes.into_iter().map(|s| s.to_string()).collect(),
|
||||
resources: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_action_equality() {
|
||||
assert_eq!(ForwardingAction::Allow, ForwardingAction::Allow);
|
||||
assert_eq!(ForwardingAction::Deny, ForwardingAction::Deny);
|
||||
assert_ne!(ForwardingAction::Allow, ForwardingAction::Deny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_all_allows_everything() {
|
||||
let policy = ForwardingPolicy::allow_all();
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"10.0.0.1",
|
||||
22,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_all_denies_everything() {
|
||||
let policy = ForwardingPolicy::deny_all();
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check(
|
||||
"10.0.0.1",
|
||||
22,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_allowlist() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("allowed.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check("denied.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_blocklist() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check("allowed.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_ordering() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
},
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
},
|
||||
],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("blocked.example.com", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_principals_matches_all() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity1 = make_identity("user1", vec![]);
|
||||
let identity2 = make_identity("user2", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity1, TransportKind::Tcp));
|
||||
assert!(policy.check("example.com", 80, &identity2, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn principal_matching_by_id() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["SHA256:abc123".to_string()],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let allowed = make_identity("SHA256:abc123", vec![]);
|
||||
let denied = make_identity("SHA256:other", vec![]);
|
||||
assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp));
|
||||
assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn principal_matching_by_scope() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["admin".to_string()],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let allowed = make_identity("user1", vec!["admin"]);
|
||||
let denied = make_identity("user2", vec!["viewer"]);
|
||||
assert!(policy.check("example.com", 80, &allowed, TransportKind::Tcp));
|
||||
assert!(!policy.check("example.com", 80, &denied, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_transports_matches_all() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("example.com", 80, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
80,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
80,
|
||||
&identity,
|
||||
TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_matching() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::Tls { server_name: None }],
|
||||
}],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("example.com", 443, &identity, TransportKind::Tcp));
|
||||
assert!(policy.check(
|
||||
"example.com",
|
||||
443,
|
||||
&identity,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_any_matches_all() {
|
||||
let pattern = TargetPattern::Any;
|
||||
assert!(pattern.matches("example.com", 80));
|
||||
assert!(pattern.matches("10.0.0.1", 22));
|
||||
assert!(pattern.matches("alknet-control", 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_exact_match() {
|
||||
let pattern = TargetPattern::Host("example.com".to_string());
|
||||
assert!(pattern.matches("example.com", 80));
|
||||
assert!(!pattern.matches("other.com", 80));
|
||||
assert!(!pattern.matches("sub.example.com", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_match() {
|
||||
let pattern = TargetPattern::Host("*.example.com".to_string());
|
||||
assert!(pattern.matches("sub.example.com", 80));
|
||||
assert!(pattern.matches("a.example.com", 443));
|
||||
assert!(!pattern.matches("example.com", 80));
|
||||
assert!(!pattern.matches("xsub.example.com.org", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_prefix() {
|
||||
let pattern = TargetPattern::Host("db-*".to_string());
|
||||
assert!(pattern.matches("db-primary", 5432));
|
||||
assert!(pattern.matches("db-replica", 5432));
|
||||
assert!(!pattern.matches("web-primary", 5432));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_host_glob_suffix() {
|
||||
let pattern = TargetPattern::Host("*.internal".to_string());
|
||||
assert!(pattern.matches("app.internal", 8080));
|
||||
assert!(pattern.matches("db.internal", 5432));
|
||||
assert!(!pattern.matches("app.external", 80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_cidr_matches_ip() {
|
||||
let network: IpNetwork = "10.0.0.0/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(pattern.matches("10.0.0.1", 22));
|
||||
assert!(pattern.matches("10.255.255.255", 22));
|
||||
assert!(!pattern.matches("192.168.1.1", 22));
|
||||
assert!(!pattern.matches("not-an-ip", 22));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_cidr_ipv6() {
|
||||
let network: IpNetwork = "fd00::/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(pattern.matches("fd00::1", 22));
|
||||
assert!(!pattern.matches("10.0.0.1", 22));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_port_range_matches() {
|
||||
let pattern = TargetPattern::PortRange("localhost".to_string(), 8080..8090);
|
||||
assert!(pattern.matches("localhost", 8080));
|
||||
assert!(pattern.matches("localhost", 8085));
|
||||
assert!(pattern.matches("localhost", 8089));
|
||||
assert!(!pattern.matches("localhost", 8079));
|
||||
assert!(!pattern.matches("localhost", 8090));
|
||||
assert!(!pattern.matches("otherhost", 8080));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_port_range_with_glob() {
|
||||
let pattern = TargetPattern::PortRange("*.internal".to_string(), 3000..4000);
|
||||
assert!(pattern.matches("app.internal", 3000));
|
||||
assert!(pattern.matches("app.internal", 3999));
|
||||
assert!(!pattern.matches("app.internal", 2999));
|
||||
assert!(!pattern.matches("app.internal", 4000));
|
||||
assert!(!pattern.matches("app.external", 3000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_pattern_alknet_prefix() {
|
||||
let pattern = TargetPattern::AlknetPrefix;
|
||||
assert!(pattern.matches("alknet-control", 0));
|
||||
assert!(pattern.matches("alknet-status", 0));
|
||||
assert!(pattern.matches("alknet-", 0));
|
||||
assert!(!pattern.matches("example.com", 0));
|
||||
assert!(!pattern.matches("alknet.example.com", 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_fallthrough_allow() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check("anything", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_fallthrough_deny() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(!policy.check("anything", 80, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_principal_and_transport_matching() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("restricted.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["admin".to_string()],
|
||||
transports: vec![TransportKind::Tls { server_name: None }],
|
||||
}],
|
||||
};
|
||||
let admin = make_identity("admin-user", vec!["admin"]);
|
||||
let viewer = make_identity("viewer-user", vec!["viewer"]);
|
||||
assert!(policy.check(
|
||||
"restricted.example.com",
|
||||
443,
|
||||
&admin,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
assert!(!policy.check("restricted.example.com", 443, &admin, TransportKind::Tcp));
|
||||
assert!(!policy.check(
|
||||
"restricted.example.com",
|
||||
443,
|
||||
&viewer,
|
||||
TransportKind::Tls { server_name: None }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn webtransport_restricted_to_alknet() {
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Allow,
|
||||
rules: vec![
|
||||
ForwardingRule {
|
||||
target: TargetPattern::AlknetPrefix,
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::WebTransport { server_name: None }],
|
||||
},
|
||||
ForwardingRule {
|
||||
target: TargetPattern::Any,
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![TransportKind::WebTransport { server_name: None }],
|
||||
},
|
||||
],
|
||||
};
|
||||
let identity = make_identity("user1", vec![]);
|
||||
assert!(policy.check(
|
||||
"alknet-control",
|
||||
0,
|
||||
&identity,
|
||||
TransportKind::WebTransport { server_name: None }
|
||||
));
|
||||
assert!(!policy.check(
|
||||
"example.com",
|
||||
443,
|
||||
&identity,
|
||||
TransportKind::WebTransport { server_name: None }
|
||||
));
|
||||
assert!(policy.check("example.com", 443, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cidr_does_not_match_hostname() {
|
||||
let network: IpNetwork = "10.0.0.0/8".parse().unwrap();
|
||||
let pattern = TargetPattern::Cidr(network);
|
||||
assert!(!pattern.matches("example.com", 22));
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
pub mod config_service;
|
||||
pub mod dynamic_config;
|
||||
pub mod forwarding;
|
||||
pub mod static_config;
|
||||
|
||||
pub use config_service::ConfigServiceImpl;
|
||||
pub use dynamic_config::{
|
||||
new_dynamic_config, ApiKeyEntry, AuthPolicy, ConfigReloadHandle, DynamicConfig,
|
||||
RateLimitConfig, API_KEY_PREFIX,
|
||||
};
|
||||
pub use forwarding::{ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern};
|
||||
pub use static_config::StaticConfig;
|
||||
@@ -1,281 +0,0 @@
|
||||
//! Static (immutable) server configuration resolved at startup.
|
||||
//!
|
||||
//! See [ADR-030](docs/architecture/decisions/030-dynamic-config.md).
|
||||
|
||||
use crate::interface::StreamInterfaceKind;
|
||||
use crate::server::handler::{ProxyConfig, ProxyMode};
|
||||
use crate::server::serve::{ListenerConfig, ServeTransportMode, StreamListenerConfig};
|
||||
use crate::transport::TransportKind;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub struct StaticConfig {
|
||||
pub transport_mode: ServeTransportMode,
|
||||
pub listen_addr: String,
|
||||
pub tls_cert: Option<String>,
|
||||
pub tls_key: Option<String>,
|
||||
pub acme_domain: Option<String>,
|
||||
pub stealth: bool,
|
||||
pub host_key: russh::keys::PrivateKey,
|
||||
pub host_key_algorithm: russh::keys::Algorithm,
|
||||
pub max_auth_attempts: usize,
|
||||
pub max_connections_per_ip: usize,
|
||||
pub proxy_config: Option<ProxyConfig>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub listeners: Vec<ListenerConfig>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StaticConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("StaticConfig")
|
||||
.field("transport_mode", &self.transport_mode)
|
||||
.field("listen_addr", &self.listen_addr)
|
||||
.field("tls_cert", &self.tls_cert.as_ref().map(|_| "<redacted>"))
|
||||
.field("tls_key", &self.tls_key.as_ref().map(|_| "<redacted>"))
|
||||
.field("acme_domain", &self.acme_domain)
|
||||
.field("stealth", &self.stealth)
|
||||
.field("host_key_algorithm", &self.host_key_algorithm)
|
||||
.field("max_auth_attempts", &self.max_auth_attempts)
|
||||
.field("max_connections_per_ip", &self.max_connections_per_ip)
|
||||
.field("proxy_config", &self.proxy_config)
|
||||
.field("iroh_relay", &self.iroh_relay)
|
||||
.field("listeners", &self.listeners)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticConfig {
|
||||
pub fn from_serve_options(
|
||||
opts: crate::server::serve::ServeOptions,
|
||||
) -> Result<(Self, crate::config::DynamicConfig), crate::error::ConfigError> {
|
||||
opts.validate()?;
|
||||
|
||||
let host_key = crate::auth::keys::load_private_key(opts.key.clone())?;
|
||||
let host_key_algorithm = host_key.algorithm();
|
||||
|
||||
let auth_config = crate::auth::ServerAuthConfig::from_keys_and_ca(
|
||||
opts.authorized_keys.clone(),
|
||||
opts.cert_authority.clone(),
|
||||
)?;
|
||||
|
||||
let auth_policy = crate::config::AuthPolicy::from_server_auth_config(auth_config);
|
||||
|
||||
let dynamic = crate::config::DynamicConfig::new(auth_policy);
|
||||
|
||||
let proxy_config = parse_proxy_config(opts.proxy.as_deref())?;
|
||||
|
||||
let listeners = if let Some(listeners) = opts.listeners {
|
||||
listeners
|
||||
} else {
|
||||
vec![ListenerConfig::Stream {
|
||||
config: StreamListenerConfig {
|
||||
transport_kind: match opts.transport_mode {
|
||||
ServeTransportMode::Tcp => TransportKind::Tcp,
|
||||
ServeTransportMode::Tls => TransportKind::Tls { server_name: None },
|
||||
ServeTransportMode::Iroh => TransportKind::Iroh {
|
||||
endpoint_id: String::new(),
|
||||
},
|
||||
},
|
||||
interface: StreamInterfaceKind::Ssh,
|
||||
listen_addr: opts.listen_addr.clone(),
|
||||
tls_cert: opts.tls_cert.clone(),
|
||||
tls_key: opts.tls_key.clone(),
|
||||
acme_domain: opts.acme_domain.clone(),
|
||||
stealth: opts.stealth,
|
||||
iroh_relay: opts.iroh_relay.clone(),
|
||||
},
|
||||
}]
|
||||
};
|
||||
|
||||
let static_config = StaticConfig {
|
||||
transport_mode: opts.transport_mode,
|
||||
listen_addr: opts.listen_addr,
|
||||
tls_cert: opts.tls_cert,
|
||||
tls_key: opts.tls_key,
|
||||
acme_domain: opts.acme_domain,
|
||||
stealth: opts.stealth,
|
||||
host_key,
|
||||
host_key_algorithm,
|
||||
max_auth_attempts: opts.max_auth_attempts,
|
||||
max_connections_per_ip: opts.max_connections_per_ip,
|
||||
proxy_config,
|
||||
iroh_relay: opts.iroh_relay,
|
||||
listeners,
|
||||
};
|
||||
|
||||
Ok((static_config, dynamic))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_proxy_config(
|
||||
proxy: Option<&str>,
|
||||
) -> Result<Option<ProxyConfig>, crate::error::ConfigError> {
|
||||
match proxy {
|
||||
None => Ok(None),
|
||||
Some(url) => {
|
||||
if let Some(rest) = url.strip_prefix("socks5://") {
|
||||
let addr: SocketAddr =
|
||||
rest.parse()
|
||||
.map_err(|e| crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!("invalid socks5 proxy address '{}': {}", rest, e),
|
||||
})?;
|
||||
Ok(Some(ProxyConfig {
|
||||
mode: ProxyMode::Socks5(addr),
|
||||
}))
|
||||
} else if let Some(rest) = url.strip_prefix("http://") {
|
||||
let addr: SocketAddr =
|
||||
rest.parse()
|
||||
.map_err(|e| crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!(
|
||||
"invalid http connect proxy address '{}': {}",
|
||||
rest, e
|
||||
),
|
||||
})?;
|
||||
Ok(Some(ProxyConfig {
|
||||
mode: ProxyMode::HttpConnect(addr),
|
||||
}))
|
||||
} else {
|
||||
Err(crate::error::ConfigError::ProxyConfigInvalid {
|
||||
message: format!("unsupported proxy URL scheme: {}", url),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::server::serve::ServeOptions;
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn make_key_source() -> KeySource {
|
||||
KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn make_authorized_keys_source() -> KeySource {
|
||||
KeySource::Memory(ED25519_PUBLIC_KEY.as_bytes().to_vec())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_socks5() {
|
||||
let config = parse_proxy_config(Some("socks5://127.0.0.1:9050")).unwrap();
|
||||
assert!(config.is_some());
|
||||
match config.unwrap().mode {
|
||||
ProxyMode::Socks5(addr) => {
|
||||
assert_eq!(addr, "127.0.0.1:9050".parse().unwrap());
|
||||
}
|
||||
_ => panic!("expected Socks5"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_http() {
|
||||
let config = parse_proxy_config(Some("http://127.0.0.1:8080")).unwrap();
|
||||
assert!(config.is_some());
|
||||
match config.unwrap().mode {
|
||||
ProxyMode::HttpConnect(addr) => {
|
||||
assert_eq!(addr, "127.0.0.1:8080".parse().unwrap());
|
||||
}
|
||||
_ => panic!("expected HttpConnect"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_none() {
|
||||
assert!(parse_proxy_config(None).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_invalid_scheme() {
|
||||
let result = parse_proxy_config(Some("ftp://127.0.0.1:9050"));
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("unsupported proxy URL scheme"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_config_invalid_address() {
|
||||
let result = parse_proxy_config(Some("socks5://not-an-address"));
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("invalid socks5 proxy address"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_basic() {
|
||||
let opts =
|
||||
ServeOptions::new(make_key_source()).authorized_keys(make_authorized_keys_source());
|
||||
let (static_config, dynamic) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert_eq!(static_config.listen_addr, "0.0.0.0:22");
|
||||
assert_eq!(static_config.max_auth_attempts, 10);
|
||||
assert!(dynamic.auth.authorized_keys.len() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_with_proxy() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("socks5://127.0.0.1:9050");
|
||||
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert!(static_config.proxy_config.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_with_listeners() {
|
||||
let listeners = vec![ListenerConfig::tcp("0.0.0.0:22")];
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.listeners(listeners);
|
||||
let (static_config, _) = StaticConfig::from_serve_options(opts).unwrap();
|
||||
assert_eq!(static_config.listeners.len(), 1);
|
||||
match &static_config.listeners[0] {
|
||||
ListenerConfig::Stream { config } => {
|
||||
assert_eq!(config.transport_kind, TransportKind::Tcp);
|
||||
}
|
||||
_ => panic!("expected Stream variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_invalid_proxy_returns_err() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("ftp://bad-scheme");
|
||||
let result = StaticConfig::from_serve_options(opts);
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("unsupported proxy URL scheme"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_config_from_serve_options_malformed_proxy_address_returns_err() {
|
||||
let opts = ServeOptions::new(make_key_source())
|
||||
.authorized_keys(make_authorized_keys_source())
|
||||
.proxy("socks5://not-a-valid-addr");
|
||||
let result = StaticConfig::from_serve_options(opts);
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
crate::error::ConfigError::ProxyConfigInvalid { message } => {
|
||||
assert!(message.contains("invalid socks5 proxy address"));
|
||||
}
|
||||
e => panic!("expected ProxyConfigInvalid, got {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum CredentialSet {
|
||||
ApiKey {
|
||||
header_name: String,
|
||||
token: String,
|
||||
},
|
||||
Basic {
|
||||
username: String,
|
||||
password: String,
|
||||
},
|
||||
Bearer {
|
||||
token: String,
|
||||
},
|
||||
S3AccessKey {
|
||||
access_key: String,
|
||||
secret_key: String,
|
||||
session_token: Option<String>,
|
||||
},
|
||||
OidcToken {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
Custom {
|
||||
scheme: String,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait CredentialProvider: Send + Sync + 'static {
|
||||
fn get_credentials(&self, service: &str) -> Option<CredentialSet>;
|
||||
fn refresh_credentials(&self, service: &str) -> Option<CredentialSet>;
|
||||
}
|
||||
|
||||
pub struct ConfigCredentialProvider {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
}
|
||||
|
||||
impl ConfigCredentialProvider {
|
||||
pub fn new(dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self { dynamic }
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for ConfigCredentialProvider {
|
||||
fn get_credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
let config = self.dynamic.load();
|
||||
config.credentials.get(service).cloned()
|
||||
}
|
||||
|
||||
fn refresh_credentials(&self, service: &str) -> Option<CredentialSet> {
|
||||
self.get_credentials(service)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecretStoreCredentialProvider;
|
||||
|
||||
impl SecretStoreCredentialProvider {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecretStoreCredentialProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for SecretStoreCredentialProvider {
|
||||
fn get_credentials(&self, _service: &str) -> Option<CredentialSet> {
|
||||
None
|
||||
}
|
||||
|
||||
fn refresh_credentials(&self, _service: &str) -> Option<CredentialSet> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::AuthPolicy;
|
||||
|
||||
fn make_dynamic_with_credentials() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let mut credentials = HashMap::new();
|
||||
credentials.insert(
|
||||
"vast-ai".to_string(),
|
||||
CredentialSet::Bearer {
|
||||
token: "secret-token".to_string(),
|
||||
},
|
||||
);
|
||||
credentials.insert(
|
||||
"custom-service".to_string(),
|
||||
CredentialSet::ApiKey {
|
||||
header_name: "X-API-Key".to_string(),
|
||||
token: "api-key-123".to_string(),
|
||||
},
|
||||
);
|
||||
let config = DynamicConfig::new(AuthPolicy::empty()).with_credentials(credentials);
|
||||
Arc::new(ArcSwap::new(Arc::new(config)))
|
||||
}
|
||||
|
||||
fn make_dynamic_empty() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let config = DynamicConfig::default();
|
||||
Arc::new(ArcSwap::new(Arc::new(config)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_configured_credentials() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("vast-ai");
|
||||
assert!(creds.is_some());
|
||||
match creds.unwrap() {
|
||||
CredentialSet::Bearer { token } => assert_eq!(token, "secret-token"),
|
||||
_ => panic!("expected Bearer variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_api_key_variant() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("custom-service");
|
||||
assert!(creds.is_some());
|
||||
match creds.unwrap() {
|
||||
CredentialSet::ApiKey { header_name, token } => {
|
||||
assert_eq!(header_name, "X-API-Key");
|
||||
assert_eq!(token, "api-key-123");
|
||||
}
|
||||
_ => panic!("expected ApiKey variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_returns_none_for_unknown_service() {
|
||||
let dynamic = make_dynamic_with_credentials();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("nonexistent");
|
||||
assert!(creds.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_credential_provider_empty_config_returns_none() {
|
||||
let dynamic = make_dynamic_empty();
|
||||
let provider = ConfigCredentialProvider::new(dynamic);
|
||||
let creds = provider.get_credentials("vast-ai");
|
||||
assert!(creds.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_store_credential_provider_returns_none() {
|
||||
let provider = SecretStoreCredentialProvider::new();
|
||||
assert!(provider.get_credentials("vast-ai").is_none());
|
||||
assert!(provider.get_credentials("rustfs").is_none());
|
||||
assert!(provider.get_credentials("gitea").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_store_credential_provider_refresh_returns_none() {
|
||||
let provider = SecretStoreCredentialProvider::new();
|
||||
assert!(provider.refresh_credentials("vast-ai").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_bearer_serialization() {
|
||||
let creds = CredentialSet::Bearer {
|
||||
token: "tok".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_s3_access_key_serialization() {
|
||||
let creds = CredentialSet::S3AccessKey {
|
||||
access_key: "AKIA123".to_string(),
|
||||
secret_key: "secret".to_string(),
|
||||
session_token: Some("session".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_oidc_token_serialization() {
|
||||
let creds = CredentialSet::OidcToken {
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: Some("refresh".to_string()),
|
||||
expires_at: Some(1234567890),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_custom_serialization() {
|
||||
let mut params = HashMap::new();
|
||||
params.insert("key1".to_string(), "val1".to_string());
|
||||
let creds = CredentialSet::Custom {
|
||||
scheme: "X-Custom".to_string(),
|
||||
params,
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_basic_serialization() {
|
||||
let creds = CredentialSet::Basic {
|
||||
username: "user".to_string(),
|
||||
password: "pass".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&creds).unwrap();
|
||||
let deserialized: CredentialSet = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(creds, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_set_clone() {
|
||||
let creds = CredentialSet::Bearer {
|
||||
token: "tok".to_string(),
|
||||
};
|
||||
let cloned = creds.clone();
|
||||
assert_eq!(creds, cloned);
|
||||
}
|
||||
}
|
||||
1859
crates/alknet-core/src/endpoint.rs
Normal file
1859
crates/alknet-core/src/endpoint.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,241 +0,0 @@
|
||||
//! Error types for alknet-core.
|
||||
//!
|
||||
//! Layered error hierarchy:
|
||||
//! - `TransportError` — connection/handshake/timeout errors (trigger reconnection on client)
|
||||
//! - `AuthError` — key rejection, certificate validation failures
|
||||
//! - `ChannelError` — per-channel failures (target unreachable, channel closed)
|
||||
//! - `ConfigError` — invalid configuration (flags, key files, bind failures)
|
||||
//! - `ForwardError` — port forward setup and connection failures
|
||||
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TransportError {
|
||||
#[error("connection failed")]
|
||||
ConnectionFailed,
|
||||
#[error("handshake failed")]
|
||||
HandshakeFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("transport timeout")]
|
||||
Timeout,
|
||||
#[error("proxy failed")]
|
||||
ProxyFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("key rejected")]
|
||||
KeyRejected,
|
||||
#[error("certificate invalid")]
|
||||
CertInvalid,
|
||||
#[error("certificate expired")]
|
||||
CertExpired,
|
||||
#[error("certificate principal mismatch")]
|
||||
CertPrincipalMismatch,
|
||||
#[error("no matching key")]
|
||||
NoMatchingKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelError {
|
||||
#[error("target unreachable")]
|
||||
TargetUnreachable,
|
||||
#[error("proxy connect failed")]
|
||||
ProxyConnectFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("channel closed")]
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("invalid flag: {name}")]
|
||||
InvalidFlag { name: String },
|
||||
#[error("key file not found: {path}")]
|
||||
KeyFileNotFound { path: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("incompatible options")]
|
||||
IncompatibleOptions,
|
||||
#[error("invalid proxy config: {message}")]
|
||||
ProxyConfigInvalid { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ForwardError {
|
||||
#[error("invalid port forward spec: {spec}")]
|
||||
InvalidSpec { spec: String },
|
||||
#[error("bind failed")]
|
||||
BindFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed {
|
||||
#[source]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
},
|
||||
#[error("connect to local target failed")]
|
||||
LocalConnectFailed {
|
||||
#[source]
|
||||
source: io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn transport_error_display() {
|
||||
assert_eq!(
|
||||
TransportError::ConnectionFailed.to_string(),
|
||||
"connection failed"
|
||||
);
|
||||
assert_eq!(
|
||||
TransportError::HandshakeFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "tls failed")
|
||||
}
|
||||
.to_string(),
|
||||
"handshake failed"
|
||||
);
|
||||
assert_eq!(TransportError::Timeout.to_string(), "transport timeout");
|
||||
assert_eq!(
|
||||
TransportError::ProxyFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "proxy err")
|
||||
}
|
||||
.to_string(),
|
||||
"proxy failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_error_display() {
|
||||
assert_eq!(AuthError::KeyRejected.to_string(), "key rejected");
|
||||
assert_eq!(AuthError::CertInvalid.to_string(), "certificate invalid");
|
||||
assert_eq!(AuthError::CertExpired.to_string(), "certificate expired");
|
||||
assert_eq!(
|
||||
AuthError::CertPrincipalMismatch.to_string(),
|
||||
"certificate principal mismatch"
|
||||
);
|
||||
assert_eq!(AuthError::NoMatchingKey.to_string(), "no matching key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_error_display() {
|
||||
assert_eq!(
|
||||
ChannelError::TargetUnreachable.to_string(),
|
||||
"target unreachable"
|
||||
);
|
||||
assert_eq!(
|
||||
ChannelError::ProxyConnectFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||
}
|
||||
.to_string(),
|
||||
"proxy connect failed"
|
||||
);
|
||||
assert_eq!(ChannelError::ChannelClosed.to_string(), "channel closed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_error_display() {
|
||||
assert_eq!(
|
||||
ConfigError::InvalidFlag {
|
||||
name: "--bad".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid flag: --bad"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::KeyFileNotFound {
|
||||
path: "/missing".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"key file not found: /missing"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::BindFailed {
|
||||
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||
}
|
||||
.to_string(),
|
||||
"bind failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::IncompatibleOptions.to_string(),
|
||||
"incompatible options"
|
||||
);
|
||||
assert_eq!(
|
||||
ConfigError::ProxyConfigInvalid {
|
||||
message: "bad proxy".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid proxy config: bad proxy"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_source_chaining() {
|
||||
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
|
||||
let transport_err = TransportError::HandshakeFailed { source: io_err };
|
||||
assert!(transport_err.source().is_some());
|
||||
|
||||
let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "proxy");
|
||||
let channel_err = ChannelError::ProxyConnectFailed { source: io_err };
|
||||
assert!(channel_err.source().is_some());
|
||||
|
||||
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "addr");
|
||||
let config_err = ConfigError::BindFailed { source: io_err };
|
||||
assert!(config_err.source().is_some());
|
||||
|
||||
let plain = AuthError::KeyRejected;
|
||||
assert!(plain.source().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_error_display() {
|
||||
assert_eq!(
|
||||
ForwardError::InvalidSpec {
|
||||
spec: "bad".to_string()
|
||||
}
|
||||
.to_string(),
|
||||
"invalid port forward spec: bad"
|
||||
);
|
||||
assert_eq!(
|
||||
ForwardError::BindFailed {
|
||||
source: io::Error::new(io::ErrorKind::AddrInUse, "in use")
|
||||
}
|
||||
.to_string(),
|
||||
"bind failed"
|
||||
);
|
||||
assert_eq!(
|
||||
ForwardError::LocalConnectFailed {
|
||||
source: io::Error::new(io::ErrorKind::ConnectionRefused, "refused")
|
||||
}
|
||||
.to_string(),
|
||||
"connect to local target failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_error_source_chaining() {
|
||||
let io_err = io::Error::new(io::ErrorKind::AddrInUse, "in use");
|
||||
let forward_err = ForwardError::BindFailed { source: io_err };
|
||||
assert!(forward_err.source().is_some());
|
||||
|
||||
let plain = ForwardError::InvalidSpec {
|
||||
spec: "bad".to_string(),
|
||||
};
|
||||
assert!(plain.source().is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
use axum::extract::Request;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
|
||||
use crate::auth::{AuthToken, Identity, IdentityProvider};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct IdentityExt(pub Identity);
|
||||
|
||||
pub async fn auth_middleware(
|
||||
axum::extract::State(identity_provider): axum::extract::State<
|
||||
std::sync::Arc<dyn IdentityProvider>,
|
||||
>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let auth_header = request
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
let token_str = match auth_header {
|
||||
Some(h) if h.starts_with("Bearer ") => &h[7..],
|
||||
_ => {
|
||||
return axum::http::StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let token = AuthToken {
|
||||
raw: token_str.as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
match identity_provider.resolve_from_token(&token) {
|
||||
Some(identity) => {
|
||||
request.extensions_mut().insert(IdentityExt(identity));
|
||||
next.run(request).await
|
||||
}
|
||||
None => axum::http::StatusCode::UNAUTHORIZED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request as HttpRequest, StatusCode};
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
struct MockIdentityProvider {
|
||||
valid_token: String,
|
||||
identity: Identity,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
let token_str = String::from_utf8_lossy(&token.raw);
|
||||
if token_str == self.valid_token {
|
||||
Some(self.identity.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_provider(valid_token: &str) -> Arc<dyn IdentityProvider> {
|
||||
let identity = Identity {
|
||||
id: "test-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
Arc::new(MockIdentityProvider {
|
||||
valid_token: valid_token.to_string(),
|
||||
identity,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_extracts_bearer_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|request: Request| async move {
|
||||
let has_identity = request.extensions().get::<IdentityExt>().is_some();
|
||||
if has_identity {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_validtoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_returns_401_for_missing_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route("/test", get(|| async { StatusCode::OK.into_response() }))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_returns_401_for_invalid_token() {
|
||||
let provider = make_provider("alk_validtoken1");
|
||||
let app = Router::new()
|
||||
.route("/test", get(|| async { StatusCode::OK.into_response() }))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_wrongtoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_middleware_attaches_identity_to_extensions() {
|
||||
let provider = make_provider("alk_testidentity1");
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|request: Request| async move {
|
||||
let identity = request.extensions().get::<IdentityExt>().unwrap();
|
||||
identity.0.id.clone()
|
||||
}),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
provider,
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/test")
|
||||
.header("authorization", "Bearer alk_testidentity1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
|
||||
assert_eq!(&body[..], b"test-user");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
pub mod auth;
|
||||
pub mod router;
|
||||
|
||||
pub use auth::IdentityExt;
|
||||
pub use router::{build_router, serve_connection};
|
||||
@@ -1,150 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Router;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use hyper_util::service::TowerToHyperService;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
use crate::http::auth::auth_middleware;
|
||||
|
||||
async fn default_404() -> impl IntoResponse {
|
||||
axum::http::StatusCode::NOT_FOUND
|
||||
}
|
||||
|
||||
pub fn build_router(identity_provider: Arc<dyn IdentityProvider>) -> Router {
|
||||
Router::new()
|
||||
.fallback(default_404)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
identity_provider,
|
||||
auth_middleware,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn serve_connection<S>(stream: S, identity_provider: Arc<dyn IdentityProvider>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let app = build_router(identity_provider);
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let hyper_service = TowerToHyperService::new(app.into_service::<hyper::body::Incoming>());
|
||||
|
||||
let result = Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(io, hyper_service)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::debug!("http connection error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_connection_from_reader<S>(
|
||||
reader: BufReader<S>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
serve_connection(reader, identity_provider).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::{AuthToken, Identity};
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request as HttpRequest, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
struct NullIdentityProvider;
|
||||
|
||||
impl IdentityProvider for NullIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_404_handler_returns_not_found() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.header("authorization", "Bearer alk_sometoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_auth_returns_401_before_404() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_auth_returns_401_before_404() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NullIdentityProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/anything")
|
||||
.header("authorization", "Bearer alk_sometoken1")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unmatched_route_returns_404_with_valid_auth() {
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockValidProvider);
|
||||
let app = build_router(provider);
|
||||
|
||||
let req = HttpRequest::builder()
|
||||
.uri("/v1/unknown/op")
|
||||
.header("authorization", "Bearer alk_valid")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
struct MockValidProvider;
|
||||
|
||||
impl IdentityProvider for MockValidProvider {
|
||||
fn resolve_from_fingerprint(&self, _fingerprint: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<Identity> {
|
||||
Some(Identity {
|
||||
id: "test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,270 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use russh::keys::PrivateKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
use crate::config::DynamicConfig;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum StreamInterfaceKind {
|
||||
Ssh,
|
||||
RawFraming,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StreamInterfaceKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamInterfaceKind::Ssh => write!(f, "ssh"),
|
||||
StreamInterfaceKind::RawFraming => write!(f, "raw-framing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum MessageInterfaceKind {
|
||||
Http,
|
||||
Dns,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MessageInterfaceKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageInterfaceKind::Http => write!(f, "http"),
|
||||
MessageInterfaceKind::Dns => write!(f, "dns"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum InterfaceConfig {
|
||||
Ssh(SshInterfaceConfig),
|
||||
RawFraming(RawFramingConfig),
|
||||
}
|
||||
|
||||
impl InterfaceConfig {
|
||||
pub fn kind(&self) -> StreamInterfaceKind {
|
||||
#[allow(unreachable_patterns)]
|
||||
match self {
|
||||
InterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh,
|
||||
InterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming,
|
||||
_ => StreamInterfaceKind::Ssh,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum StreamInterfaceConfig {
|
||||
Ssh(SshInterfaceConfig),
|
||||
RawFraming(RawFramingConfig),
|
||||
}
|
||||
|
||||
impl StreamInterfaceConfig {
|
||||
pub fn kind(&self) -> StreamInterfaceKind {
|
||||
match self {
|
||||
StreamInterfaceConfig::Ssh(_) => StreamInterfaceKind::Ssh,
|
||||
StreamInterfaceConfig::RawFraming(_) => StreamInterfaceKind::RawFraming,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StreamInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
StreamInterfaceConfig::Ssh(_) => write!(f, "ssh"),
|
||||
StreamInterfaceConfig::RawFraming(_) => write!(f, "raw-framing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum MessageInterfaceConfig {
|
||||
Http(HttpInterfaceConfig),
|
||||
Dns(DnsInterfaceConfig),
|
||||
}
|
||||
|
||||
impl MessageInterfaceConfig {
|
||||
pub fn kind(&self) -> MessageInterfaceKind {
|
||||
match self {
|
||||
MessageInterfaceConfig::Http(_) => MessageInterfaceKind::Http,
|
||||
MessageInterfaceConfig::Dns(_) => MessageInterfaceKind::Dns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MessageInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageInterfaceConfig::Http(_) => write!(f, "http"),
|
||||
MessageInterfaceConfig::Dns(_) => write!(f, "dns"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshInterfaceConfig {
|
||||
pub auth: Arc<dyn IdentityProvider>,
|
||||
pub forwarding: Arc<ArcSwap<DynamicConfig>>,
|
||||
pub host_key: Arc<PrivateKey>,
|
||||
}
|
||||
|
||||
pub struct RawFramingConfig {
|
||||
pub auth: Arc<dyn IdentityProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct HttpInterfaceConfig {
|
||||
pub bind_addr: std::net::SocketAddr,
|
||||
pub tls: bool,
|
||||
pub stealth: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HttpInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "http {}", self.bind_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct DnsInterfaceConfig {
|
||||
pub bind_addr: std::net::SocketAddr,
|
||||
pub tls: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DnsInterfaceConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "dns {}", self.bind_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::ConfigIdentityProvider;
|
||||
|
||||
#[test]
|
||||
fn stream_interface_kind_display() {
|
||||
assert_eq!(StreamInterfaceKind::Ssh.to_string(), "ssh");
|
||||
assert_eq!(StreamInterfaceKind::RawFraming.to_string(), "raw-framing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_kind_display() {
|
||||
assert_eq!(MessageInterfaceKind::Http.to_string(), "http");
|
||||
assert_eq!(MessageInterfaceKind::Dns.to_string(), "dns");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_interface_config_kind() {
|
||||
let auth = Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
)));
|
||||
let ssh_config = StreamInterfaceConfig::Ssh(SshInterfaceConfig {
|
||||
auth,
|
||||
forwarding: Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))),
|
||||
host_key: Arc::new(
|
||||
russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
});
|
||||
assert_eq!(ssh_config.kind(), StreamInterfaceKind::Ssh);
|
||||
|
||||
let raw_config = StreamInterfaceConfig::RawFraming(RawFramingConfig {
|
||||
auth: Arc::new(ConfigIdentityProvider::new(Arc::new(ArcSwap::new(
|
||||
Arc::new(DynamicConfig::default()),
|
||||
)))),
|
||||
});
|
||||
assert_eq!(raw_config.kind(), StreamInterfaceKind::RawFraming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_config_kind() {
|
||||
let http_config = MessageInterfaceConfig::Http(HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: false,
|
||||
stealth: false,
|
||||
});
|
||||
assert_eq!(http_config.kind(), MessageInterfaceKind::Http);
|
||||
|
||||
let dns_config = MessageInterfaceConfig::Dns(DnsInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:53".parse().unwrap(),
|
||||
tls: false,
|
||||
});
|
||||
assert_eq!(dns_config.kind(), MessageInterfaceKind::Dns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_interface_kind_equality() {
|
||||
assert_eq!(StreamInterfaceKind::Ssh, StreamInterfaceKind::Ssh);
|
||||
assert_eq!(
|
||||
StreamInterfaceKind::RawFraming,
|
||||
StreamInterfaceKind::RawFraming
|
||||
);
|
||||
assert_ne!(StreamInterfaceKind::Ssh, StreamInterfaceKind::RawFraming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_interface_kind_equality() {
|
||||
assert_eq!(MessageInterfaceKind::Http, MessageInterfaceKind::Http);
|
||||
assert_eq!(MessageInterfaceKind::Dns, MessageInterfaceKind::Dns);
|
||||
assert_ne!(MessageInterfaceKind::Http, MessageInterfaceKind::Dns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_framing_config_minimal() {
|
||||
let auth: Arc<dyn IdentityProvider> = Arc::new(ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
)));
|
||||
let _config = RawFramingConfig { auth };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_interface_config_display() {
|
||||
let config = HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: true,
|
||||
stealth: true,
|
||||
};
|
||||
assert_eq!(config.to_string(), "http 127.0.0.1:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_interface_config_display() {
|
||||
let config = DnsInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:53".parse().unwrap(),
|
||||
tls: false,
|
||||
};
|
||||
assert_eq!(config.to_string(), "dns 127.0.0.1:53");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_interface_config_serialization() {
|
||||
let config = HttpInterfaceConfig {
|
||||
bind_addr: "127.0.0.1:8080".parse().unwrap(),
|
||||
tls: true,
|
||||
stealth: false,
|
||||
};
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: HttpInterfaceConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.bind_addr, config.bind_addr);
|
||||
assert_eq!(deserialized.tls, config.tls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_interface_config_serialization() {
|
||||
let config = DnsInterfaceConfig {
|
||||
bind_addr: "0.0.0.0:53".parse().unwrap(),
|
||||
tls: true,
|
||||
};
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: DnsInterfaceConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.bind_addr, config.bind_addr);
|
||||
assert_eq!(deserialized.tls, config.tls);
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface};
|
||||
|
||||
pub struct DnsInterface {
|
||||
pub domain: String,
|
||||
pub identity_provider: Arc<dyn crate::auth::IdentityProvider>,
|
||||
pub registry: Arc<crate::call::OperationRegistry>,
|
||||
pub env: OperationEnv,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for DnsInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Err(crate::call::CallError::new(
|
||||
"NOT_IMPLEMENTED",
|
||||
"DnsInterface is not yet implemented",
|
||||
false,
|
||||
)),
|
||||
status: 501,
|
||||
headers: std::collections::HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn dns_interface_type_exists() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let _iface = DnsInterface {
|
||||
domain: "alk.dev".to_string(),
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::call::OperationEnv;
|
||||
use crate::interface::{InterfaceRequest, InterfaceResponse, MessageInterface};
|
||||
|
||||
pub struct HttpInterface {
|
||||
pub identity_provider: Arc<dyn crate::auth::IdentityProvider>,
|
||||
pub registry: Arc<crate::call::OperationRegistry>,
|
||||
pub env: OperationEnv,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for HttpInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Err(crate::call::CallError::new(
|
||||
"NOT_IMPLEMENTED",
|
||||
"HttpInterface is not yet implemented",
|
||||
false,
|
||||
)),
|
||||
status: 501,
|
||||
headers: std::collections::HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
impl HttpInterface {
|
||||
pub fn build_router(&self) -> axum::Router {
|
||||
crate::http::router::build_router(Arc::clone(&self.identity_provider))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn http_interface_type_exists() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let _iface = HttpInterface {
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[test]
|
||||
fn http_interface_builds_router() {
|
||||
let registry = Arc::new(crate::call::OperationRegistry::new());
|
||||
let iface = HttpInterface {
|
||||
identity_provider: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
arc_swap::ArcSwap::new(Arc::new(crate::config::DynamicConfig::default())),
|
||||
))),
|
||||
env: OperationEnv::local(crate::call::OperationRegistry::new()),
|
||||
registry,
|
||||
};
|
||||
let _router = iface.build_router();
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
//! Interface layer (Layer 2) of the three-layer model (ADR-026, ADR-035).
|
||||
//!
|
||||
//! The Interface layer sits between Transport (Layer 1) and Protocol (Layer 3).
|
||||
//! It has two distinct patterns:
|
||||
//!
|
||||
//! - **StreamInterface** — consumes a `TransportStream`, produces a long-lived
|
||||
//! `Session` that yields `InterfaceEvent` frames. SSH and raw framing are
|
||||
//! `StreamInterface` implementations.
|
||||
//!
|
||||
//! - **MessageInterface** — handles individual `InterfaceRequest` →
|
||||
//! `InterfaceResponse` pairs. Manages its own transport (HTTP server, DNS
|
||||
//! server). HTTP and DNS are `MessageInterface` implementations.
|
||||
|
||||
pub mod config;
|
||||
pub mod dns;
|
||||
pub mod http;
|
||||
pub mod pairs;
|
||||
pub mod raw_framing;
|
||||
pub mod session;
|
||||
pub mod ssh;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use config::{
|
||||
DnsInterfaceConfig, HttpInterfaceConfig, InterfaceConfig, MessageInterfaceConfig,
|
||||
MessageInterfaceKind, RawFramingConfig, SshInterfaceConfig, StreamInterfaceConfig,
|
||||
StreamInterfaceKind,
|
||||
};
|
||||
pub use dns::DnsInterface;
|
||||
pub use http::HttpInterface;
|
||||
pub use pairs::{is_valid_pair, TransportKindBase, VALID_TRANSPORT_INTERFACE_PAIRS};
|
||||
pub use raw_framing::{RawFramingInterface, RawFramingSession};
|
||||
pub use session::{InterfaceEvent, InterfaceSession};
|
||||
pub use ssh::{ControlChannelBridge, SshInterface, SshSession};
|
||||
|
||||
pub trait TransportStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> TransportStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait StreamInterface: Send + Sync + 'static {
|
||||
type Session: InterfaceSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait MessageInterface: Send + Sync + 'static {
|
||||
async fn handle_request(&self, request: InterfaceRequest) -> Result<InterfaceResponse>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceRequest {
|
||||
pub operation_path: String,
|
||||
pub input: serde_json::Value,
|
||||
pub auth_token: Option<crate::auth::AuthToken>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceResponse {
|
||||
pub result: Result<serde_json::Value, crate::call::CallError>,
|
||||
pub status: u16,
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn transport_stream_trait_bounds() {
|
||||
fn assert_transport_stream<S: TransportStream>() {}
|
||||
assert_transport_stream::<tokio::io::DuplexStream>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_stream_from_duplex() {
|
||||
let (client, server) = duplex(1024);
|
||||
let _boxed: Box<dyn TransportStream> = Box::new(server);
|
||||
let _: Box<dyn TransportStream> = Box::new(client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_request_fields() {
|
||||
let req = InterfaceRequest {
|
||||
operation_path: "/v1/head/auth/verify".to_string(),
|
||||
input: serde_json::json!({"key": "value"}),
|
||||
auth_token: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
assert_eq!(req.operation_path, "/v1/head/auth/verify");
|
||||
assert!(req.auth_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_response_fields() {
|
||||
let resp = InterfaceResponse {
|
||||
result: Ok(serde_json::json!({"status": "ok"})),
|
||||
status: 200,
|
||||
headers: HashMap::new(),
|
||||
};
|
||||
assert_eq!(resp.status, 200);
|
||||
}
|
||||
|
||||
struct MockMessageInterface;
|
||||
|
||||
#[async_trait]
|
||||
impl MessageInterface for MockMessageInterface {
|
||||
async fn handle_request(&self, _request: InterfaceRequest) -> Result<InterfaceResponse> {
|
||||
Ok(InterfaceResponse {
|
||||
result: Ok(serde_json::json!({})),
|
||||
status: 200,
|
||||
headers: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_interface_trait_compiles() {
|
||||
let iface = MockMessageInterface;
|
||||
let req = InterfaceRequest {
|
||||
operation_path: "/test".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
auth_token: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
let resp = iface.handle_request(req).await.unwrap();
|
||||
assert_eq!(resp.status, 200);
|
||||
}
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
use super::config::StreamInterfaceKind;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TransportKindBase {
|
||||
Tcp,
|
||||
Tls,
|
||||
Iroh,
|
||||
WebTransport,
|
||||
}
|
||||
|
||||
fn transport_base(kind: &TransportKind) -> TransportKindBase {
|
||||
match kind {
|
||||
TransportKind::Tcp => TransportKindBase::Tcp,
|
||||
TransportKind::Tls { .. } => TransportKindBase::Tls,
|
||||
TransportKind::Iroh { .. } => TransportKindBase::Iroh,
|
||||
TransportKind::WebTransport { .. } => TransportKindBase::WebTransport,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_valid_pair(transport: &TransportKind, interface: StreamInterfaceKind) -> bool {
|
||||
let base = transport_base(transport);
|
||||
matches!(
|
||||
(base, interface),
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::Tls, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::Iroh, StreamInterfaceKind::Ssh)
|
||||
| (TransportKindBase::WebTransport, StreamInterfaceKind::Ssh)
|
||||
| (
|
||||
TransportKindBase::WebTransport,
|
||||
StreamInterfaceKind::RawFraming
|
||||
)
|
||||
| (TransportKindBase::Tcp, StreamInterfaceKind::RawFraming)
|
||||
)
|
||||
}
|
||||
|
||||
pub const VALID_TRANSPORT_INTERFACE_PAIRS: &[(TransportKindBase, StreamInterfaceKind)] = &[
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::Tls, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::Iroh, StreamInterfaceKind::Ssh),
|
||||
(TransportKindBase::WebTransport, StreamInterfaceKind::Ssh),
|
||||
(
|
||||
TransportKindBase::WebTransport,
|
||||
StreamInterfaceKind::RawFraming,
|
||||
),
|
||||
(TransportKindBase::Tcp, StreamInterfaceKind::RawFraming),
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn valid_ssh_pairs() {
|
||||
assert!(is_valid_pair(&TransportKind::Tcp, StreamInterfaceKind::Ssh));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Tls { server_name: None },
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
},
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::WebTransport { server_name: None },
|
||||
StreamInterfaceKind::Ssh
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_raw_framing_pairs() {
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::Tcp,
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
assert!(is_valid_pair(
|
||||
&TransportKind::WebTransport { server_name: None },
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pairs() {
|
||||
assert!(!is_valid_pair(
|
||||
&TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
},
|
||||
StreamInterfaceKind::RawFraming
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_base_classification() {
|
||||
assert_eq!(transport_base(&TransportKind::Tcp), TransportKindBase::Tcp);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::Tls {
|
||||
server_name: Some("example.com".to_string())
|
||||
}),
|
||||
TransportKindBase::Tls
|
||||
);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::Iroh {
|
||||
endpoint_id: "abc".to_string()
|
||||
}),
|
||||
TransportKindBase::Iroh
|
||||
);
|
||||
assert_eq!(
|
||||
transport_base(&TransportKind::WebTransport {
|
||||
server_name: Some("example.com".to_string())
|
||||
}),
|
||||
TransportKindBase::WebTransport
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_pairs_table_complete() {
|
||||
assert_eq!(VALID_TRANSPORT_INTERFACE_PAIRS.len(), 6);
|
||||
}
|
||||
}
|
||||
@@ -1,399 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||
|
||||
use crate::auth::{AuthToken, Identity, IdentityProvider};
|
||||
use crate::call::frame::{decode_with_remainder, encode};
|
||||
use crate::call::EventEnvelope;
|
||||
use crate::interface::session::{InterfaceEvent, InterfaceSession};
|
||||
use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream};
|
||||
|
||||
const READ_BUF_SIZE: usize = 8192;
|
||||
|
||||
pub struct RawFramingInterface;
|
||||
|
||||
#[async_trait]
|
||||
impl StreamInterface for RawFramingInterface {
|
||||
type Session = RawFramingSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session> {
|
||||
let raw_config = match config {
|
||||
StreamInterfaceConfig::RawFraming(c) => c,
|
||||
StreamInterfaceConfig::Ssh(_) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"RawFramingInterface received SshInterfaceConfig"
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(RawFramingSession::new(stream, Arc::clone(&raw_config.auth)))
|
||||
}
|
||||
}
|
||||
|
||||
enum AuthState {
|
||||
Pending,
|
||||
Authenticated(Identity),
|
||||
Failed,
|
||||
}
|
||||
|
||||
pub struct RawFramingSession {
|
||||
reader: BufReader<tokio::io::ReadHalf<Box<dyn TransportStream>>>,
|
||||
writer: BufWriter<tokio::io::WriteHalf<Box<dyn TransportStream>>>,
|
||||
auth_state: AuthState,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
read_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RawFramingSession {
|
||||
pub fn new(
|
||||
stream: Box<dyn TransportStream>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) -> Self {
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
Self {
|
||||
reader: BufReader::new(read_half),
|
||||
writer: BufWriter::new(write_half),
|
||||
auth_state: AuthState::Pending,
|
||||
identity_provider,
|
||||
read_buf: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_frame(&mut self) -> Result<EventEnvelope> {
|
||||
loop {
|
||||
match decode_with_remainder(&self.read_buf) {
|
||||
Ok((envelope, consumed)) => {
|
||||
self.read_buf.drain(..consumed);
|
||||
return Ok(envelope);
|
||||
}
|
||||
Err(crate::call::frame::FrameDecodeError::TooShort { .. })
|
||||
| Err(crate::call::frame::FrameDecodeError::Incomplete { .. }) => {
|
||||
let mut tmp = [0u8; READ_BUF_SIZE];
|
||||
let n = self.reader.read(&mut tmp).await?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("stream closed while reading frame"));
|
||||
}
|
||||
self.read_buf.extend_from_slice(&tmp[..n]);
|
||||
}
|
||||
Err(crate::call::frame::FrameDecodeError::Json(e)) => {
|
||||
return Err(anyhow::anyhow!("frame JSON decode error: {e}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_frame(&mut self, envelope: &EventEnvelope) -> Result<()> {
|
||||
let frame = encode(envelope);
|
||||
self.writer.write_all(&frame).await?;
|
||||
self.writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InterfaceSession for RawFramingSession {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent> {
|
||||
match &self.auth_state {
|
||||
AuthState::Failed => return None,
|
||||
AuthState::Authenticated(_) => {
|
||||
let identity = match &self.auth_state {
|
||||
AuthState::Authenticated(id) => id.clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let envelope = match self.read_frame().await {
|
||||
Ok(e) => e,
|
||||
Err(_) => return None,
|
||||
};
|
||||
return Some(InterfaceEvent::with_identity(envelope, identity));
|
||||
}
|
||||
AuthState::Pending => {}
|
||||
}
|
||||
|
||||
let envelope = match self.read_frame().await {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
self.auth_state = AuthState::Failed;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let token_raw = envelope.payload.as_str().unwrap_or("").as_bytes().to_vec();
|
||||
let token = AuthToken { raw: token_raw };
|
||||
|
||||
match self.identity_provider.resolve_from_token(&token) {
|
||||
Some(identity) => {
|
||||
self.auth_state = AuthState::Authenticated(identity.clone());
|
||||
Some(InterfaceEvent::with_identity(envelope, identity))
|
||||
}
|
||||
None => {
|
||||
self.auth_state = AuthState::Failed;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()> {
|
||||
match self.auth_state {
|
||||
AuthState::Failed => Err(anyhow::anyhow!("session authentication failed")),
|
||||
_ => self.write_frame(&envelope).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::ConfigIdentityProvider;
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::interface::RawFramingConfig;
|
||||
use arc_swap::ArcSwap;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_provider() -> Arc<dyn IdentityProvider> {
|
||||
Arc::new(ConfigIdentityProvider::new(Arc::new(ArcSwap::new(
|
||||
Arc::new(DynamicConfig::default()),
|
||||
))))
|
||||
}
|
||||
|
||||
fn make_provider_with_identity(
|
||||
identity: Identity,
|
||||
valid_token: &str,
|
||||
) -> (Arc<dyn IdentityProvider>, String) {
|
||||
struct MockProvider {
|
||||
identity: Identity,
|
||||
valid_token: String,
|
||||
}
|
||||
impl IdentityProvider for MockProvider {
|
||||
fn resolve_from_fingerprint(&self, _fp: &str) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
fn resolve_from_token(&self, token: &AuthToken) -> Option<Identity> {
|
||||
if token.raw == self.valid_token.as_bytes() {
|
||||
Some(self.identity.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
let provider = Arc::new(MockProvider {
|
||||
identity,
|
||||
valid_token: valid_token.to_string(),
|
||||
});
|
||||
(provider, valid_token.to_string())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_interface_accept_succeeds() {
|
||||
let iface = RawFramingInterface;
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let config = StreamInterfaceConfig::RawFraming(RawFramingConfig {
|
||||
auth: make_provider(),
|
||||
});
|
||||
let result = iface.accept(stream, &config).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_interface_rejects_ssh_config() {
|
||||
let iface = RawFramingInterface;
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let config = StreamInterfaceConfig::Ssh(crate::interface::SshInterfaceConfig {
|
||||
auth: make_provider(),
|
||||
forwarding: Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default()))),
|
||||
host_key: Arc::new(
|
||||
russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
});
|
||||
let result = iface.accept(stream, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_session_round_trip() {
|
||||
let identity = Identity {
|
||||
id: "test-id".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) =
|
||||
make_provider_with_identity(identity.clone(), "valid-test-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut server_session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let auth_frame = encode(&auth_envelope);
|
||||
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
client_writer.write_all(&auth_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let event = server_session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "test-id");
|
||||
|
||||
let data_envelope =
|
||||
EventEnvelope::call_requested("req-2", serde_json::json!({"op": "test"}));
|
||||
let data_frame = encode(&data_envelope);
|
||||
client_writer.write_all(&data_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let event = server_session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.envelope.r#type, "call.requested");
|
||||
assert_eq!(event.envelope.id, "req-2");
|
||||
assert!(event.identity.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_frame_auth_valid_token() {
|
||||
let identity = Identity {
|
||||
id: "auth-user".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "my-valid-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let frame = encode(&auth_envelope);
|
||||
let mut writer = tokio::io::BufWriter::new(client_stream);
|
||||
writer.write_all(&frame).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "auth-user");
|
||||
assert_eq!(event.identity.as_ref().unwrap().scopes, vec!["admin"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_frame_auth_invalid_token() {
|
||||
let identity = Identity {
|
||||
id: "auth-user".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, _) = make_provider_with_identity(identity, "correct-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let bad_envelope =
|
||||
EventEnvelope::new("auth", "auth-1", serde_json::json!("bad-token-value"));
|
||||
let frame = encode(&bad_envelope);
|
||||
let mut writer = tokio::io::BufWriter::new(client_stream);
|
||||
writer.write_all(&frame).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_none());
|
||||
|
||||
let data_envelope = EventEnvelope::call_requested("req-2", serde_json::json!({}));
|
||||
let result = session.send(data_envelope).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_session_send() {
|
||||
let identity = Identity {
|
||||
id: "send-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "send-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(4096);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut server_session = RawFramingSession::new(server_stream, provider);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-1", serde_json::json!(token_str));
|
||||
let auth_frame = encode(&auth_envelope);
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
client_writer.write_all(&auth_frame).await.unwrap();
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let _ = server_session.recv().await;
|
||||
|
||||
let response = EventEnvelope::call_responded("req-1", serde_json::json!({"result": "ok"}));
|
||||
let send_result = server_session.send(response).await;
|
||||
assert!(send_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_framing_multiple_frames_over_duplex() {
|
||||
let identity = Identity {
|
||||
id: "multi-user".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let (provider, token_str) = make_provider_with_identity(identity, "multi-token");
|
||||
|
||||
let (client, server) = tokio::io::duplex(8192);
|
||||
let server_stream: Box<dyn TransportStream> = Box::new(server);
|
||||
let client_stream: Box<dyn TransportStream> = Box::new(client);
|
||||
|
||||
let mut session = RawFramingSession::new(server_stream, provider);
|
||||
let mut client_writer = tokio::io::BufWriter::new(client_stream);
|
||||
|
||||
let auth_envelope = EventEnvelope::new("auth", "auth-0", serde_json::json!(token_str));
|
||||
client_writer
|
||||
.write_all(&encode(&auth_envelope))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 1..=5 {
|
||||
let envelope =
|
||||
EventEnvelope::call_requested(format!("req-{i}"), serde_json::json!({"seq": i}));
|
||||
client_writer.write_all(&encode(&envelope)).await.unwrap();
|
||||
}
|
||||
client_writer.flush().await.unwrap();
|
||||
|
||||
let auth_event = session.recv().await;
|
||||
assert!(auth_event.is_some());
|
||||
assert!(auth_event.unwrap().identity.is_some());
|
||||
|
||||
for i in 1..=5 {
|
||||
let event = session.recv().await;
|
||||
assert!(event.is_some());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.envelope.id, format!("req-{i}"));
|
||||
assert!(event.identity.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_framing_interface_type_exists() {
|
||||
let _iface = RawFramingInterface;
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::auth::Identity;
|
||||
use crate::call::EventEnvelope;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceEvent {
|
||||
pub envelope: EventEnvelope,
|
||||
pub identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl InterfaceEvent {
|
||||
pub fn new(envelope: EventEnvelope) -> Self {
|
||||
Self {
|
||||
envelope,
|
||||
identity: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_identity(envelope: EventEnvelope, identity: Identity) -> Self {
|
||||
Self {
|
||||
envelope,
|
||||
identity: Some(identity),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait InterfaceSession: Send {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent>;
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn interface_event_new() {
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
let event = InterfaceEvent::new(envelope.clone());
|
||||
assert_eq!(event.envelope, envelope);
|
||||
assert!(event.identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_event_with_identity() {
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
let event = InterfaceEvent::with_identity(envelope.clone(), identity.clone());
|
||||
assert_eq!(event.envelope, envelope);
|
||||
assert!(event.identity.is_some());
|
||||
assert_eq!(event.identity.as_ref().unwrap().id, "SHA256:abc123");
|
||||
}
|
||||
}
|
||||
@@ -1,982 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::server::{self, Config};
|
||||
use russh::Channel;
|
||||
use russh::ChannelId;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::auth::identity::{Identity, IdentityProvider};
|
||||
use crate::call::frame::{FrameFramedReader, FrameFramedWriter};
|
||||
use crate::call::EventEnvelope;
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::interface::session::{InterfaceEvent, InterfaceSession};
|
||||
use crate::interface::{StreamInterface, StreamInterfaceConfig, TransportStream};
|
||||
use crate::server::control_channel::{
|
||||
ControlChannelHandler, ControlChannelRouter, DuplexStream, ALKNET_CONTROL_DESTINATION,
|
||||
ALKNET_PREFIX,
|
||||
};
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
use crate::transport::TransportKind;
|
||||
|
||||
struct SshHandler {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
authenticated_identity: Option<Identity>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
bridge_event_tx: Option<mpsc::Sender<InterfaceEvent>>,
|
||||
bridge_envelope_rx: Option<mpsc::Receiver<EventEnvelope>>,
|
||||
connected_at: Instant,
|
||||
}
|
||||
|
||||
impl SshHandler {
|
||||
fn new(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
dynamic,
|
||||
identity_provider,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
authenticated_identity: None,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
bridge_event_tx: None,
|
||||
bridge_envelope_rx: None,
|
||||
connected_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn with_control_channel_router(mut self, router: ControlChannelRouter) -> Self {
|
||||
self.control_channel_router = router;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_bridge_channels(
|
||||
mut self,
|
||||
event_tx: mpsc::Sender<InterfaceEvent>,
|
||||
envelope_rx: mpsc::Receiver<EventEnvelope>,
|
||||
) -> Self {
|
||||
self.bridge_event_tx = Some(event_tx);
|
||||
self.bridge_envelope_rx = Some(envelope_rx);
|
||||
self
|
||||
}
|
||||
|
||||
fn has_control_channel_bridge(&self) -> bool {
|
||||
self.bridge_event_tx.is_some() && self.bridge_envelope_rx.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SshHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl server::Handler for SshHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn auth_publickey(
|
||||
&mut self,
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<server::Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
|
||||
let identity = self
|
||||
.identity_provider
|
||||
.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match identity {
|
||||
Some(id) => {
|
||||
self.authenticated_identity = Some(id);
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(server::Auth::Accept)
|
||||
}
|
||||
None => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn channel_open_direct_tcpip(
|
||||
&mut self,
|
||||
channel: Channel<server::Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
originator_address: &str,
|
||||
originator_port: u32,
|
||||
_session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
if host_to_connect.starts_with(ALKNET_PREFIX) {
|
||||
if host_to_connect == ALKNET_CONTROL_DESTINATION && self.has_control_channel_bridge() {
|
||||
let event_tx = self.bridge_event_tx.take().unwrap();
|
||||
let envelope_rx = self.bridge_envelope_rx.take().unwrap();
|
||||
let identity = self.authenticated_identity.clone();
|
||||
tokio::spawn(async move {
|
||||
let stream = channel.into_stream();
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
run_control_channel_bridge(
|
||||
read_half,
|
||||
write_half,
|
||||
identity,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
let _ = (originator_address, originator_port);
|
||||
return Ok(true);
|
||||
}
|
||||
if self.control_channel_router.has_handler() {
|
||||
if let Some(handler) = self.control_channel_router.take_handler() {
|
||||
let stream: Box<dyn DuplexStream> = Box::new(channel.into_stream());
|
||||
tokio::spawn(async move {
|
||||
handler.handle_channel(stream).await;
|
||||
});
|
||||
}
|
||||
let _ = (originator_address, originator_port);
|
||||
return Ok(true);
|
||||
}
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let identity = self
|
||||
.authenticated_identity
|
||||
.clone()
|
||||
.unwrap_or_else(|| Identity {
|
||||
id: String::new(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
});
|
||||
|
||||
let policy = self.dynamic.load();
|
||||
let allowed = policy.forwarding.check(
|
||||
host_to_connect,
|
||||
port_to_connect as u16,
|
||||
&identity,
|
||||
self.transport.clone(),
|
||||
);
|
||||
|
||||
if !allowed {
|
||||
tracing::info!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
identity = %identity.id,
|
||||
transport = %self.transport,
|
||||
"forwarding denied by policy"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let target_host = host_to_connect.to_string();
|
||||
let target_port = port_to_connect;
|
||||
let proxy_config =
|
||||
self.outbound_proxy
|
||||
.clone()
|
||||
.unwrap_or(crate::server::handler::ProxyConfig {
|
||||
mode: crate::server::handler::ProxyMode::Direct,
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let target = match format!("{target_host}:{target_port}")
|
||||
.parse::<std::net::SocketAddr>()
|
||||
{
|
||||
Ok(addr) => addr,
|
||||
Err(_) => {
|
||||
match tokio::net::lookup_host((&target_host[..], target_port as u16)).await {
|
||||
Ok(mut addrs) => match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return,
|
||||
},
|
||||
Err(_) => return,
|
||||
}
|
||||
}
|
||||
};
|
||||
crate::server::channel_proxy::proxy_channel(
|
||||
channel.into_stream(),
|
||||
target,
|
||||
&proxy_config,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let _ = (originator_address, originator_port);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn channel_open_session(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected session channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_x11(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected x11 channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_forwarded_tcpip(
|
||||
&mut self,
|
||||
_channel: Channel<server::Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn exec_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
data: &[u8],
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
data_len = data.len(),
|
||||
"rejected exec request on channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shell_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected shell request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subsystem_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
name: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
subsystem = name,
|
||||
"rejected subsystem request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pty_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
term: &str,
|
||||
col_width: u32,
|
||||
row_height: u32,
|
||||
pix_width: u32,
|
||||
pix_height: u32,
|
||||
modes: &[(russh::Pty, u32)],
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
term = term,
|
||||
"rejected pty request on channel"
|
||||
);
|
||||
let _ = (col_width, row_height, pix_width, pix_height, modes);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn env_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
variable_name: &str,
|
||||
variable_value: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
variable = variable_name,
|
||||
"rejected env request on channel"
|
||||
);
|
||||
let _ = variable_value;
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn x11_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
single_connection: bool,
|
||||
x11_auth_protocol: &str,
|
||||
x11_auth_cookie: &str,
|
||||
x11_screen_number: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected x11 request on channel"
|
||||
);
|
||||
let _ = (
|
||||
single_connection,
|
||||
x11_auth_protocol,
|
||||
x11_auth_cookie,
|
||||
x11_screen_number,
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn agent_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected agent forwarding request on channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: &mut u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
address = address,
|
||||
port = *port,
|
||||
"rejected tcpip-forward request (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn cancel_tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: u32,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
let _ = (address, port, session);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn streamlocal_forward(
|
||||
&mut self,
|
||||
socket_path: &str,
|
||||
session: &mut server::Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
socket_path = socket_path,
|
||||
"rejected streamlocal-forward request"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn signal(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
signal: russh::Sig,
|
||||
session: &mut server::Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::debug!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
signal = ?signal,
|
||||
"received signal on channel (ignored)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshInterface {
|
||||
config: Arc<Config>,
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
outbound_proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
max_auth_attempts: usize,
|
||||
}
|
||||
|
||||
impl SshInterface {
|
||||
pub fn new(config: Arc<Config>, dynamic: Arc<ArcSwap<DynamicConfig>>) -> Self {
|
||||
Self {
|
||||
config,
|
||||
dynamic,
|
||||
connection_limiter: Arc::new(ConnectionRateLimiter::new(0)),
|
||||
outbound_proxy: None,
|
||||
max_auth_attempts: 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_connection_limiter(mut self, limiter: Arc<ConnectionRateLimiter>) -> Self {
|
||||
self.connection_limiter = limiter;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_outbound_proxy(
|
||||
mut self,
|
||||
proxy: Option<crate::server::handler::ProxyConfig>,
|
||||
) -> Self {
|
||||
self.outbound_proxy = proxy;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_auth_attempts(mut self, max: usize) -> Self {
|
||||
self.max_auth_attempts = max;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Arc<Config> {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn dynamic(&self) -> &Arc<ArcSwap<DynamicConfig>> {
|
||||
&self.dynamic
|
||||
}
|
||||
|
||||
async fn accept_inner(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
ssh_config: &crate::interface::SshInterfaceConfig,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
) -> Result<SshSession> {
|
||||
let identity_provider = Arc::clone(&ssh_config.auth);
|
||||
let _forwarding = Arc::clone(&ssh_config.forwarding);
|
||||
|
||||
let (event_tx, event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let handler = SshHandler::new(
|
||||
Arc::clone(&self.dynamic),
|
||||
identity_provider,
|
||||
self.outbound_proxy.clone(),
|
||||
remote_addr,
|
||||
transport,
|
||||
Arc::clone(&self.connection_limiter),
|
||||
self.max_auth_attempts,
|
||||
)
|
||||
.with_bridge_channels(event_tx, envelope_rx);
|
||||
|
||||
let running = server::run_stream(Arc::clone(&self.config), stream, handler).await?;
|
||||
let handle = running.handle();
|
||||
let join = tokio::spawn(async {
|
||||
let _ = running.await;
|
||||
});
|
||||
|
||||
Ok(SshSession {
|
||||
handle,
|
||||
_join: join,
|
||||
event_rx,
|
||||
envelope_tx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StreamInterface for SshInterface {
|
||||
type Session = SshSession;
|
||||
|
||||
async fn accept(
|
||||
&self,
|
||||
stream: Box<dyn TransportStream>,
|
||||
config: &StreamInterfaceConfig,
|
||||
) -> Result<Self::Session> {
|
||||
let ssh_config = match config {
|
||||
StreamInterfaceConfig::Ssh(c) => c,
|
||||
StreamInterfaceConfig::RawFraming(_) => {
|
||||
return Err(anyhow::anyhow!("SshInterface received RawFramingConfig"));
|
||||
}
|
||||
};
|
||||
|
||||
self.accept_inner(stream, ssh_config, None, TransportKind::Tcp)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SshSession {
|
||||
handle: server::Handle,
|
||||
_join: tokio::task::JoinHandle<()>,
|
||||
event_rx: mpsc::Receiver<InterfaceEvent>,
|
||||
envelope_tx: mpsc::Sender<EventEnvelope>,
|
||||
}
|
||||
|
||||
impl SshSession {
|
||||
pub fn handle(&self) -> &server::Handle {
|
||||
&self.handle
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InterfaceSession for SshSession {
|
||||
async fn recv(&mut self) -> Option<InterfaceEvent> {
|
||||
self.event_rx.recv().await
|
||||
}
|
||||
|
||||
async fn send(&mut self, envelope: EventEnvelope) -> Result<()> {
|
||||
self.envelope_tx
|
||||
.send(envelope)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("control channel bridge closed"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_control_channel_bridge<R, W>(
|
||||
read_half: R,
|
||||
write_half: W,
|
||||
identity: Option<Identity>,
|
||||
event_tx: mpsc::Sender<InterfaceEvent>,
|
||||
mut envelope_rx: mpsc::Receiver<EventEnvelope>,
|
||||
) where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let mut reader = FrameFramedReader::new(read_half);
|
||||
let mut writer = FrameFramedWriter::new(write_half);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
frame = reader.read_frame() => {
|
||||
match frame {
|
||||
Ok(Some(envelope)) => {
|
||||
let event = match &identity {
|
||||
Some(id) => InterfaceEvent::with_identity(envelope, id.clone()),
|
||||
None => InterfaceEvent::new(envelope),
|
||||
};
|
||||
if event_tx.send(event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Ok(None) => return,
|
||||
Err(_) => return,
|
||||
}
|
||||
}
|
||||
envelope = envelope_rx.recv() => {
|
||||
match envelope {
|
||||
Some(envelope) => {
|
||||
if writer.write_frame(&envelope).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
None => return,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlChannelBridge {
|
||||
identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl ControlChannelBridge {
|
||||
pub fn new(identity: Option<Identity>) -> Self {
|
||||
Self { identity }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for ControlChannelBridge {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>) {
|
||||
let (event_tx, _event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (_envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let identity = self.identity.clone();
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
read_half,
|
||||
write_half,
|
||||
identity,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::call::frame::{FrameFramedReader, FrameFramedWriter};
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn ssh_interface_constructs_with_config() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
|
||||
let iface = SshInterface::new(config, dynamic);
|
||||
assert!(iface.config().keys.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_interface_builder_pattern() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(5));
|
||||
|
||||
let iface = SshInterface::new(config, dynamic)
|
||||
.with_connection_limiter(limiter)
|
||||
.with_max_auth_attempts(3);
|
||||
|
||||
assert!(iface.config().keys.len() >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_handler_auth_delegates_to_identity_provider() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct MockProvider {
|
||||
identities: HashMap<String, Identity>,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockProvider {
|
||||
fn resolve_from_fingerprint(&self, fp: &str) -> Option<Identity> {
|
||||
self.identities.get(fp).cloned()
|
||||
}
|
||||
fn resolve_from_token(&self, _t: &crate::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let mut ids = HashMap::new();
|
||||
ids.insert(
|
||||
"SHA256:testkey".to_string(),
|
||||
Identity {
|
||||
id: "SHA256:testkey".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
},
|
||||
);
|
||||
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(MockProvider { identities: ids });
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
|
||||
let handler = SshHandler::new(
|
||||
dynamic,
|
||||
provider,
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
|
||||
assert!(handler.authenticated_identity.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssh_handler_connection_rate_limiting() {
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(
|
||||
crate::auth::identity::ConfigIdentityProvider::new(Arc::clone(&dynamic)),
|
||||
);
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = SshHandler::new(
|
||||
Arc::clone(&dynamic),
|
||||
Arc::clone(&provider),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
Arc::clone(&limiter),
|
||||
10,
|
||||
);
|
||||
assert!(h1.connection_allowed);
|
||||
|
||||
let h2 = SshHandler::new(
|
||||
dynamic,
|
||||
provider,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(!h2.connection_allowed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_interface_rejects_raw_framing_config() {
|
||||
let config = Arc::new(Config {
|
||||
keys: vec![russh::keys::PrivateKey::random(
|
||||
&mut rand_core::OsRng,
|
||||
russh::keys::Algorithm::Ed25519,
|
||||
)
|
||||
.unwrap()],
|
||||
..Default::default()
|
||||
});
|
||||
let dynamic = Arc::new(ArcSwap::new(Arc::new(DynamicConfig::default())));
|
||||
let iface = SshInterface::new(config, dynamic);
|
||||
let (_client, server) = tokio::io::duplex(1024);
|
||||
let stream: Box<dyn TransportStream> = Box::new(server);
|
||||
|
||||
let raw_config = StreamInterfaceConfig::RawFraming(crate::interface::RawFramingConfig {
|
||||
auth: Arc::new(crate::auth::ConfigIdentityProvider::new(Arc::new(
|
||||
ArcSwap::new(Arc::new(DynamicConfig::default())),
|
||||
))),
|
||||
});
|
||||
let result = iface.accept(stream, &raw_config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_session_round_trip_event_envelope() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let identity = Identity {
|
||||
id: "SHA256:test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
};
|
||||
let identity_clone = identity.clone();
|
||||
|
||||
let (server_read, server_write) = tokio::io::split(server);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
server_read,
|
||||
server_write,
|
||||
Some(identity_clone),
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
|
||||
let (client_read, client_write) = tokio::io::split(client);
|
||||
let mut client_reader = FrameFramedReader::new(client_read);
|
||||
let mut client_writer = FrameFramedWriter::new(client_write);
|
||||
|
||||
let envelope = EventEnvelope::call_requested("req-1", serde_json::json!({"op": "test"}));
|
||||
client_writer.write_frame(&envelope).await.unwrap();
|
||||
|
||||
let received_event =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(received_event.envelope, envelope);
|
||||
assert_eq!(received_event.identity.as_ref().unwrap().id, "SHA256:test");
|
||||
|
||||
let response = EventEnvelope::call_responded("req-1", serde_json::json!({"result": 42}));
|
||||
envelope_tx.send(response.clone()).await.unwrap();
|
||||
|
||||
let read_back = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
client_reader.read_frame(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(read_back, response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_session_recv_without_identity() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let (event_tx, mut event_rx) = mpsc::channel::<InterfaceEvent>(256);
|
||||
let (_envelope_tx, envelope_rx) = mpsc::channel::<EventEnvelope>(256);
|
||||
|
||||
let (server_read, server_write) = tokio::io::split(server);
|
||||
tokio::spawn(run_control_channel_bridge(
|
||||
server_read,
|
||||
server_write,
|
||||
None,
|
||||
event_tx,
|
||||
envelope_rx,
|
||||
));
|
||||
|
||||
let (client_read, client_write) = tokio::io::split(client);
|
||||
let mut client_writer = FrameFramedWriter::new(client_write);
|
||||
let _client_reader = FrameFramedReader::new(client_read);
|
||||
|
||||
let envelope = EventEnvelope::call_requested("req-2", serde_json::json!({"op": "no-id"}));
|
||||
client_writer.write_frame(&envelope).await.unwrap();
|
||||
|
||||
let received_event =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), event_rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(received_event.envelope, envelope);
|
||||
assert!(received_event.identity.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_channel_router_with_handler_routes_data() {
|
||||
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let called_clone = called.clone();
|
||||
|
||||
struct TrackingHandler {
|
||||
called: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackingHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(TrackingHandler {
|
||||
called: called_clone,
|
||||
}));
|
||||
assert!(router.has_handler());
|
||||
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(std::sync::atomic::Ordering::SeqCst));
|
||||
}
|
||||
}
|
||||
@@ -1,110 +1,16 @@
|
||||
//! # alknet-core
|
||||
//! alknet-core: Core library for ALPN-based protocol dispatch.
|
||||
//!
|
||||
//! Core library for [Alknet](https://git.alk.dev/alkdev/alknet), a self-hostable SSH-based
|
||||
//! tunnel tool. This crate provides the transport abstraction, SOCKS5 server, port forwarding,
|
||||
//! authentication, and server handler — everything needed to build an alknet client or server
|
||||
//! on top of pluggable transports.
|
||||
//!
|
||||
//! > **Alpha software.** This crate depends on solid libraries (russh, tokio, rustls, iroh)
|
||||
//! > for core functionality, but the integration layer has not been battle-tested. Use with
|
||||
//! > caution and report issues.
|
||||
//!
|
||||
//! # Key concepts
|
||||
//!
|
||||
//! - **Transport trait** — produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
|
||||
//! that SSH consumes. Implementations: TCP, TLS, iroh (QUIC P2P).
|
||||
//! - **SOCKS5 server** — the primary client interface, listening on a local port and routing
|
||||
//! traffic through SSH channels.
|
||||
//! - **Port forwarding** — `-L` local and `-R` remote port forwards over SSH channels.
|
||||
//! - **Authentication** — Ed25519 public key and OpenSSH certificate authority. No passwords.
|
||||
//! - **Server handler** — accepts SSH connections via a `TransportAcceptor` and proxies
|
||||
//! `direct-tcpip` channel requests to targets (directly or via outbound proxy).
|
||||
//!
|
||||
//! # Feature flags
|
||||
//!
|
||||
//! | Feature | Default | Description |
|
||||
//! |---------|---------|-------------|
|
||||
//! | `tls` | yes | TLS transport via `tokio-rustls` |
|
||||
//! | `iroh` | yes | iroh QUIC P2P transport |
|
||||
//! | `acme` | no | ACME/Let's Encrypt auto-cert provisioning (implies `tls`) |
|
||||
//! | `irpc` | no | irpc service layer (AuthProtocol, AuthServiceImpl) |
|
||||
//! | `testutil` | no | Test utilities (for internal use) |
|
||||
//!
|
||||
//! # Quick example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use std::sync::Arc;
|
||||
//! use alknet_core::transport::TcpTransport;
|
||||
//! use alknet_core::client::{ClientSession, ConnectOptions, TransportMode};
|
||||
//! use alknet_core::auth::keys::KeySource;
|
||||
//! use alknet_core::Transport;
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let opts = ConnectOptions::new(KeySource::File("/path/to/key".into()))
|
||||
//! .server("example.com:22")
|
||||
//! .transport_mode(TransportMode::Tcp);
|
||||
//! let transport = Arc::new(TcpTransport::new("example.com:22".parse()?));
|
||||
//! let session = ClientSession::new(opts, transport).await?;
|
||||
//! session.run().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
//! Every handler crate depends on this crate. It provides the
|
||||
//! [`ProtocolHandler`][crate::types::ProtocolHandler] trait, the
|
||||
//! [`Connection`][crate::types::Connection] wrapper, auth primitives,
|
||||
//! hot-reloadable configuration, and the [`AlknetEndpoint`][crate::endpoint::AlknetEndpoint]
|
||||
//! that dispatches incoming QUIC connections by ALPN string.
|
||||
|
||||
pub mod auth;
|
||||
pub mod call;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod credentials;
|
||||
pub mod error;
|
||||
pub mod interface;
|
||||
pub mod server;
|
||||
pub mod socks5;
|
||||
pub mod transport;
|
||||
pub mod endpoint;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub mod http;
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub use http::IdentityExt;
|
||||
|
||||
#[cfg(feature = "testutil")]
|
||||
pub mod testutil;
|
||||
|
||||
#[cfg(feature = "irpc")]
|
||||
pub use auth::{AuthProtocol, AuthResult, AuthServiceImpl};
|
||||
pub use auth::{AuthToken, ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
pub use call::{
|
||||
decode as decode_frame, decode_with_remainder as decode_frame_with_remainder,
|
||||
encode as encode_frame,
|
||||
};
|
||||
pub use call::{
|
||||
register_default_operations, services_list_spec, services_schema_spec, AccessControl,
|
||||
CallError, EventEnvelope, FrameDecodeError, Handler, OperationContext, OperationEnv,
|
||||
OperationRegistry, OperationRegistryBuilder, OperationSpec, OperationType, PendingRequestMap,
|
||||
ResponseEnvelope,
|
||||
};
|
||||
pub use call::{CALL_ABORTED, CALL_COMPLETED, CALL_ERROR, CALL_REQUESTED, CALL_RESPONDED};
|
||||
pub use client::channel_manager::{ChannelManager, ForwardRequest};
|
||||
pub use client::connect::{ClientSession, ConnectError, ConnectOptions, TransportMode};
|
||||
pub use config::{
|
||||
AuthPolicy, ConfigReloadHandle, ConfigServiceImpl, DynamicConfig, ForwardingAction,
|
||||
ForwardingPolicy, ForwardingRule, RateLimitConfig, StaticConfig, TargetPattern,
|
||||
};
|
||||
pub use credentials::{
|
||||
ConfigCredentialProvider, CredentialProvider, CredentialSet, SecretStoreCredentialProvider,
|
||||
};
|
||||
pub use error::{AuthError, ChannelError, ConfigError, ForwardError, TransportError};
|
||||
pub use interface::{
|
||||
is_valid_pair, DnsInterface, DnsInterfaceConfig, HttpInterface, HttpInterfaceConfig,
|
||||
InterfaceConfig, InterfaceEvent, InterfaceRequest, InterfaceResponse, InterfaceSession,
|
||||
MessageInterface, MessageInterfaceConfig, MessageInterfaceKind, RawFramingConfig,
|
||||
RawFramingInterface, RawFramingSession, SshInterface, SshInterfaceConfig, SshSession,
|
||||
StreamInterface, StreamInterfaceConfig, StreamInterfaceKind, TransportKindBase,
|
||||
TransportStream, VALID_TRANSPORT_INTERFACE_PAIRS,
|
||||
};
|
||||
pub use server::serve::{
|
||||
DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions,
|
||||
ServeTransportMode, Server, StreamListenerConfig,
|
||||
};
|
||||
pub use transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
pub use auth::{IdentityProvider, IdentityStore};
|
||||
pub use store::{CredentialStore, EncryptedData, InMemoryCredentialStore, StoreError};
|
||||
|
||||
@@ -1,555 +0,0 @@
|
||||
//! Outbound connection proxy for SSH channel targets.
|
||||
//!
|
||||
//! Connects to the requested `host:port` either directly, via SOCKS5 proxy, or
|
||||
//! via HTTP CONNECT proxy, then proxies bytes bidirectionally between the SSH
|
||||
//! channel and the outbound TCP stream.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use super::handler::{ProxyConfig, ProxyMode};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelProxyError {
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("target unreachable")]
|
||||
TargetUnreachable,
|
||||
#[error("socks5 proxy handshake failed")]
|
||||
Socks5HandshakeFailed,
|
||||
#[error("socks5 proxy rejected connection")]
|
||||
Socks5ProxyRejected,
|
||||
#[error("http connect proxy handshake failed")]
|
||||
HttpConnectHandshakeFailed,
|
||||
#[error("http connect proxy rejected: {0}")]
|
||||
HttpConnectProxyRejected(String),
|
||||
#[error("io error")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub async fn connect_outbound(
|
||||
target: SocketAddr,
|
||||
proxy: &ProxyConfig,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
match &proxy.mode {
|
||||
ProxyMode::Direct => connect_direct(target).await,
|
||||
ProxyMode::Socks5(addr) => connect_socks5(target, *addr).await,
|
||||
ProxyMode::HttpConnect(addr) => connect_http_connect(target, *addr).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_direct(target: SocketAddr) -> Result<TcpStream, ChannelProxyError> {
|
||||
TcpStream::connect(target)
|
||||
.await
|
||||
.map_err(|e| map_connection_error(e, target))
|
||||
}
|
||||
|
||||
async fn connect_socks5(
|
||||
target: SocketAddr,
|
||||
proxy_addr: SocketAddr,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
let mut stream = TcpStream::connect(proxy_addr)
|
||||
.await
|
||||
.map_err(ChannelProxyError::from)?;
|
||||
|
||||
stream.write_all(&[0x05, 0x01, 0x00]).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
stream.read_exact(&mut resp).await?;
|
||||
if resp[0] != 0x05 || resp[1] != 0x00 {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
|
||||
let ip_bytes = target.ip().to_string();
|
||||
let mut connect_req = vec![0x05, 0x01, 0x00, 0x03];
|
||||
connect_req.push(ip_bytes.len() as u8);
|
||||
connect_req.extend_from_slice(ip_bytes.as_bytes());
|
||||
connect_req.extend_from_slice(&target.port().to_be_bytes());
|
||||
stream.write_all(&connect_req).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut reply_header = [0u8; 4];
|
||||
stream.read_exact(&mut reply_header).await?;
|
||||
if reply_header[0] != 0x05 {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
if reply_header[1] != 0x00 {
|
||||
return Err(ChannelProxyError::Socks5ProxyRejected);
|
||||
}
|
||||
|
||||
let atyp = reply_header[3];
|
||||
match atyp {
|
||||
0x01 => {
|
||||
let mut _addr = [0u8; 4];
|
||||
stream.read_exact(&mut _addr).await?;
|
||||
}
|
||||
0x04 => {
|
||||
let mut _addr = [0u8; 16];
|
||||
stream.read_exact(&mut _addr).await?;
|
||||
}
|
||||
0x03 => {
|
||||
let len = stream.read_u8().await?;
|
||||
let mut _domain = vec![0u8; len as usize];
|
||||
stream.read_exact(&mut _domain).await?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ChannelProxyError::Socks5HandshakeFailed);
|
||||
}
|
||||
}
|
||||
let mut _port = [0u8; 2];
|
||||
stream.read_exact(&mut _port).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn connect_http_connect(
|
||||
target: SocketAddr,
|
||||
proxy_addr: SocketAddr,
|
||||
) -> Result<TcpStream, ChannelProxyError> {
|
||||
let mut stream = TcpStream::connect(proxy_addr)
|
||||
.await
|
||||
.map_err(ChannelProxyError::from)?;
|
||||
|
||||
let connect_request = format!(
|
||||
"CONNECT {}:{} HTTP/1.1\r\nHost: {}:{}\r\n\r\n",
|
||||
target.ip(),
|
||||
target.port(),
|
||||
target.ip(),
|
||||
target.port()
|
||||
);
|
||||
stream.write_all(connect_request.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut response = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = stream.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
return Err(ChannelProxyError::HttpConnectHandshakeFailed);
|
||||
}
|
||||
response.extend_from_slice(&buf[..n]);
|
||||
if response.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response);
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
|
||||
if status_line.contains("200") {
|
||||
Ok(stream)
|
||||
} else {
|
||||
Err(ChannelProxyError::HttpConnectProxyRejected(
|
||||
status_line.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn map_connection_error(e: std::io::Error, _target: SocketAddr) -> ChannelProxyError {
|
||||
match e.kind() {
|
||||
std::io::ErrorKind::ConnectionRefused => ChannelProxyError::ConnectionRefused,
|
||||
std::io::ErrorKind::AddrNotAvailable
|
||||
| std::io::ErrorKind::NetworkUnreachable
|
||||
| std::io::ErrorKind::HostUnreachable => ChannelProxyError::TargetUnreachable,
|
||||
_ => ChannelProxyError::Io(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn proxy_channel<S>(channel: S, target: SocketAddr, proxy: &ProxyConfig)
|
||||
where
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
if let Ok(outbound) = connect_outbound(target, proxy).await {
|
||||
let (mut read_chan, mut write_chan) = tokio::io::split(channel);
|
||||
let (mut read_out, mut write_out) = outbound.into_split();
|
||||
|
||||
let client_to_target = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut read_chan, &mut write_out).await;
|
||||
let _ = write_out.shutdown().await;
|
||||
});
|
||||
|
||||
let target_to_client = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut read_out, &mut write_chan).await;
|
||||
let _ = write_chan.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = client_to_target.await;
|
||||
let _ = target_to_client.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
fn direct_config() -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::Direct,
|
||||
}
|
||||
}
|
||||
|
||||
fn socks5_config(addr: SocketAddr) -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::Socks5(addr),
|
||||
}
|
||||
}
|
||||
|
||||
fn http_connect_config(addr: SocketAddr) -> ProxyConfig {
|
||||
ProxyConfig {
|
||||
mode: ProxyMode::HttpConnect(addr),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_to_echo_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut sock, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let stream = connect_outbound(addr, &direct_config()).await.unwrap();
|
||||
let (mut read, mut write) = stream.into_split();
|
||||
write.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
read.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_target_unreachable() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let result = connect_outbound(target, &direct_config()).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_handshake() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = target_listener.local_addr().unwrap();
|
||||
|
||||
let target_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut greeting = [0u8; 3];
|
||||
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||
assert_eq!(greeting[0], 0x05);
|
||||
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||
|
||||
let mut req_header = [0u8; 4];
|
||||
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||
assert_eq!(req_header[0], 0x05);
|
||||
assert_eq!(req_header[1], 0x01);
|
||||
|
||||
let atyp = req_header[3];
|
||||
assert_eq!(atyp, 0x03);
|
||||
|
||||
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||
let mut domain = vec![0u8; domain_len];
|
||||
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||
let mut port_bytes = [0u8; 2];
|
||||
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||
|
||||
let target: SocketAddr = format!(
|
||||
"{}:{}",
|
||||
String::from_utf8_lossy(&domain),
|
||||
u16::from_be_bytes(port_bytes)
|
||||
)
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
let reply = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||
proxy_sock.write_all(&reply).await.unwrap();
|
||||
|
||||
let mut target_stream = TcpStream::connect(target).await.unwrap();
|
||||
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||
});
|
||||
|
||||
let config = socks5_config(proxy_addr);
|
||||
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||
stream.write_all(b"hello socks").await.unwrap();
|
||||
let mut buf = [0u8; 11];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello socks");
|
||||
drop(stream);
|
||||
|
||||
let _ = target_server.await;
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_rejected() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut greeting = [0u8; 3];
|
||||
proxy_sock.read_exact(&mut greeting).await.unwrap();
|
||||
proxy_sock.write_all(&[0x05, 0x00]).await.unwrap();
|
||||
|
||||
let mut req_header = [0u8; 4];
|
||||
proxy_sock.read_exact(&mut req_header).await.unwrap();
|
||||
|
||||
let domain_len = proxy_sock.read_u8().await.unwrap() as usize;
|
||||
let mut domain = vec![0u8; domain_len];
|
||||
proxy_sock.read_exact(&mut domain).await.unwrap();
|
||||
let mut port_bytes = [0u8; 2];
|
||||
proxy_sock.read_exact(&mut port_bytes).await.unwrap();
|
||||
|
||||
let reply = vec![0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
|
||||
proxy_sock.write_all(&reply).await.unwrap();
|
||||
});
|
||||
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let config = socks5_config(proxy_addr);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
ChannelProxyError::Socks5ProxyRejected
|
||||
));
|
||||
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_handshake() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let target_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = target_listener.local_addr().unwrap();
|
||||
|
||||
let target_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = target_listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||
|
||||
let target_str = extract_connect_target(&String::from_utf8_lossy(&request));
|
||||
let mut target_stream = TcpStream::connect(target_str).await.unwrap();
|
||||
let _ = tokio::io::copy_bidirectional(&mut proxy_sock, &mut target_stream).await;
|
||||
});
|
||||
|
||||
let config = http_connect_config(proxy_addr);
|
||||
let mut stream = connect_outbound(target_addr, &config).await.unwrap();
|
||||
stream.write_all(b"hello http").await.unwrap();
|
||||
let mut buf = [0u8; 10];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello http");
|
||||
drop(stream);
|
||||
|
||||
let _ = target_server.await;
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
fn extract_connect_target(request: &str) -> String {
|
||||
let connect_line = request.lines().next().unwrap_or("");
|
||||
let parts: Vec<&str> = connect_line.split_whitespace().collect();
|
||||
if parts.len() >= 2 {
|
||||
parts[1].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_rejected() {
|
||||
let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let proxy_addr = proxy_listener.local_addr().unwrap();
|
||||
|
||||
let proxy_server = tokio::spawn(async move {
|
||||
let (mut proxy_sock, _) = proxy_listener.accept().await.unwrap();
|
||||
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let n = proxy_sock.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let response = "HTTP/1.1 403 Forbidden\r\n\r\n";
|
||||
proxy_sock.write_all(response.as_bytes()).await.unwrap();
|
||||
});
|
||||
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let config = http_connect_config(proxy_addr);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
ChannelProxyError::HttpConnectProxyRejected(msg) => {
|
||||
assert!(msg.contains("403"));
|
||||
}
|
||||
other => panic!("expected HttpConnectProxyRejected, got {:?}", other),
|
||||
}
|
||||
|
||||
let _ = proxy_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn target_unreachable_returns_appropriate_error() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let result = connect_outbound(target, &direct_config()).await;
|
||||
match result.unwrap_err() {
|
||||
ChannelProxyError::TargetUnreachable
|
||||
| ChannelProxyError::ConnectionRefused
|
||||
| ChannelProxyError::Io(_) => {}
|
||||
other => panic!("unexpected error type: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn socks5_proxy_unreachable() {
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let config = socks5_config(bad_proxy);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_connect_proxy_unreachable() {
|
||||
let target: SocketAddr = "127.0.0.1:9999".parse().unwrap();
|
||||
let bad_proxy: SocketAddr = "127.0.0.1:1".parse().unwrap();
|
||||
let config = http_connect_config(bad_proxy);
|
||||
let result = connect_outbound(target, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
struct MockChannel {
|
||||
read_half: tokio::io::ReadHalf<DuplexStream>,
|
||||
write_half: tokio::io::WriteHalf<DuplexStream>,
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncRead for MockChannel {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().read_half).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for MockChannel {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().write_half).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_mock_channel() -> (MockChannel, DuplexStream) {
|
||||
let (client, server) = duplex(4096);
|
||||
let (read_half, write_half) = tokio::io::split(client);
|
||||
(
|
||||
MockChannel {
|
||||
read_half,
|
||||
write_half,
|
||||
},
|
||||
server,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_channel_bidirectional_data_flow() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let target_addr = listener.local_addr().unwrap();
|
||||
|
||||
let echo_server = tokio::spawn(async move {
|
||||
let (mut sock, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 64];
|
||||
let n = sock.read(&mut buf).await.unwrap();
|
||||
sock.write_all(&buf[..n]).await.unwrap();
|
||||
});
|
||||
|
||||
let (channel, mut channel_peer) = make_mock_channel();
|
||||
|
||||
let target = target_addr;
|
||||
let proxy = direct_config();
|
||||
tokio::spawn(async move {
|
||||
proxy_channel(channel, target, &proxy).await;
|
||||
});
|
||||
|
||||
channel_peer.write_all(b"ping").await.unwrap();
|
||||
channel_peer.flush().await.unwrap();
|
||||
|
||||
let mut buf = [0u8; 4];
|
||||
channel_peer.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"ping");
|
||||
|
||||
drop(channel_peer);
|
||||
let _ = echo_server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_channel_target_unreachable_closes_cleanly() {
|
||||
let target: SocketAddr = "240.0.0.1:1".parse().unwrap();
|
||||
let (channel, _channel_peer) = make_mock_channel();
|
||||
|
||||
let proxy = direct_config();
|
||||
proxy_channel(channel, target, &proxy).await;
|
||||
}
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
//! Control channel routing for reserved `alknet-*` destinations.
|
||||
//!
|
||||
//! SSH channels opened with a destination starting with `alknet-` are intercepted
|
||||
//! by the server and routed to a `ControlChannelHandler` instead of proxied to a
|
||||
//! TCP target. See ADR-018 for the design rationale.
|
||||
|
||||
use std::io;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub const ALKNET_CONTROL_DESTINATION: &str = "alknet-control";
|
||||
pub const ALKNET_PREFIX: &str = "alknet-";
|
||||
|
||||
pub fn is_reserved_destination(host: &str) -> bool {
|
||||
host.starts_with(ALKNET_PREFIX)
|
||||
}
|
||||
|
||||
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DuplexStream for T {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ControlChannelHandler: Send + Sync {
|
||||
async fn handle_channel(&self, stream: Box<dyn DuplexStream>);
|
||||
}
|
||||
|
||||
pub struct ControlChannelRouter {
|
||||
handler: Option<Box<dyn ControlChannelHandler>>,
|
||||
}
|
||||
|
||||
impl ControlChannelRouter {
|
||||
pub fn new(handler: Option<Box<dyn ControlChannelHandler>>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
|
||||
pub fn without_handler() -> Self {
|
||||
Self { handler: None }
|
||||
}
|
||||
|
||||
pub fn with_handler(handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
Self {
|
||||
handler: Some(handler),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_handler(&self) -> bool {
|
||||
self.handler.is_some()
|
||||
}
|
||||
|
||||
pub async fn route(&self, stream: Box<dyn DuplexStream>) -> io::Result<()> {
|
||||
match &self.handler {
|
||||
Some(handler) => {
|
||||
handler.handle_channel(stream).await;
|
||||
Ok(())
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionRefused,
|
||||
"no control channel handler configured",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_handler(&mut self) -> Option<Box<dyn ControlChannelHandler>> {
|
||||
self.handler.take()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[test]
|
||||
fn alknet_control_destination_constant() {
|
||||
assert_eq!(ALKNET_CONTROL_DESTINATION, "alknet-control");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alknet_prefix_constant() {
|
||||
assert_eq!(ALKNET_PREFIX, "alknet-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_destination_detected() {
|
||||
assert!(is_reserved_destination("alknet-control"));
|
||||
assert!(is_reserved_destination("alknet-status"));
|
||||
assert!(is_reserved_destination("alknet-events"));
|
||||
assert!(is_reserved_destination("alknet-"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_reserved_destination_passes_through() {
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("192.168.1.1"));
|
||||
assert!(!is_reserved_destination("alknet.example.com"));
|
||||
assert!(!is_reserved_destination(""));
|
||||
assert!(!is_reserved_destination("alkne-control"));
|
||||
assert!(!is_reserved_destination("ALKNET-control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_matching_case_sensitive() {
|
||||
assert!(!is_reserved_destination("Alknet-control"));
|
||||
assert!(!is_reserved_destination("ALKNET-control"));
|
||||
assert!(is_reserved_destination("alknet-Control"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_without_handler_has_no_handler() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
assert!(!router.has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn router_with_handler_has_handler() {
|
||||
struct DummyHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for DummyHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {}
|
||||
}
|
||||
let router = ControlChannelRouter::with_handler(Box::new(DummyHandler));
|
||||
assert!(router.has_handler());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_without_handler_returns_error() {
|
||||
let router = ControlChannelRouter::without_handler();
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_succeeds() {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TrackedHandler {
|
||||
called: Arc<AtomicBool>,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for TrackedHandler {
|
||||
async fn handle_channel(&self, _stream: Box<dyn DuplexStream>) {
|
||||
self.called.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
let called = Arc::new(AtomicBool::new(false));
|
||||
let handler = TrackedHandler {
|
||||
called: called.clone(),
|
||||
};
|
||||
let router = ControlChannelRouter::with_handler(Box::new(handler));
|
||||
let (_client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
let result = router.route(stream).await;
|
||||
assert!(result.is_ok());
|
||||
assert!(called.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_with_handler_can_read_write() {
|
||||
struct EchoHandler;
|
||||
#[async_trait]
|
||||
impl ControlChannelHandler for EchoHandler {
|
||||
async fn handle_channel(&self, mut stream: Box<dyn DuplexStream>) {
|
||||
let mut buf = [0u8; 64];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
stream.write_all(&buf[..n]).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let router = ControlChannelRouter::with_handler(Box::new(EchoHandler));
|
||||
let (client, server) = duplex(64);
|
||||
let stream: Box<dyn DuplexStream> = Box::new(server);
|
||||
tokio::spawn(async move {
|
||||
router.route(stream).await.unwrap();
|
||||
});
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
let mut client = client;
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_destination_matches_prefix() {
|
||||
assert!(is_reserved_destination(ALKNET_CONTROL_DESTINATION));
|
||||
}
|
||||
}
|
||||
@@ -1,974 +0,0 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use russh::keys::ssh_key::HashAlg;
|
||||
use russh::server::{Auth, Handler, Msg, Session};
|
||||
use russh::Channel;
|
||||
use russh::ChannelId;
|
||||
|
||||
use crate::auth::identity::{ConfigIdentityProvider, Identity, IdentityProvider};
|
||||
use crate::config::DynamicConfig;
|
||||
use crate::server::control_channel::{ControlChannelHandler, ControlChannelRouter, ALKNET_PREFIX};
|
||||
use crate::server::rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
|
||||
pub use crate::transport::TransportKind;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProxyMode {
|
||||
Direct,
|
||||
Socks5(SocketAddr),
|
||||
HttpConnect(SocketAddr),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyConfig {
|
||||
pub mode: ProxyMode,
|
||||
}
|
||||
|
||||
pub struct ServerHandler {
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
#[allow(dead_code)]
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
control_channel_router: ControlChannelRouter,
|
||||
#[allow(dead_code)]
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
connection_allowed: bool,
|
||||
auth_limiter: AuthAttemptLimiter,
|
||||
connected_at: Instant,
|
||||
authenticated_identity: Option<Identity>,
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
pub fn new(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
transport: TransportKind,
|
||||
connection_limiter: Arc<ConnectionRateLimiter>,
|
||||
max_auth_attempts: usize,
|
||||
) -> Self {
|
||||
let identity_provider: Arc<dyn IdentityProvider> =
|
||||
Arc::new(ConfigIdentityProvider::new(Arc::clone(&dynamic)));
|
||||
|
||||
let allowed = if let Some(addr) = remote_addr {
|
||||
let ip = addr.ip();
|
||||
if connection_limiter.check(ip) {
|
||||
connection_limiter.on_connect(ip);
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection opened"
|
||||
);
|
||||
true
|
||||
} else {
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
transport = %transport,
|
||||
"connection rejected"
|
||||
);
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
Self {
|
||||
dynamic,
|
||||
identity_provider,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
control_channel_router: ControlChannelRouter::without_handler(),
|
||||
transport,
|
||||
connection_limiter,
|
||||
connection_allowed: allowed,
|
||||
auth_limiter: AuthAttemptLimiter::new(max_auth_attempts),
|
||||
connected_at: Instant::now(),
|
||||
authenticated_identity: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_identity_provider(mut self, provider: Arc<dyn IdentityProvider>) -> Self {
|
||||
self.identity_provider = provider;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn authenticated_identity(&self) -> Option<&Identity> {
|
||||
self.authenticated_identity.as_ref()
|
||||
}
|
||||
|
||||
pub fn is_connection_allowed(&self) -> bool {
|
||||
self.connection_allowed
|
||||
}
|
||||
|
||||
pub fn remote_ip(&self) -> Option<IpAddr> {
|
||||
self.remote_addr.map(|a| a.ip())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerHandler {
|
||||
fn drop(&mut self) {
|
||||
if let Some(addr) = self.remote_addr {
|
||||
if self.connection_allowed {
|
||||
self.connection_limiter.on_disconnect(addr.ip());
|
||||
}
|
||||
let duration = self.connected_at.elapsed();
|
||||
tracing::info!(
|
||||
remote_addr = %addr,
|
||||
duration_secs = duration.as_secs_f64(),
|
||||
"connection closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerHandler {
|
||||
pub fn with_control_channel_handler(mut self, handler: Box<dyn ControlChannelHandler>) -> Self {
|
||||
self.control_channel_router = ControlChannelRouter::with_handler(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn control_channel_router(&self) -> &ControlChannelRouter {
|
||||
&self.control_channel_router
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Handler for ServerHandler {
|
||||
type Error = russh::Error;
|
||||
|
||||
async fn auth_publickey(
|
||||
&mut self,
|
||||
user: &str,
|
||||
public_key: &russh::keys::ssh_key::PublicKey,
|
||||
) -> Result<Auth, Self::Error> {
|
||||
if !self.auth_limiter.check() {
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
return Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
});
|
||||
}
|
||||
|
||||
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
|
||||
let remote_addr_display = self
|
||||
.remote_addr
|
||||
.map_or("unknown".to_string(), |a| a.to_string());
|
||||
|
||||
let identity = self
|
||||
.identity_provider
|
||||
.resolve_from_fingerprint(&fingerprint);
|
||||
|
||||
match identity {
|
||||
Some(id) => {
|
||||
self.authenticated_identity = Some(id);
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "accept",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(Auth::Accept)
|
||||
}
|
||||
None => {
|
||||
self.auth_limiter.on_failure();
|
||||
tracing::info!(
|
||||
remote_addr = %remote_addr_display,
|
||||
user = user,
|
||||
key_fingerprint = %fingerprint,
|
||||
result = "reject",
|
||||
"auth attempt"
|
||||
);
|
||||
Ok(Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn channel_open_direct_tcpip(
|
||||
&mut self,
|
||||
channel: Channel<Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
originator_address: &str,
|
||||
originator_port: u32,
|
||||
_session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
if host_to_connect.starts_with(ALKNET_PREFIX) {
|
||||
if !self.control_channel_router.has_handler() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let _ = channel;
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let identity = self
|
||||
.authenticated_identity
|
||||
.clone()
|
||||
.unwrap_or_else(|| Identity {
|
||||
id: String::new(),
|
||||
scopes: vec![],
|
||||
resources: std::collections::HashMap::new(),
|
||||
});
|
||||
|
||||
let policy = self.dynamic.load();
|
||||
let allowed = policy.forwarding.check(
|
||||
host_to_connect,
|
||||
port_to_connect as u16,
|
||||
&identity,
|
||||
self.transport.clone(),
|
||||
);
|
||||
|
||||
if !allowed {
|
||||
tracing::info!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
identity = %identity.id,
|
||||
transport = %self.transport,
|
||||
"forwarding denied by policy"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let target_host = host_to_connect.to_string();
|
||||
let target_port = port_to_connect;
|
||||
let proxy_config = self.outbound_proxy.clone().unwrap_or(ProxyConfig {
|
||||
mode: ProxyMode::Direct,
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
let target =
|
||||
match format!("{target_host}:{target_port}").parse::<std::net::SocketAddr>() {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => match tokio::net::lookup_host((&target_host[..], target_port as u16))
|
||||
.await
|
||||
{
|
||||
Ok(mut addrs) => match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return,
|
||||
},
|
||||
Err(_) => return,
|
||||
},
|
||||
};
|
||||
crate::server::channel_proxy::proxy_channel(
|
||||
channel.into_stream(),
|
||||
target,
|
||||
&proxy_config,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let _ = (originator_address, originator_port);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn channel_open_session(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected session channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_x11(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
"rejected x11 channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn channel_open_forwarded_tcpip(
|
||||
&mut self,
|
||||
_channel: Channel<Msg>,
|
||||
host_to_connect: &str,
|
||||
port_to_connect: u32,
|
||||
_originator_address: &str,
|
||||
_originator_port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
target = %format!("{host_to_connect}:{port_to_connect}"),
|
||||
"rejected forwarded-tcpip channel (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn exec_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
data: &[u8],
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
data_len = data.len(),
|
||||
"rejected exec request on channel (shell/exec not supported)"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shell_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected shell request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subsystem_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
name: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
subsystem = name,
|
||||
"rejected subsystem request on channel"
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pty_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
term: &str,
|
||||
col_width: u32,
|
||||
row_height: u32,
|
||||
pix_width: u32,
|
||||
pix_height: u32,
|
||||
modes: &[(russh::Pty, u32)],
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
term = term,
|
||||
"rejected pty request on channel"
|
||||
);
|
||||
let _ = (col_width, row_height, pix_width, pix_height, modes);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn env_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
variable_name: &str,
|
||||
variable_value: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
variable = variable_name,
|
||||
"rejected env request on channel"
|
||||
);
|
||||
let _ = variable_value;
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn x11_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
single_connection: bool,
|
||||
x11_auth_protocol: &str,
|
||||
x11_auth_cookie: &str,
|
||||
x11_screen_number: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected x11 request on channel"
|
||||
);
|
||||
let _ = (
|
||||
single_connection,
|
||||
x11_auth_protocol,
|
||||
x11_auth_cookie,
|
||||
x11_screen_number,
|
||||
);
|
||||
let _ = session.channel_failure(channel);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn agent_request(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
"rejected agent forwarding request on channel"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: &mut u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
address = address,
|
||||
port = *port,
|
||||
"rejected tcpip-forward request (remote port forwarding not supported)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn cancel_tcpip_forward(
|
||||
&mut self,
|
||||
address: &str,
|
||||
port: u32,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
let _ = (address, port, session);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn streamlocal_forward(
|
||||
&mut self,
|
||||
socket_path: &str,
|
||||
session: &mut Session,
|
||||
) -> Result<bool, Self::Error> {
|
||||
tracing::warn!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
socket_path = socket_path,
|
||||
"rejected streamlocal-forward request"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn signal(
|
||||
&mut self,
|
||||
channel: ChannelId,
|
||||
signal: russh::Sig,
|
||||
session: &mut Session,
|
||||
) -> Result<(), Self::Error> {
|
||||
tracing::debug!(
|
||||
remote_addr = ?self.remote_addr,
|
||||
channel = %channel,
|
||||
signal = ?signal,
|
||||
"received signal on channel (ignored)"
|
||||
);
|
||||
let _ = session;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::keys::KeySource;
|
||||
use crate::auth::ServerAuthConfig;
|
||||
use crate::config::AuthPolicy;
|
||||
use russh::keys::{decode_secret_key, PrivateKey};
|
||||
use std::io::Write;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
const ED25519_PUBLIC_KEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE58icPJFLfckR4M1PzF3XSpF3AU3zP9C6QI6AQiS/TV ubuntu@ns528096";
|
||||
|
||||
fn make_authorized_keys_file(keys_content: &str) -> tempfile::NamedTempFile {
|
||||
let mut f = tempfile::NamedTempFile::new().unwrap();
|
||||
f.write_all(keys_content.as_bytes()).unwrap();
|
||||
f.flush().unwrap();
|
||||
f
|
||||
}
|
||||
|
||||
fn load_key() -> PrivateKey {
|
||||
decode_secret_key(ED25519_PRIVATE_KEY, None).unwrap()
|
||||
}
|
||||
|
||||
fn make_auth_config(keys_content: &str) -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let f = make_authorized_keys_file(keys_content);
|
||||
let server_auth =
|
||||
ServerAuthConfig::from_keys_and_ca(Some(KeySource::File(f.path().to_path_buf())), None)
|
||||
.unwrap();
|
||||
let auth_policy = AuthPolicy::from_server_auth_config(server_auth);
|
||||
let dynamic = DynamicConfig::new(auth_policy);
|
||||
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||
}
|
||||
|
||||
fn make_empty_auth_config() -> Arc<ArcSwap<DynamicConfig>> {
|
||||
let dynamic = DynamicConfig::default();
|
||||
Arc::new(ArcSwap::new(Arc::new(dynamic)))
|
||||
}
|
||||
|
||||
fn default_limiter() -> Arc<ConnectionRateLimiter> {
|
||||
Arc::new(ConnectionRateLimiter::new(0))
|
||||
}
|
||||
|
||||
fn make_handler(
|
||||
dynamic: Arc<ArcSwap<DynamicConfig>>,
|
||||
outbound_proxy: Option<ProxyConfig>,
|
||||
remote_addr: Option<SocketAddr>,
|
||||
) -> ServerHandler {
|
||||
ServerHandler::new(
|
||||
dynamic,
|
||||
outbound_proxy,
|
||||
remote_addr,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_accepts_known_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_rejects_unknown_key() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let other_key_text = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIHeLC1lWiCYrXsf/85O/pkbUFZ6OGIt49PX3nw8iRoXE other@host";
|
||||
let other_ssh_key =
|
||||
russh::keys::parse_public_key_base64(other_key_text.split_whitespace().nth(1).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let result = handler
|
||||
.auth_publickey("testuser", &other_ssh_key)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_delegation_empty_config_rejects_all() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = make_handler(auth_config, None, None);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_logging_includes_remote_addr() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let remote_addr: SocketAddr = "203.0.113.50:12345".parse().unwrap();
|
||||
let mut handler = make_handler(auth_config, None, Some(remote_addr));
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_alknet_destination_routing() {
|
||||
use crate::server::control_channel::is_reserved_destination;
|
||||
assert!(is_reserved_destination("alknet-control"));
|
||||
assert!(is_reserved_destination("alknet-status"));
|
||||
assert!(is_reserved_destination("alknet-events"));
|
||||
assert!(!is_reserved_destination("example.com"));
|
||||
assert!(!is_reserved_destination("localhost"));
|
||||
assert!(!is_reserved_destination("alknet.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_without_control_handler_rejects_alknet_destinations() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler = make_handler(auth_config, None, None);
|
||||
assert!(!handler.control_channel_router().has_handler());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_mode_variants() {
|
||||
let direct = ProxyMode::Direct;
|
||||
let socks5 = ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap());
|
||||
let http = ProxyMode::HttpConnect("127.0.0.1:8080".parse().unwrap());
|
||||
|
||||
match direct {
|
||||
ProxyMode::Direct => {}
|
||||
_ => panic!("expected Direct"),
|
||||
}
|
||||
match socks5 {
|
||||
ProxyMode::Socks5(_) => {}
|
||||
_ => panic!("expected Socks5"),
|
||||
}
|
||||
match http {
|
||||
ProxyMode::HttpConnect(_) => {}
|
||||
_ => panic!("expected HttpConnect"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_holds_config() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let proxy = Some(ProxyConfig {
|
||||
mode: ProxyMode::Socks5("127.0.0.1:9050".parse().unwrap()),
|
||||
});
|
||||
let remote: Option<SocketAddr> = Some("10.0.0.1:22".parse().unwrap());
|
||||
|
||||
let handler = make_handler(auth_config, proxy.clone(), remote);
|
||||
assert!(handler.outbound_proxy.is_some());
|
||||
assert!(handler.remote_addr.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_handler_per_connection() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let handler1 = make_handler(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
);
|
||||
let handler2 = make_handler(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some("10.0.0.2:22".parse().unwrap()),
|
||||
);
|
||||
|
||||
assert!(handler1.remote_addr != handler2.remote_addr);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rate_limit_rejects_after_max_failures() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(0));
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
2,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
|
||||
let r1 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
r1,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
|
||||
let r2 = handler.auth_publickey("user", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
r2,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
|
||||
assert!(!handler.auth_limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_rate_limit_blocks_over_limit() {
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(1));
|
||||
let auth_config = make_empty_auth_config();
|
||||
let addr: SocketAddr = "10.0.0.1:22".parse().unwrap();
|
||||
|
||||
let h1 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(h1.is_connection_allowed());
|
||||
|
||||
let h2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter.clone(),
|
||||
10,
|
||||
);
|
||||
assert!(!h2.is_connection_allowed());
|
||||
|
||||
drop(h1);
|
||||
|
||||
let h3 = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some(addr),
|
||||
TransportKind::Tcp,
|
||||
limiter,
|
||||
10,
|
||||
);
|
||||
assert!(h3.is_connection_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_display() {
|
||||
assert_eq!(TransportKind::Tcp.to_string(), "tcp");
|
||||
assert_eq!(TransportKind::Tls { server_name: None }.to_string(), "tls");
|
||||
assert_eq!(
|
||||
TransportKind::Iroh {
|
||||
endpoint_id: String::new()
|
||||
}
|
||||
.to_string(),
|
||||
"iroh"
|
||||
);
|
||||
assert_eq!(
|
||||
TransportKind::WebTransport { server_name: None }.to_string(),
|
||||
"webtransport"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_log_includes_user_field() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tls { server_name: None },
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let _ = handler.auth_publickey("root", &ssh_key).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_closed_logs_duration_on_drop() {
|
||||
let auth_config = make_empty_auth_config();
|
||||
let _handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("203.0.113.50:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
Arc::new(ConnectionRateLimiter::new(0)),
|
||||
10,
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn config_reload_new_keys_take_effect() {
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
drop(handler);
|
||||
|
||||
let new_dynamic = DynamicConfig::default();
|
||||
auth_config.store(Arc::new(new_dynamic));
|
||||
|
||||
let mut handler2 = ServerHandler::new(
|
||||
auth_config.clone(),
|
||||
None,
|
||||
None,
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let result2 = handler2.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
Auth::Reject {
|
||||
proceed_with_methods: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarding_policy_deny_blocks_channel_open() {
|
||||
use crate::config::forwarding::{
|
||||
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
|
||||
};
|
||||
|
||||
let deny_policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("blocked.example.com".to_string()),
|
||||
action: ForwardingAction::Deny,
|
||||
principals: vec![],
|
||||
transports: vec![],
|
||||
}],
|
||||
};
|
||||
|
||||
let auth_config = make_auth_config(ED25519_PUBLIC_KEY);
|
||||
{
|
||||
let dynamic = auth_config.load();
|
||||
let new_dynamic = DynamicConfig {
|
||||
auth: dynamic.auth.clone(),
|
||||
forwarding: deny_policy,
|
||||
rate_limits: dynamic.rate_limits.clone(),
|
||||
credentials: dynamic.credentials.clone(),
|
||||
};
|
||||
drop(dynamic);
|
||||
auth_config.store(Arc::new(new_dynamic));
|
||||
}
|
||||
|
||||
let mut handler = ServerHandler::new(
|
||||
auth_config,
|
||||
None,
|
||||
Some("127.0.0.1:12345".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
);
|
||||
|
||||
let ssh_key = load_key().public_key().clone();
|
||||
let result = handler.auth_publickey("testuser", &ssh_key).await.unwrap();
|
||||
assert_eq!(result, Auth::Accept);
|
||||
assert!(handler.authenticated_identity().is_some());
|
||||
|
||||
let identity = handler.authenticated_identity().unwrap();
|
||||
let dynamic = handler.dynamic.load();
|
||||
assert!(!dynamic.forwarding.check(
|
||||
"blocked.example.com",
|
||||
443,
|
||||
identity,
|
||||
TransportKind::Tcp
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarding_policy_deny_with_custom_identity() {
|
||||
use crate::config::forwarding::{
|
||||
ForwardingAction, ForwardingPolicy, ForwardingRule, TargetPattern,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut resources = HashMap::new();
|
||||
resources.insert("service".to_string(), vec!["gitea".to_string()]);
|
||||
let identity = Identity {
|
||||
id: "SHA256:abc123".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources,
|
||||
};
|
||||
|
||||
let policy = ForwardingPolicy {
|
||||
default: ForwardingAction::Deny,
|
||||
rules: vec![ForwardingRule {
|
||||
target: TargetPattern::Host("allowed.example.com".to_string()),
|
||||
action: ForwardingAction::Allow,
|
||||
principals: vec!["SHA256:abc123".to_string()],
|
||||
transports: vec![TransportKind::Tcp],
|
||||
}],
|
||||
};
|
||||
|
||||
assert!(policy.check("allowed.example.com", 443, &identity, TransportKind::Tcp));
|
||||
assert!(!policy.check("denied.example.com", 443, &identity, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_handler_with_custom_identity_provider() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct MockIdentityProvider {
|
||||
identities: HashMap<String, Identity>,
|
||||
}
|
||||
|
||||
impl IdentityProvider for MockIdentityProvider {
|
||||
fn resolve_from_fingerprint(&self, fingerprint: &str) -> Option<Identity> {
|
||||
self.identities.get(fingerprint).cloned()
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &crate::auth::AuthToken) -> Option<Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let mut identities = HashMap::new();
|
||||
identities.insert(
|
||||
"SHA256:testkey".to_string(),
|
||||
Identity {
|
||||
id: "SHA256:testkey".to_string(),
|
||||
scopes: vec!["admin".to_string()],
|
||||
resources: HashMap::new(),
|
||||
},
|
||||
);
|
||||
|
||||
let provider = Arc::new(MockIdentityProvider { identities }) as Arc<dyn IdentityProvider>;
|
||||
let dynamic = make_empty_auth_config();
|
||||
|
||||
let handler = ServerHandler::new(
|
||||
dynamic,
|
||||
None,
|
||||
Some("10.0.0.1:22".parse().unwrap()),
|
||||
TransportKind::Tcp,
|
||||
default_limiter(),
|
||||
10,
|
||||
)
|
||||
.with_identity_provider(provider);
|
||||
|
||||
assert!(handler.authenticated_identity().is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
//! Server-side SSH connection handling.
|
||||
//!
|
||||
//! Provides `Server` for accepting SSH connections over any transport and proxying
|
||||
//! `direct-tcpip` channel requests to targets. Supports Ed25519 and certificate-authority
|
||||
//! auth, connection rate limiting, auth attempt limiting, stealth mode (fake nginx 404),
|
||||
//! and outbound proxy routing (direct/SOCKS5/HTTP CONNECT).
|
||||
//!
|
||||
//! Destination hosts starting with `alknet-` are reserved for internal use (control channel, ADR-018).
|
||||
|
||||
pub mod channel_proxy;
|
||||
pub mod control_channel;
|
||||
pub mod handler;
|
||||
pub mod rate_limit;
|
||||
pub mod serve;
|
||||
pub mod stealth;
|
||||
|
||||
pub use channel_proxy::{connect_outbound, proxy_channel};
|
||||
pub use control_channel::{
|
||||
is_reserved_destination, ControlChannelHandler, ControlChannelRouter, DuplexStream,
|
||||
ALKNET_CONTROL_DESTINATION, ALKNET_PREFIX,
|
||||
};
|
||||
pub use handler::{ProxyConfig, ProxyMode, ServerHandler};
|
||||
pub use rate_limit::{AuthAttemptLimiter, ConnectionRateLimiter};
|
||||
pub use serve::{
|
||||
DnsListenerConfig, HttpListenerConfig, ListenerConfig, ServeError, ServeOptions,
|
||||
ServeTransportMode, Server, StreamListenerConfig,
|
||||
};
|
||||
|
||||
pub use crate::transport::TransportKind;
|
||||
pub use stealth::{
|
||||
detect_protocol, handle_http_stealth, send_fake_nginx_404, validate_stealth_config,
|
||||
ProtocolDetection,
|
||||
};
|
||||
@@ -1,200 +0,0 @@
|
||||
//! Connection rate limiting and auth attempt limiting.
|
||||
//!
|
||||
//! `ConnectionRateLimiter` tracks per-IP active connections (thread-safe).
|
||||
//! `AuthAttemptLimiter` caps failed auth attempts per connection.
|
||||
//! These complement fail2ban on Linux and provide abuse protection on all platforms.
|
||||
//! See ADR-013.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
pub struct ConnectionRateLimiter {
|
||||
max_per_ip: usize,
|
||||
active: Mutex<HashMap<IpAddr, usize>>,
|
||||
}
|
||||
|
||||
impl ConnectionRateLimiter {
|
||||
pub fn new(max_per_ip: usize) -> Self {
|
||||
Self {
|
||||
max_per_ip,
|
||||
active: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, ip: IpAddr) -> bool {
|
||||
if self.max_per_ip == 0 {
|
||||
return true;
|
||||
}
|
||||
let active = self.active.lock().unwrap();
|
||||
let count = active.get(&ip).copied().unwrap_or(0);
|
||||
count < self.max_per_ip
|
||||
}
|
||||
|
||||
pub fn on_connect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
*active.entry(ip).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
pub fn on_disconnect(&self, ip: IpAddr) {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
if let Some(count) = active.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
active.remove(&ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthAttemptLimiter {
|
||||
max_attempts: usize,
|
||||
failures: usize,
|
||||
}
|
||||
|
||||
impl AuthAttemptLimiter {
|
||||
pub fn new(max_attempts: usize) -> Self {
|
||||
Self {
|
||||
max_attempts,
|
||||
failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self) -> bool {
|
||||
if self.max_attempts == 0 {
|
||||
return true;
|
||||
}
|
||||
self.failures < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn on_failure(&mut self) {
|
||||
self.failures += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
fn ip(n: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(192, 168, 1, n))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_when_under_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_blocks_when_at_limit() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_allows_after_disconnect() {
|
||||
let limiter = ConnectionRateLimiter::new(2);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
limiter.on_disconnect(ip(1));
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_unlimited_when_zero() {
|
||||
let limiter = ConnectionRateLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_connect(ip(1));
|
||||
}
|
||||
assert!(limiter.check(ip(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_tracks_per_ip_independently() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
limiter.on_connect(ip(1));
|
||||
assert!(!limiter.check(ip(1)));
|
||||
assert!(limiter.check(ip(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_ipv6() {
|
||||
let limiter = ConnectionRateLimiter::new(1);
|
||||
let ip6 = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
||||
limiter.on_connect(ip6);
|
||||
assert!(!limiter.check(ip6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_disconnect_removes_zero_entry() {
|
||||
let limiter = ConnectionRateLimiter::new(3);
|
||||
limiter.on_connect(ip(1));
|
||||
limiter.on_disconnect(ip(1));
|
||||
{
|
||||
let active = limiter.active.lock().unwrap();
|
||||
assert!(!active.contains_key(&ip(1)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_allows_when_under_limit() {
|
||||
let limiter = AuthAttemptLimiter::new(3);
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_blocks_after_max_failures() {
|
||||
let mut limiter = AuthAttemptLimiter::new(2);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_unlimited_when_zero() {
|
||||
let mut limiter = AuthAttemptLimiter::new(0);
|
||||
for _ in 0..100 {
|
||||
limiter.on_failure();
|
||||
}
|
||||
assert!(limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_limiter_still_allows_at_one_below_limit() {
|
||||
let mut limiter = AuthAttemptLimiter::new(3);
|
||||
limiter.on_failure();
|
||||
limiter.on_failure();
|
||||
assert!(limiter.check());
|
||||
limiter.on_failure();
|
||||
assert!(!limiter.check());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_limiter_thread_safety() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let limiter = Arc::new(ConnectionRateLimiter::new(100));
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let lim = Arc::clone(&limiter);
|
||||
handles.push(thread::spawn(move || {
|
||||
let ip_addr = ip((i % 3) as u8 + 1);
|
||||
lim.on_connect(ip_addr);
|
||||
assert!(lim.check(ip_addr));
|
||||
lim.on_disconnect(ip_addr);
|
||||
}));
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,316 +0,0 @@
|
||||
//! Stealth mode: protocol detection on TLS connections.
|
||||
//!
|
||||
//! When stealth mode is enabled with TLS transport, the server peeks at the first
|
||||
//! bytes after the TLS handshake to determine whether the client is speaking SSH
|
||||
//! or HTTP. When the `http` feature is enabled, detected HTTP traffic is routed to
|
||||
//! the axum router. When `http` is disabled, non-SSH connections receive a fake
|
||||
//! nginx 404 response, making the server appear as an ordinary web server to port
|
||||
//! scanners and DPI systems. See ADR-017.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
|
||||
use crate::auth::IdentityProvider;
|
||||
|
||||
const SSH_BANNER_PREFIX: &[u8] = b"SSH-2.0-";
|
||||
const FAKE_NGINX_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nServer: nginx\r\n\r\n";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProtocolDetection {
|
||||
Ssh,
|
||||
Http,
|
||||
}
|
||||
|
||||
pub async fn detect_protocol<S>(stream: S) -> (ProtocolDetection, BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let detection = match reader.fill_buf().await {
|
||||
Ok(buf) if buf.len() >= SSH_BANNER_PREFIX.len() => {
|
||||
if &buf[..SSH_BANNER_PREFIX.len()] == SSH_BANNER_PREFIX {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
Ok(buf) if !buf.is_empty() => {
|
||||
if buf.starts_with(SSH_BANNER_PREFIX) {
|
||||
ProtocolDetection::Ssh
|
||||
} else {
|
||||
ProtocolDetection::Http
|
||||
}
|
||||
}
|
||||
_ => ProtocolDetection::Http,
|
||||
};
|
||||
|
||||
(detection, reader)
|
||||
}
|
||||
|
||||
pub async fn send_fake_nginx_404<S>(reader: &mut BufReader<S>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let _ = reader.get_mut().write_all(FAKE_NGINX_404).await;
|
||||
let _ = reader.get_mut().shutdown().await;
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
pub async fn handle_http_stealth<S>(
|
||||
reader: BufReader<S>,
|
||||
identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
crate::http::router::serve_connection_from_reader(reader, identity_provider).await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "http"))]
|
||||
pub async fn handle_http_stealth<S>(
|
||||
mut reader: BufReader<S>,
|
||||
_identity_provider: Arc<dyn IdentityProvider>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
send_fake_nginx_404(&mut reader).await
|
||||
}
|
||||
|
||||
pub fn validate_stealth_config(stealth: bool, transport_is_tls: bool) -> Result<(), &'static str> {
|
||||
if stealth && !transport_is_tls {
|
||||
return Err("stealth mode requires TLS transport (--transport tls)");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
async fn write_and_detect(data: &[u8]) -> ProtocolDetection {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client.write_all(data).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
detection
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_detected() {
|
||||
let detection = write_and_detect(b"SSH-2.0-OpenSSH_9.0\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_other_implementation() {
|
||||
let detection = write_and_detect(b"SSH-2.0-russh_0.49\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_minimal() {
|
||||
let detection = write_and_detect(b"SSH-2.0-X\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_get_detected_as_http() {
|
||||
let detection = write_and_detect(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_post_detected_as_http() {
|
||||
let detection = write_and_detect(b"POST /api HTTP/1.1\r\nHost: example.com\r\n\r\n").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn random_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"\x01\x02\x03\x04\x05\x06\x07\x08").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_stream_detected_as_http() {
|
||||
let (client, server) = duplex(1024);
|
||||
drop(client);
|
||||
let (detection, _) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ssh_banner_bytes_preserved_by_bufreader() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
let banner = b"SSH-2.0-OpenSSH_9.0\r\n";
|
||||
client.write_all(banner).await.unwrap();
|
||||
client.write_all(b"subsequent data").await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Ssh);
|
||||
|
||||
let mut all_data = Vec::new();
|
||||
reader.read_to_end(&mut all_data).await.unwrap();
|
||||
assert!(
|
||||
all_data.starts_with(banner),
|
||||
"banner bytes must be preserved after detection"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fake_nginx_404_response() {
|
||||
let (client, server) = duplex(1024);
|
||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||
|
||||
client_write
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
drop(client_write);
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client_read.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.contains("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn protocol_detection_enum_equality() {
|
||||
assert_eq!(ProtocolDetection::Ssh, ProtocolDetection::Ssh);
|
||||
assert_eq!(ProtocolDetection::Http, ProtocolDetection::Http);
|
||||
assert_ne!(ProtocolDetection::Ssh, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_without_tls_rejected() {
|
||||
let result = validate_stealth_config(true, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("TLS transport"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(true, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tcp_accepted() {
|
||||
let result = validate_stealth_config(false, false);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_no_stealth_with_tls_accepted() {
|
||||
let result = validate_stealth_config(false, true);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_data_detected_as_http() {
|
||||
let detection = write_and_detect(b"GE").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_ssh_prefix_detected_as_http() {
|
||||
let detection = write_and_detect(b"SSH-1.").await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_request_gets_404_then_closed() {
|
||||
let (client, server) = duplex(1024);
|
||||
let mut client = client;
|
||||
|
||||
client
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (detection, mut reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
send_fake_nginx_404(&mut reader).await;
|
||||
|
||||
let mut buf = [0u8; 256];
|
||||
let n = client.read(&mut buf).await.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(response.starts_with("HTTP/1.1 404 Not Found"));
|
||||
assert!(response.contains("Server: nginx"));
|
||||
|
||||
let mut extra = [0u8; 16];
|
||||
let result = client.read(&mut extra).await;
|
||||
assert!(result.is_err() || result.unwrap() == 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[tokio::test]
|
||||
async fn stealth_handoff_routes_http_to_axum() {
|
||||
use crate::auth::{AuthToken, IdentityProvider};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
struct NullProvider;
|
||||
|
||||
impl IdentityProvider for NullProvider {
|
||||
fn resolve_from_fingerprint(
|
||||
&self,
|
||||
_fingerprint: &str,
|
||||
) -> Option<crate::auth::Identity> {
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_from_token(&self, _token: &AuthToken) -> Option<crate::auth::Identity> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let (client, server) = duplex(4096);
|
||||
let (mut client_read, mut client_write) = tokio::io::split(client);
|
||||
|
||||
client_write
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
drop(client_write);
|
||||
|
||||
let (detection, reader) = detect_protocol(server).await;
|
||||
assert_eq!(detection, ProtocolDetection::Http);
|
||||
|
||||
let provider: Arc<dyn IdentityProvider> = Arc::new(NullProvider);
|
||||
let handle = tokio::spawn(async move {
|
||||
handle_http_stealth(reader, provider).await;
|
||||
});
|
||||
|
||||
let mut buf = Vec::new();
|
||||
tokio::io::AsyncReadExt::read_to_end(&mut client_read, &mut buf)
|
||||
.await
|
||||
.unwrap();
|
||||
let response = String::from_utf8_lossy(&buf);
|
||||
assert!(
|
||||
response.contains("401"),
|
||||
"expected 401 from axum auth middleware, got: {response}"
|
||||
);
|
||||
assert!(
|
||||
!response.contains("nginx"),
|
||||
"should not contain fake nginx response when http feature is enabled"
|
||||
);
|
||||
|
||||
let _ = handle.await;
|
||||
}
|
||||
}
|
||||
@@ -1,490 +0,0 @@
|
||||
//! SOCKS5 proxy server.
|
||||
//!
|
||||
//! Listens on a local port and routes each SOCKS5 connection through an SSH
|
||||
//! `direct-tcpip` channel. Supports SOCKS5h (domain names resolved server-side)
|
||||
//! to prevent DNS leaks. Uses the `ChannelOpener` trait to abstract over the
|
||||
//! SSH channel mechanism, making it testable without a real SSH session.
|
||||
|
||||
mod protocol;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::debug;
|
||||
|
||||
use protocol::{Socks5Reply, Socks5Request, Socks5VersionMethod};
|
||||
|
||||
pub use protocol::Socks5Address;
|
||||
|
||||
const DEFAULT_SOCKS5_ADDR: &str = "127.0.0.1:1080";
|
||||
|
||||
pub trait ChannelOpener: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
fn open_channel(
|
||||
&self,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> impl std::future::Future<Output = Result<Self::Stream, ChannelOpenError>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChannelOpenError {
|
||||
#[error("session closed")]
|
||||
SessionClosed,
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed,
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
}
|
||||
|
||||
pub struct Socks5Server<C: ChannelOpener> {
|
||||
listen_addr: SocketAddr,
|
||||
channel_opener: Arc<C>,
|
||||
}
|
||||
|
||||
impl<C: ChannelOpener> Socks5Server<C> {
|
||||
pub fn new(channel_opener: C) -> Self {
|
||||
Self::with_addr(channel_opener, DEFAULT_SOCKS5_ADDR)
|
||||
}
|
||||
|
||||
pub fn with_addr(channel_opener: C, addr: &str) -> Self {
|
||||
let listen_addr: SocketAddr = addr.parse().expect("invalid SOCKS5 listen address");
|
||||
Self {
|
||||
listen_addr,
|
||||
channel_opener: Arc::new(channel_opener),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
|
||||
pub async fn run(self) -> Result<(), std::io::Error> {
|
||||
let listener = TcpListener::bind(self.listen_addr).await?;
|
||||
debug!("socks5 server listening on {}", self.listen_addr);
|
||||
loop {
|
||||
let (socket, _peer) = listener.accept().await?;
|
||||
let opener = Arc::clone(&self.channel_opener);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_socks5_connection(socket, opener).await {
|
||||
debug!("socks5 connection error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_socks5_connection<S, C>(mut socket: S, opener: Arc<C>) -> Result<(), Socks5Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
C: ChannelOpener,
|
||||
{
|
||||
let vm = Socks5VersionMethod::read_from(&mut socket).await?;
|
||||
if vm.version != 0x05 {
|
||||
return Err(Socks5Error::InvalidVersion(vm.version));
|
||||
}
|
||||
if !vm.methods.contains(&0x00) {
|
||||
let reply = [0x05, 0xFF];
|
||||
socket.write_all(&reply).await?;
|
||||
socket.shutdown().await?;
|
||||
return Err(Socks5Error::NoAcceptableAuth);
|
||||
}
|
||||
let reply = [0x05, 0x00];
|
||||
socket.write_all(&reply).await?;
|
||||
|
||||
let request = Socks5Request::read_from(&mut socket).await?;
|
||||
if request.version != 0x05 {
|
||||
return Err(Socks5Error::InvalidVersion(request.version));
|
||||
}
|
||||
if request.command != 0x01 {
|
||||
send_error_reply(&mut socket, Socks5Reply::command_not_supported()).await?;
|
||||
return Err(Socks5Error::UnsupportedCommand(request.command));
|
||||
}
|
||||
|
||||
let (host, port) = match &request.address {
|
||||
Socks5Address::Ipv4(addr) => (addr.to_string(), request.port),
|
||||
Socks5Address::Ipv6(addr) => (addr.to_string(), request.port),
|
||||
Socks5Address::Domain(name) => (name.clone(), request.port),
|
||||
};
|
||||
|
||||
match opener.open_channel(host, port).await {
|
||||
Ok(mut ssh_stream) => {
|
||||
let bind_addr = Socks5Address::Ipv4(std::net::Ipv4Addr::UNSPECIFIED);
|
||||
let reply = Socks5Reply::success(bind_addr, 0);
|
||||
reply.write_to(&mut socket).await?;
|
||||
tokio::io::copy_bidirectional(&mut socket, &mut ssh_stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => {
|
||||
send_error_reply(&mut socket, Socks5Reply::connection_refused()).await?;
|
||||
Err(Socks5Error::ChannelOpenFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_error_reply<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
socket: &mut S,
|
||||
reply: Socks5Reply,
|
||||
) -> Result<(), Socks5Error> {
|
||||
reply.write_to(socket).await?;
|
||||
let _ = socket.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Socks5Error {
|
||||
#[error("invalid SOCKS version: {0}")]
|
||||
InvalidVersion(u8),
|
||||
#[error("no acceptable auth method")]
|
||||
NoAcceptableAuth,
|
||||
#[error("unsupported command: {0}")]
|
||||
UnsupportedCommand(u8),
|
||||
#[error("channel open failed")]
|
||||
ChannelOpenFailed,
|
||||
#[error("io error")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub struct HandleChannelOpener<H: russh::client::Handler> {
|
||||
handle: Arc<Mutex<russh::client::Handle<H>>>,
|
||||
}
|
||||
|
||||
impl<H: russh::client::Handler> HandleChannelOpener<H> {
|
||||
pub fn new(handle: russh::client::Handle<H>) -> Self {
|
||||
Self {
|
||||
handle: Arc::new(Mutex::new(handle)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_arc(handle: Arc<Mutex<russh::client::Handle<H>>>) -> Self {
|
||||
Self { handle }
|
||||
}
|
||||
}
|
||||
|
||||
impl<H: russh::client::Handler + Send + Sync + 'static> ChannelOpener for HandleChannelOpener<H> {
|
||||
type Stream = russh::ChannelStream<russh::client::Msg>;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
let handle = self.handle.lock().await;
|
||||
if handle.is_closed() {
|
||||
return Err(ChannelOpenError::SessionClosed);
|
||||
}
|
||||
let channel = handle
|
||||
.channel_open_direct_tcpip(host, port as u32, "127.0.0.1", 0)
|
||||
.await
|
||||
.map_err(|_| ChannelOpenError::ChannelOpenFailed)?;
|
||||
Ok(channel.into_stream())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
|
||||
struct MockChannelOpener {
|
||||
fail: bool,
|
||||
}
|
||||
|
||||
impl ChannelOpener for MockChannelOpener {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
if self.fail {
|
||||
Err(ChannelOpenError::ChannelOpenFailed)
|
||||
} else {
|
||||
let (client, _server) = duplex(4096);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_socks5_greeting(methods: &[u8]) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, methods.len() as u8];
|
||||
buf.extend_from_slice(methods);
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_ipv4(addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x01];
|
||||
buf.extend_from_slice(&addr);
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_domain(domain: &str, port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x03];
|
||||
buf.push(domain.len() as u8);
|
||||
buf.extend_from_slice(domain.as_bytes());
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
fn build_socks5_connect_ipv6(addr: [u8; 16], port: u16) -> Vec<u8> {
|
||||
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
|
||||
buf.extend_from_slice(&addr);
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
async fn do_handshake(client: &mut DuplexStream) -> [u8; 2] {
|
||||
client
|
||||
.write_all(&build_socks5_greeting(&[0x00]))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
let mut resp = [0u8; 2];
|
||||
client.read_exact(&mut resp).await.unwrap();
|
||||
resp
|
||||
}
|
||||
|
||||
async fn do_connect_ipv4(client: &mut DuplexStream, addr: [u8; 4], port: u16) -> Vec<u8> {
|
||||
client
|
||||
.write_all(&build_socks5_connect_ipv4(addr, port))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
reply_buf.to_vec()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_no_auth_method() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
let resp = do_handshake(&mut client).await;
|
||||
assert_eq!(resp, [0x05, 0x00]);
|
||||
|
||||
let reply_buf = do_connect_ipv4(&mut client, [127, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_rejects_no_acceptable_method() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
client
|
||||
.write_all(&build_socks5_greeting(&[0x02]))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
client.read_exact(&mut resp).await.unwrap();
|
||||
assert_eq!(resp, [0x05, 0xFF]);
|
||||
|
||||
drop(client);
|
||||
let result = server_handle.await.unwrap();
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), Socks5Error::NoAcceptableAuth));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_ipv4() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 443).await;
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_domain() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
client
|
||||
.write_all(&build_socks5_connect_domain("example.com", 443))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn address_type_ipv6() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
let ipv6_addr = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||
client
|
||||
.write_all(&build_socks5_connect_ipv6(ipv6_addr, 443))
|
||||
.await
|
||||
.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn channel_open_failure_returns_socks5_error() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: true };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client, [10, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[0], 0x05);
|
||||
assert_eq!(reply_buf[1], 0x05);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_command_returns_error() {
|
||||
let (mut client, server) = duplex(4096);
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { handle_socks5_connection(server, Arc::new(opener)).await });
|
||||
|
||||
do_handshake(&mut client).await;
|
||||
|
||||
let mut bind_req = vec![0x05, 0x02, 0x00, 0x01];
|
||||
bind_req.extend_from_slice(&[127, 0, 0, 1]);
|
||||
bind_req.extend_from_slice(&80u16.to_be_bytes());
|
||||
client.write_all(&bind_req).await.unwrap();
|
||||
client.flush().await.unwrap();
|
||||
|
||||
let mut reply_buf = [0u8; 10];
|
||||
client.read_exact(&mut reply_buf).await.unwrap();
|
||||
assert_eq!(reply_buf[1], 0x07);
|
||||
|
||||
drop(client);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bidirectional_proxy_flow() {
|
||||
let (mut client_sock, server_sock) = duplex(4096);
|
||||
let (ssh_client, mut ssh_server) = duplex(4096);
|
||||
|
||||
let ssh_stream = Arc::new(Mutex::new(Some(ssh_client)));
|
||||
|
||||
struct ProxyOpener {
|
||||
stream: Arc<Mutex<Option<DuplexStream>>>,
|
||||
}
|
||||
|
||||
impl ChannelOpener for ProxyOpener {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn open_channel(
|
||||
&self,
|
||||
_host: String,
|
||||
_port: u16,
|
||||
) -> Result<Self::Stream, ChannelOpenError> {
|
||||
self.stream
|
||||
.lock()
|
||||
.await
|
||||
.take()
|
||||
.ok_or(ChannelOpenError::ChannelOpenFailed)
|
||||
}
|
||||
}
|
||||
|
||||
let opener = ProxyOpener {
|
||||
stream: Arc::clone(&ssh_stream),
|
||||
};
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(
|
||||
async move { handle_socks5_connection(server_sock, Arc::new(opener)).await },
|
||||
);
|
||||
|
||||
do_handshake(&mut client_sock).await;
|
||||
let reply_buf = do_connect_ipv4(&mut client_sock, [127, 0, 0, 1], 80).await;
|
||||
assert_eq!(reply_buf[1], 0x00);
|
||||
|
||||
let test_data = b"hello through tunnel";
|
||||
client_sock.write_all(test_data).await.unwrap();
|
||||
client_sock.flush().await.unwrap();
|
||||
|
||||
let mut received = vec![0u8; test_data.len()];
|
||||
AsyncReadExt::read_exact(&mut ssh_server, &mut received)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(&received, test_data);
|
||||
|
||||
let echo_data = b"response from tunnel";
|
||||
ssh_server.write_all(echo_data).await.unwrap();
|
||||
ssh_server.flush().await.unwrap();
|
||||
|
||||
let mut received_back = vec![0u8; echo_data.len()];
|
||||
client_sock.read_exact(&mut received_back).await.unwrap();
|
||||
assert_eq!(&received_back, echo_data);
|
||||
|
||||
drop(client_sock);
|
||||
drop(ssh_server);
|
||||
let _ = server_handle.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_listen_address() {
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
let server = Socks5Server::new(opener);
|
||||
assert_eq!(server.listen_addr(), "127.0.0.1:1080".parse().unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn custom_listen_address() {
|
||||
let opener = MockChannelOpener { fail: false };
|
||||
let server = Socks5Server::with_addr(opener, "127.0.0.1:9050");
|
||||
assert_eq!(server.listen_addr(), "127.0.0.1:9050".parse().unwrap());
|
||||
}
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Socks5Address {
|
||||
Ipv4(Ipv4Addr),
|
||||
Ipv6(Ipv6Addr),
|
||||
Domain(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5VersionMethod {
|
||||
pub version: u8,
|
||||
pub methods: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Socks5VersionMethod {
|
||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||
let version = reader.read_u8().await?;
|
||||
let nmethods = reader.read_u8().await?;
|
||||
let mut methods = vec![0u8; nmethods as usize];
|
||||
reader.read_exact(&mut methods).await?;
|
||||
Ok(Self { version, methods })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5Request {
|
||||
pub version: u8,
|
||||
pub command: u8,
|
||||
pub address: Socks5Address,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl Socks5Request {
|
||||
pub async fn read_from<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Self> {
|
||||
let version = reader.read_u8().await?;
|
||||
let command = reader.read_u8().await?;
|
||||
let _rsv = reader.read_u8().await?;
|
||||
let atyp = reader.read_u8().await?;
|
||||
|
||||
let address = match atyp {
|
||||
0x01 => {
|
||||
let mut octets = [0u8; 4];
|
||||
reader.read_exact(&mut octets).await?;
|
||||
Socks5Address::Ipv4(Ipv4Addr::from(octets))
|
||||
}
|
||||
0x04 => {
|
||||
let mut octets = [0u8; 16];
|
||||
reader.read_exact(&mut octets).await?;
|
||||
Socks5Address::Ipv6(Ipv6Addr::from(octets))
|
||||
}
|
||||
0x03 => {
|
||||
let len = reader.read_u8().await?;
|
||||
let mut domain = vec![0u8; len as usize];
|
||||
reader.read_exact(&mut domain).await?;
|
||||
Socks5Address::Domain(String::from_utf8_lossy(&domain).into_owned())
|
||||
}
|
||||
_ => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("unsupported address type: {atyp}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let port = reader.read_u16().await?;
|
||||
|
||||
Ok(Self {
|
||||
version,
|
||||
command,
|
||||
address,
|
||||
port,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socks5Reply {
|
||||
pub version: u8,
|
||||
pub reply: u8,
|
||||
pub address: Socks5Address,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl Socks5Reply {
|
||||
pub fn success(address: Socks5Address, port: u16) -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x00,
|
||||
address,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connection_refused() -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x05,
|
||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn command_not_supported() -> Self {
|
||||
Self {
|
||||
version: 0x05,
|
||||
reply: 0x07,
|
||||
address: Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u8(self.version).await?;
|
||||
writer.write_u8(self.reply).await?;
|
||||
writer.write_u8(0x00).await?;
|
||||
match &self.address {
|
||||
Socks5Address::Ipv4(addr) => {
|
||||
writer.write_u8(0x01).await?;
|
||||
writer.write_all(&addr.octets()).await?;
|
||||
}
|
||||
Socks5Address::Ipv6(addr) => {
|
||||
writer.write_u8(0x04).await?;
|
||||
writer.write_all(&addr.octets()).await?;
|
||||
}
|
||||
Socks5Address::Domain(name) => {
|
||||
writer.write_u8(0x03).await?;
|
||||
writer.write_u8(name.len() as u8).await?;
|
||||
writer.write_all(name.as_bytes()).await?;
|
||||
}
|
||||
}
|
||||
writer.write_u16(self.port).await?;
|
||||
writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_version_method_no_auth() {
|
||||
let data = [0x05, 0x01, 0x00];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(vm.version, 0x05);
|
||||
assert_eq!(vm.methods, vec![0x00]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_version_method_multiple() {
|
||||
let data = [0x05, 0x02, 0x00, 0x02];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let vm = Socks5VersionMethod::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(vm.version, 0x05);
|
||||
assert_eq!(vm.methods, vec![0x00, 0x02]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_ipv4() {
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x01];
|
||||
data.extend_from_slice(&[10, 0, 0, 1]);
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert_eq!(req.address, Socks5Address::Ipv4(Ipv4Addr::new(10, 0, 0, 1)));
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_ipv6() {
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x04];
|
||||
let octets: [u8; 16] = [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
|
||||
data.extend_from_slice(&octets);
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert!(matches!(req.address, Socks5Address::Ipv6(_)));
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_domain() {
|
||||
let domain = "example.com";
|
||||
let mut data = vec![0x05, 0x01, 0x00, 0x03];
|
||||
data.push(domain.len() as u8);
|
||||
data.extend_from_slice(domain.as_bytes());
|
||||
data.extend_from_slice(&443u16.to_be_bytes());
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let req = Socks5Request::read_from(&mut cursor).await.unwrap();
|
||||
assert_eq!(req.version, 0x05);
|
||||
assert_eq!(req.command, 0x01);
|
||||
assert_eq!(
|
||||
req.address,
|
||||
Socks5Address::Domain("example.com".to_string())
|
||||
);
|
||||
assert_eq!(req.port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_request_unsupported_address_type() {
|
||||
let data = [0x05, 0x01, 0x00, 0x05];
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
let result = Socks5Request::read_from(&mut cursor).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_success_ipv4() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::UNSPECIFIED), 0);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x00);
|
||||
assert_eq!(buf[2], 0x00);
|
||||
assert_eq!(buf[3], 0x01);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_connection_refused() {
|
||||
let reply = Socks5Reply::connection_refused();
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x05);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reply_command_not_supported() {
|
||||
let reply = Socks5Reply::command_not_supported();
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
assert_eq!(buf[0], 0x05);
|
||||
assert_eq!(buf[1], 0x07);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_ipv4_reply() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), 1080);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(version, 0x05);
|
||||
assert_eq!(atyp, 0x01);
|
||||
let mut octets = [0u8; 4];
|
||||
cursor.read_exact(&mut octets).await.unwrap();
|
||||
assert_eq!(Ipv4Addr::from(octets), Ipv4Addr::new(127, 0, 0, 1));
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 1080);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_ipv6_reply() {
|
||||
let addr = Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1);
|
||||
let reply = Socks5Reply::success(Socks5Address::Ipv6(addr), 443);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let _version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(atyp, 0x04);
|
||||
let mut octets = [0u8; 16];
|
||||
cursor.read_exact(&mut octets).await.unwrap();
|
||||
assert_eq!(Ipv6Addr::from(octets), addr);
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn roundtrip_domain_reply() {
|
||||
let reply = Socks5Reply::success(Socks5Address::Domain("example.com".to_string()), 8080);
|
||||
let mut buf = Vec::new();
|
||||
reply.write_to(&mut buf).await.unwrap();
|
||||
|
||||
let mut cursor = Cursor::new(&buf[..]);
|
||||
let _version = cursor.read_u8().await.unwrap();
|
||||
let _reply_code = cursor.read_u8().await.unwrap();
|
||||
let _rsv = cursor.read_u8().await.unwrap();
|
||||
let atyp = cursor.read_u8().await.unwrap();
|
||||
assert_eq!(atyp, 0x03);
|
||||
let len = cursor.read_u8().await.unwrap();
|
||||
let mut domain = vec![0u8; len as usize];
|
||||
cursor.read_exact(&mut domain).await.unwrap();
|
||||
assert_eq!(String::from_utf8(domain).unwrap(), "example.com");
|
||||
let port = cursor.read_u16().await.unwrap();
|
||||
assert_eq!(port, 8080);
|
||||
}
|
||||
}
|
||||
203
crates/alknet-core/src/store.rs
Normal file
203
crates/alknet-core/src/store.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
//! Credential store: `CredentialStore` repo trait, `InMemoryCredentialStore`
|
||||
//! default adapter, `EncryptedData` core mirror, and the shared `StoreError`.
|
||||
//!
|
||||
//! See `docs/architecture/crates/core/auth.md` and ADR-031 / ADR-035 for the
|
||||
//! full specification. The store persists `EncryptedData` blobs keyed by
|
||||
//! provider; it never decrypts (ADR-025 — the vault is the sole decryption
|
||||
//! boundary).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StoreError {
|
||||
#[error("backend error: {message}")]
|
||||
Backend { message: String },
|
||||
#[error("not found: {entity}")]
|
||||
NotFound { entity: String },
|
||||
#[error("serialization error: {message}")]
|
||||
Serialization { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EncryptedData {
|
||||
pub key_version: u32,
|
||||
pub salt: Vec<u8>,
|
||||
pub iv: Vec<u8>,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CredentialStore: Send + Sync {
|
||||
fn get(&self, provider: &str) -> Option<EncryptedData>;
|
||||
async fn put(&self, provider: &str, data: &EncryptedData) -> Result<(), StoreError>;
|
||||
async fn delete(&self, provider: &str) -> Result<(), StoreError>;
|
||||
}
|
||||
|
||||
pub struct InMemoryCredentialStore {
|
||||
entries: RwLock<HashMap<String, EncryptedData>>,
|
||||
}
|
||||
|
||||
impl InMemoryCredentialStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_entries(entries: HashMap<String, EncryptedData>) -> Self {
|
||||
Self {
|
||||
entries: RwLock::new(entries),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryCredentialStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialStore for InMemoryCredentialStore {
|
||||
fn get(&self, provider: &str) -> Option<EncryptedData> {
|
||||
let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
|
||||
entries.get(provider).cloned()
|
||||
}
|
||||
|
||||
async fn put(&self, provider: &str, data: &EncryptedData) -> Result<(), StoreError> {
|
||||
let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
|
||||
entries.insert(provider.to_string(), data.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete(&self, provider: &str) -> Result<(), StoreError> {
|
||||
let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
|
||||
entries.remove(provider);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_encrypted_data() -> EncryptedData {
|
||||
EncryptedData {
|
||||
key_version: 2,
|
||||
salt: vec![],
|
||||
iv: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
data: vec![0xde, 0xad, 0xbe, 0xef],
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_get_put_delete_round_trip() {
|
||||
let store = InMemoryCredentialStore::new();
|
||||
let data = sample_encrypted_data();
|
||||
|
||||
assert!(store.get("openai").is_none());
|
||||
|
||||
store.put("openai", &data).await.unwrap();
|
||||
let retrieved = store.get("openai").expect("provider should be present");
|
||||
assert_eq!(retrieved.key_version, data.key_version);
|
||||
assert_eq!(retrieved.salt, data.salt);
|
||||
assert_eq!(retrieved.iv, data.iv);
|
||||
assert_eq!(retrieved.data, data.data);
|
||||
|
||||
store.delete("openai").await.unwrap();
|
||||
assert!(store.get("openai").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_get_returns_none_for_missing_provider() {
|
||||
let store = InMemoryCredentialStore::new();
|
||||
assert!(store.get("never-configured").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_delete_missing_provider_is_ok() {
|
||||
let store = InMemoryCredentialStore::new();
|
||||
store.delete("absent").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_put_replaces_existing() {
|
||||
let store = InMemoryCredentialStore::new();
|
||||
let first = sample_encrypted_data();
|
||||
let mut second = sample_encrypted_data();
|
||||
second.data = vec![0xc0, 0xff, 0xee];
|
||||
|
||||
store.put("anthropic", &first).await.unwrap();
|
||||
store.put("anthropic", &second).await.unwrap();
|
||||
|
||||
let retrieved = store.get("anthropic").expect("provider should be present");
|
||||
assert_eq!(retrieved.data, second.data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_with_entries_seeds_store() {
|
||||
let mut entries = HashMap::new();
|
||||
entries.insert("github".to_string(), sample_encrypted_data());
|
||||
let store = InMemoryCredentialStore::with_entries(entries);
|
||||
|
||||
assert!(store.get("github").is_some());
|
||||
assert!(store.get("openai").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypted_data_serializes_and_deserializes_round_trip() {
|
||||
let data = sample_encrypted_data();
|
||||
let json = serde_json::to_string(&data).expect("serialize");
|
||||
let decoded: EncryptedData = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(decoded.key_version, data.key_version);
|
||||
assert_eq!(decoded.salt, data.salt);
|
||||
assert_eq!(decoded.iv, data.iv);
|
||||
assert_eq!(decoded.data, data.data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypted_data_round_trips_non_empty_salt() {
|
||||
let data = EncryptedData {
|
||||
key_version: 1,
|
||||
salt: vec![0xab, 0xcd, 0xef],
|
||||
iv: vec![0; 12],
|
||||
data: vec![0x01, 0x02, 0x03],
|
||||
};
|
||||
let json = serde_json::to_string(&data).expect("serialize");
|
||||
let decoded: EncryptedData = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(decoded.salt, data.salt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn store_error_display_formatting() {
|
||||
let backend = StoreError::Backend {
|
||||
message: "disk full".to_string(),
|
||||
};
|
||||
assert_eq!(backend.to_string(), "backend error: disk full");
|
||||
|
||||
let not_found = StoreError::NotFound {
|
||||
entity: "openai".to_string(),
|
||||
};
|
||||
assert_eq!(not_found.to_string(), "not found: openai");
|
||||
|
||||
let serialization = StoreError::Serialization {
|
||||
message: "invalid utf8".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
serialization.to_string(),
|
||||
"serialization error: invalid utf8"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn store_error_is_non_exhaustive() {
|
||||
let err = StoreError::Backend {
|
||||
message: "x".to_string(),
|
||||
};
|
||||
let _ = err.to_string();
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
|
||||
|
||||
#[cfg(feature = "transport-traits")]
|
||||
pub use crate::transport::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
pub use local_traits::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(not(feature = "transport-traits"))]
|
||||
mod local_traits {
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls { server_name: Option<String> },
|
||||
Iroh { endpoint_id: String },
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockStream {
|
||||
inner: DuplexStream,
|
||||
}
|
||||
|
||||
impl MockStream {
|
||||
pub fn new(inner: DuplexStream) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for MockStream {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MockStream {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for MockStream {}
|
||||
|
||||
pub struct MockTransport {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransport {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (client, _) = tokio::io::duplex(self.buf_size);
|
||||
Ok(MockStream::new(client))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockTransportAcceptor {
|
||||
buf_size: usize,
|
||||
}
|
||||
|
||||
impl MockTransportAcceptor {
|
||||
pub fn new(buf_size: usize) -> Self {
|
||||
Self { buf_size }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportAcceptor for MockTransportAcceptor {
|
||||
type Stream = MockStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (_, server) = tokio::io::duplex(self.buf_size);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((MockStream::new(server), info))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mock_pair(buf_size: usize) -> (MockStream, MockStream) {
|
||||
let (client, server) = tokio::io::duplex(buf_size);
|
||||
(MockStream::new(client), MockStream::new(server))
|
||||
}
|
||||
@@ -1,352 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
use rustls::ServerConfig;
|
||||
use rustls_acme::caches::DirCache;
|
||||
use rustls_acme::{AcmeConfig, AcmeState, ResolvesServerCertAcme};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
|
||||
use tracing::{error, info};
|
||||
|
||||
use super::{TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AcmeMode {
|
||||
Domain { domain: String },
|
||||
Ip,
|
||||
}
|
||||
|
||||
pub struct AcmeCertProvider {
|
||||
mode: AcmeMode,
|
||||
cache_dir: Option<PathBuf>,
|
||||
directory_url: String,
|
||||
contact: Vec<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AcmeCertProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AcmeCertProvider")
|
||||
.field("mode", &self.mode)
|
||||
.field("cache_dir", &self.cache_dir)
|
||||
.field("directory_url", &self.directory_url)
|
||||
.field("contact", &self.contact)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl AcmeCertProvider {
|
||||
pub fn new(mode: AcmeMode) -> Self {
|
||||
Self {
|
||||
mode,
|
||||
cache_dir: None,
|
||||
directory_url: rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY.to_string(),
|
||||
contact: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn domain(domain: impl Into<String>) -> Self {
|
||||
Self::new(AcmeMode::Domain {
|
||||
domain: domain.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn ip() -> Self {
|
||||
Self::new(AcmeMode::Ip)
|
||||
}
|
||||
|
||||
pub fn with_cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||
self.cache_dir = Some(dir.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_directory(mut self, url: impl Into<String>) -> Self {
|
||||
self.directory_url = url.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_production_directory(mut self) -> Self {
|
||||
self.directory_url = rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_contact(mut self, contact: impl Into<String>) -> Self {
|
||||
self.contact.push(contact.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> &AcmeMode {
|
||||
&self.mode
|
||||
}
|
||||
|
||||
fn build_acme_state(&self) -> (AcmeState<std::io::Error>, Arc<ResolvesServerCertAcme>) {
|
||||
let domains: Vec<String> = match &self.mode {
|
||||
AcmeMode::Domain { domain } => vec![domain.clone()],
|
||||
AcmeMode::Ip => vec![],
|
||||
};
|
||||
|
||||
let base_config = AcmeConfig::new(domains)
|
||||
.directory(&self.directory_url)
|
||||
.contact(self.contact.clone());
|
||||
|
||||
let state = match &self.cache_dir {
|
||||
Some(cache_dir) => base_config.cache(DirCache::new(cache_dir.clone())).state(),
|
||||
None => base_config
|
||||
.cache(rustls_acme::caches::NoCache::default())
|
||||
.state(),
|
||||
};
|
||||
|
||||
let resolver = state.resolver();
|
||||
(state, resolver)
|
||||
}
|
||||
|
||||
pub fn build_server_config_with_resolver(
|
||||
&self,
|
||||
resolver: Arc<ResolvesServerCertAcme>,
|
||||
) -> Result<Arc<ServerConfig>> {
|
||||
let provider = default_provider().into();
|
||||
let mut config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(resolver);
|
||||
config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcmeTlsAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
#[allow(dead_code)]
|
||||
server_config: Arc<ServerConfig>,
|
||||
tokio_acceptor: TokioTlsAcceptor,
|
||||
}
|
||||
|
||||
impl AcmeTlsAcceptor {
|
||||
pub async fn bind_acme(addr: SocketAddr, provider: Arc<AcmeCertProvider>) -> Result<Self> {
|
||||
let (state, resolver) = provider.build_acme_state();
|
||||
|
||||
let server_config = provider.build_server_config_with_resolver(resolver.clone())?;
|
||||
|
||||
Self::spawn_state_worker(state, resolver);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
|
||||
fn spawn_state_worker(state: AcmeState<std::io::Error>, resolver: Arc<ResolvesServerCertAcme>) {
|
||||
use futures::StreamExt;
|
||||
|
||||
let task = async move {
|
||||
let mut state = state;
|
||||
while let Some(event) = state.next().await {
|
||||
match event {
|
||||
Ok(ok) => {
|
||||
if let rustls_acme::EventOk::DeployedNewCert = ok {
|
||||
info!("ACME: new certificate deployed");
|
||||
} else {
|
||||
info!("ACME event: {:?}", ok);
|
||||
}
|
||||
}
|
||||
Err(err) => error!("ACME event error: {:?}", err),
|
||||
}
|
||||
if Arc::strong_count(&resolver) == 1 {
|
||||
info!("ACME resolver dropped, stopping background task");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
tokio::spawn(task);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportAcceptor for AcmeTlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||
|
||||
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tls { server_name },
|
||||
};
|
||||
|
||||
Ok((tls_stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_domain_mode() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||
if let AcmeMode::Domain { domain } = provider.mode() {
|
||||
assert_eq!(domain, "example.com");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_ip_mode() {
|
||||
let provider = AcmeCertProvider::ip();
|
||||
assert!(matches!(provider.mode(), AcmeMode::Ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_default_staging_directory() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_STAGING_DIRECTORY
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_production_directory() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_production_directory();
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_custom_directory() {
|
||||
let provider =
|
||||
AcmeCertProvider::domain("example.com").with_directory("https://custom.acme.dir/");
|
||||
assert_eq!(provider.directory_url, "https://custom.acme.dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_with_cache_dir() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/acme_cache");
|
||||
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/acme_cache")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_with_contact() {
|
||||
let provider =
|
||||
AcmeCertProvider::domain("example.com").with_contact("mailto:admin@example.com");
|
||||
assert_eq!(
|
||||
provider.contact,
|
||||
vec!["mailto:admin@example.com".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_state_domain() {
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
let (_state, resolver) = provider.build_acme_state();
|
||||
assert!(Arc::strong_count(&resolver) >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_state_with_cache() {
|
||||
let provider = AcmeCertProvider::domain("example.com").with_cache_dir("/tmp/test_cache");
|
||||
let (_state, resolver) = provider.build_acme_state();
|
||||
assert!(Arc::strong_count(&resolver) >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_build_server_config() {
|
||||
let _ = default_provider().install_default();
|
||||
let provider = AcmeCertProvider::domain("example.com");
|
||||
let (_, resolver) = provider.build_acme_state();
|
||||
let config = provider
|
||||
.build_server_config_with_resolver(resolver)
|
||||
.unwrap();
|
||||
assert!(!config.alpn_protocols.is_empty());
|
||||
assert!(config
|
||||
.alpn_protocols
|
||||
.iter()
|
||||
.any(|p| p == ACME_TLS_ALPN_NAME));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_mode_domain_debug() {
|
||||
let mode = AcmeMode::Domain {
|
||||
domain: "test.example.com".to_string(),
|
||||
};
|
||||
let debug_str = format!("{:?}", mode);
|
||||
assert!(debug_str.contains("test.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_mode_ip_debug() {
|
||||
let mode = AcmeMode::Ip;
|
||||
let debug_str = format!("{:?}", mode);
|
||||
assert!(debug_str.contains("Ip"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn acme_cert_provider_builder_chain() {
|
||||
let provider = AcmeCertProvider::domain("test.example.com")
|
||||
.with_production_directory()
|
||||
.with_cache_dir("/tmp/cache")
|
||||
.with_contact("mailto:admin@test.example.com");
|
||||
assert!(matches!(provider.mode(), AcmeMode::Domain { .. }));
|
||||
assert_eq!(
|
||||
provider.directory_url,
|
||||
rustls_acme::acme::LETS_ENCRYPT_PRODUCTION_DIRECTORY
|
||||
);
|
||||
assert_eq!(provider.cache_dir, Some(PathBuf::from("/tmp/cache")));
|
||||
assert_eq!(provider.contact.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acme_tls_acceptor_bind_acme() {
|
||||
let _ = default_provider().install_default();
|
||||
let provider = Arc::new(AcmeCertProvider::domain("example.com"));
|
||||
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
|
||||
let acceptor = AcmeTlsAcceptor::bind_acme(addr, provider).await.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn acme_staging_domain_cert_provisioning() {
|
||||
let _ = default_provider().install_default();
|
||||
|
||||
let cache_dir = tempfile::tempdir().unwrap();
|
||||
let provider = Arc::new(
|
||||
AcmeCertProvider::domain("acme-test.example.com")
|
||||
.with_cache_dir(cache_dir.path())
|
||||
.with_contact("mailto:admin@example.com"),
|
||||
);
|
||||
|
||||
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let result = AcmeTlsAcceptor::bind_acme(addr, provider).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"ACME TlsAcceptor should bind: {:?}",
|
||||
result.err()
|
||||
);
|
||||
|
||||
let acceptor = result.unwrap();
|
||||
assert_eq!(acceptor.listen_addr().port(), 443);
|
||||
}
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use iroh::{
|
||||
endpoint::RecvStream, node_info::NodeIdExt, Endpoint, NodeId, RelayMap, RelayMode, RelayUrl,
|
||||
};
|
||||
use tokio::io;
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
pub const ALPN: &[u8] = b"alknet-ssh";
|
||||
const DEFAULT_RELAY_URL: &str = "https://relay.iroh.network/";
|
||||
|
||||
/// A client-side iroh QUIC P2P transport that connects to a remote iroh endpoint.
|
||||
///
|
||||
/// Connects via `Endpoint::connect(node_id, alpn)`, opens a bidirectional
|
||||
/// QUIC stream with `conn.open_bi()`, and joins the halves with
|
||||
/// `tokio::io::join(recv, send)` to produce a duplex stream for russh.
|
||||
/// Per ADR-003, `tokio::io::join` is used instead of a custom wrapper.
|
||||
///
|
||||
/// Use [`IrohTransport::new`] to create a standalone endpoint, or
|
||||
/// [`IrohTransport::from_endpoint`] to share an existing iroh `Endpoint`
|
||||
/// with other protocol handlers (blobs, gossip, docs).
|
||||
pub struct IrohTransport {
|
||||
node_id: NodeId,
|
||||
endpoint: Endpoint,
|
||||
owned: bool,
|
||||
}
|
||||
|
||||
impl IrohTransport {
|
||||
/// Create a new iroh transport with its own dedicated endpoint.
|
||||
///
|
||||
/// The endpoint is created with the `alknet-ssh` ALPN and the provided
|
||||
/// relay URL. Use this when alknet is the only iroh service on this node.
|
||||
pub async fn new(
|
||||
node_id: NodeId,
|
||||
relay_url: Option<RelayUrl>,
|
||||
proxy_url: Option<url::Url>,
|
||||
) -> Result<Self> {
|
||||
let relay_url = relay_url.unwrap_or_else(|| {
|
||||
DEFAULT_RELAY_URL
|
||||
.parse()
|
||||
.expect("default relay URL is valid")
|
||||
});
|
||||
let relay_map = RelayMap::from_url(relay_url);
|
||||
let mut builder = Endpoint::builder()
|
||||
.relay_mode(RelayMode::Custom(relay_map))
|
||||
.alpns(vec![ALPN.to_vec()]);
|
||||
if let Some(ref proxy) = proxy_url {
|
||||
builder = builder.proxy_url(proxy.clone());
|
||||
}
|
||||
let endpoint = builder.bind().await?;
|
||||
Ok(Self {
|
||||
node_id,
|
||||
endpoint,
|
||||
owned: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an iroh transport using an existing shared endpoint.
|
||||
///
|
||||
/// The endpoint must already have the `alknet-ssh` ALPN registered
|
||||
/// (typically via [`iroh::protocol::Router::builder`]). This enables
|
||||
/// running alknet alongside iroh-blobs, iroh-gossip, iroh-docs, and
|
||||
/// other protocol handlers on the same QUIC endpoint — one connection
|
||||
/// per peer, multiplexed by ALPN.
|
||||
pub fn from_endpoint(node_id: NodeId, endpoint: Endpoint) -> Self {
|
||||
Self {
|
||||
node_id,
|
||||
endpoint,
|
||||
owned: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_id(&self) -> String {
|
||||
self.endpoint.node_id().to_z32()
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &Endpoint {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
pub fn owned(&self) -> bool {
|
||||
self.owned
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for IrohTransport {
|
||||
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let conn = self.endpoint.connect(self.node_id, ALPN).await?;
|
||||
let (send, recv) = conn.open_bi().await?;
|
||||
Ok(io::join(recv, send))
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("iroh://{}", self.node_id.to_z32())
|
||||
}
|
||||
}
|
||||
|
||||
/// A server-side iroh QUIC P2P transport acceptor that listens for incoming connections.
|
||||
///
|
||||
/// Binds an iroh `Endpoint` with the configured relay URL and optional proxy
|
||||
/// (ADR-010). Accepts incoming connections, accepts bidirectional QUIC streams,
|
||||
/// and joins the halves with `tokio::io::join(recv, send)`. Exposes
|
||||
/// `endpoint_id()` for CLI display of the server's z-base-32 node ID.
|
||||
///
|
||||
/// Use [`IrohAcceptor::bind`] to create a standalone endpoint, or
|
||||
/// [`IrohAcceptor::from_endpoint`] to share an existing iroh `Endpoint`
|
||||
/// with other protocol handlers (blobs, gossip, docs).
|
||||
///
|
||||
/// When using `from_endpoint`, the alknet-ssh ALPN must be registered
|
||||
/// via an iroh `Router` that calls `Handler::accept()` on incoming
|
||||
/// connections with the `alknet-ssh` ALPN, then passes the accepted
|
||||
/// bidirectional stream to `russh::server::run_stream()`.
|
||||
pub struct IrohAcceptor {
|
||||
endpoint: Endpoint,
|
||||
owned: bool,
|
||||
}
|
||||
|
||||
impl IrohAcceptor {
|
||||
/// Bind a new iroh endpoint with a dedicated `alknet-ssh` ALPN.
|
||||
///
|
||||
/// Use this when alknet is the only iroh service on this node.
|
||||
pub async fn bind(relay_url: Option<RelayUrl>, proxy_url: Option<url::Url>) -> Result<Self> {
|
||||
let relay_url = relay_url.unwrap_or_else(|| {
|
||||
DEFAULT_RELAY_URL
|
||||
.parse()
|
||||
.expect("default relay URL is valid")
|
||||
});
|
||||
let relay_map = RelayMap::from_url(relay_url);
|
||||
let mut builder = Endpoint::builder()
|
||||
.relay_mode(RelayMode::Custom(relay_map))
|
||||
.alpns(vec![ALPN.to_vec()]);
|
||||
if let Some(ref proxy) = proxy_url {
|
||||
builder = builder.proxy_url(proxy.clone());
|
||||
}
|
||||
let endpoint = builder.bind().await?;
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
owned: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an iroh acceptor using an existing shared endpoint.
|
||||
///
|
||||
/// The endpoint must already have the `alknet-ssh` ALPN registered
|
||||
/// (typically via [`iroh::protocol::Router::builder`]). When using a
|
||||
/// shared endpoint, incoming connections with the `alknet-ssh` ALPN
|
||||
/// are routed by the Router to a `ProtocolHandler` that this acceptor
|
||||
/// does not manage — the caller is responsible for bridging the
|
||||
/// Router's `accept()` callback to this acceptor's stream handling.
|
||||
///
|
||||
/// For the standalone case where alknet owns the endpoint, use
|
||||
/// [`IrohAcceptor::bind`] instead, which handles the accept loop
|
||||
/// internally.
|
||||
pub fn from_endpoint(endpoint: Endpoint) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
owned: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_id(&self) -> String {
|
||||
self.endpoint.node_id().to_z32()
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &Endpoint {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
pub fn owned(&self) -> bool {
|
||||
self.owned
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for IrohAcceptor {
|
||||
type Stream = io::Join<RecvStream, iroh::endpoint::SendStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let incoming = self
|
||||
.endpoint
|
||||
.accept()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("endpoint closed"))?;
|
||||
let conn = incoming.await?;
|
||||
let node_id = conn.remote_node_id()?;
|
||||
let (send, recv) = conn.accept_bi().await?;
|
||||
let stream = io::join(recv, send);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Iroh {
|
||||
endpoint_id: node_id.to_z32(),
|
||||
},
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_bind_creates_endpoint() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint_id = acceptor.endpoint_id();
|
||||
assert!(!endpoint_id.is_empty());
|
||||
let parsed = NodeId::from_z32(&endpoint_id);
|
||||
assert!(parsed.is_ok());
|
||||
assert!(acceptor.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_bind_with_custom_relay() {
|
||||
let relay: RelayUrl = "https://relay.iroh.network/".parse().unwrap();
|
||||
let acceptor = IrohAcceptor::bind(Some(relay), None).await.unwrap();
|
||||
assert!(!acceptor.endpoint_id().is_empty());
|
||||
assert!(acceptor.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_acceptor_from_endpoint() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint = acceptor.endpoint.clone();
|
||||
let shared = IrohAcceptor::from_endpoint(endpoint);
|
||||
assert_eq!(shared.endpoint_id(), acceptor.endpoint_id());
|
||||
assert!(!shared.owned());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iroh_transport_describe_format() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let desc = format!("iroh://{}", node_id.to_z32());
|
||||
assert!(desc.starts_with("iroh://"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_transport_connect_builds_endpoint() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let transport = IrohTransport::new(node_id, None, None).await.unwrap();
|
||||
assert!(transport.describe().starts_with("iroh://"));
|
||||
assert!(!transport.endpoint_id().is_empty());
|
||||
assert!(transport.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn iroh_transport_from_endpoint() {
|
||||
let node_id: NodeId = iroh::SecretKey::generate(rand_core::OsRng).public().into();
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let endpoint = acceptor.endpoint.clone();
|
||||
let transport = IrohTransport::from_endpoint(node_id, endpoint);
|
||||
assert!(transport.describe().starts_with("iroh://"));
|
||||
assert_eq!(transport.endpoint_id(), acceptor.endpoint_id());
|
||||
assert!(!transport.owned());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn iroh_client_connects_to_iroh_server() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let server_node_id = acceptor.endpoint().node_id();
|
||||
|
||||
let transport = IrohTransport::new(server_node_id, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
|
||||
addrs_watcher.initialized().await.unwrap();
|
||||
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
|
||||
for addr in addr_set {
|
||||
transport
|
||||
.endpoint
|
||||
.add_node_addr(iroh::NodeAddr::from_parts(
|
||||
server_node_id,
|
||||
None,
|
||||
vec![addr.addr],
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let accept_handle = tokio::spawn(async move {
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
|
||||
stream
|
||||
});
|
||||
|
||||
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
|
||||
transport.connect().await.unwrap();
|
||||
let _server_stream = accept_handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn iroh_shared_endpoint_client_connects_to_server() {
|
||||
let acceptor = IrohAcceptor::bind(None, None).await.unwrap();
|
||||
let server_node_id = acceptor.endpoint().node_id();
|
||||
let shared_endpoint = acceptor.endpoint().clone();
|
||||
|
||||
let transport = IrohTransport::from_endpoint(server_node_id, shared_endpoint);
|
||||
|
||||
let mut addrs_watcher = acceptor.endpoint().direct_addresses();
|
||||
addrs_watcher.initialized().await.unwrap();
|
||||
let addr_set = addrs_watcher.get().unwrap().unwrap_or_default();
|
||||
for addr in addr_set {
|
||||
transport
|
||||
.endpoint
|
||||
.add_node_addr(iroh::NodeAddr::from_parts(
|
||||
server_node_id,
|
||||
None,
|
||||
vec![addr.addr],
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let accept_handle = tokio::spawn(async move {
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(matches!(info.transport_kind, TransportKind::Iroh { .. }));
|
||||
stream
|
||||
});
|
||||
|
||||
let _client_stream: io::Join<RecvStream, iroh::endpoint::SendStream> =
|
||||
transport.connect().await.unwrap();
|
||||
let _server_stream = accept_handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
//! Pluggable transport layer for Alknet.
|
||||
//!
|
||||
//! The transport layer produces a duplex byte stream (`AsyncRead + AsyncWrite + Unpin + Send`)
|
||||
//! that SSH consumes. This is the core architectural abstraction — SSH never opens its own
|
||||
//! network connections; it runs entirely over whatever stream the transport provides.
|
||||
//!
|
||||
//! Available transports (feature-gated):
|
||||
//! - `TcpTransport` / `TcpAcceptor` — always available, direct TCP
|
||||
//! - `TlsTransport` / `TlsAcceptor` — behind the `tls` feature, TCP + rustls
|
||||
//! - `IrohTransport` / `IrohAcceptor` — behind the `iroh` feature, QUIC P2P via iroh
|
||||
//! - `AcmeTlsAcceptor` — behind the `acme` feature, auto-provision TLS certs via Let's Encrypt
|
||||
//!
|
||||
//! See [ADR-001](docs/architecture/decisions/001-pluggable-transport.md) and
|
||||
//! [ADR-004](docs/architecture/decisions/004-ssh-over-transport.md) for design rationale.
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
mod iroh_transport;
|
||||
mod tcp;
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
pub use iroh_transport::{IrohAcceptor, IrohTransport, ALPN as IROH_ALPN};
|
||||
pub use tcp::{TcpAcceptor, TcpTransport};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub use tls::{AcmeConfig, TlsAcceptor, TlsTransport};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
mod acme;
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
pub use acme::{AcmeCertProvider, AcmeMode, AcmeTlsAcceptor};
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// Client-side transport trait. Produces a single duplex stream per connection.
|
||||
///
|
||||
/// Implementations connect to a remote endpoint and return a stream that SSH
|
||||
/// runs over via `russh::client::connect_stream()`. Each call to `connect()` creates
|
||||
/// a new stream — multiple sessions need multiple calls or multiple transports.
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync + 'static {
|
||||
/// The duplex stream type produced by this transport.
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
/// Connect to the remote endpoint and return a duplex stream.
|
||||
async fn connect(&self) -> Result<Self::Stream>;
|
||||
|
||||
/// Return a human-readable description of this transport for logging.
|
||||
fn describe(&self) -> String;
|
||||
}
|
||||
|
||||
/// Server-side transport acceptor. Accepts incoming connections and returns streams.
|
||||
///
|
||||
/// Implementations bind to a local endpoint and produce streams that SSH
|
||||
/// runs over via `russh::server::run_stream()`.
|
||||
#[async_trait]
|
||||
pub trait TransportAcceptor: Send + Sync + 'static {
|
||||
/// The duplex stream type produced by this acceptor.
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
/// Accept an incoming connection and return a duplex stream with metadata.
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)>;
|
||||
}
|
||||
|
||||
/// Metadata about an incoming transport connection.
|
||||
///
|
||||
/// Carries the remote address (if available) and the kind of transport
|
||||
/// used. The server handler uses this for logging and auth decisions.
|
||||
/// See ADR-001 for the pluggable transport rationale and ADR-004
|
||||
/// for why SSH runs entirely over the transport stream.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TransportInfo {
|
||||
pub remote_addr: Option<SocketAddr>,
|
||||
pub transport_kind: TransportKind,
|
||||
}
|
||||
|
||||
/// The kind of transport that produced a connection.
|
||||
///
|
||||
/// Each variant identifies the transport mechanism. Used by the
|
||||
/// server handler for logging and authorization decisions.
|
||||
/// See ADR-001 and ADR-004.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TransportKind {
|
||||
Tcp,
|
||||
Tls { server_name: Option<String> },
|
||||
Iroh { endpoint_id: String },
|
||||
WebTransport { server_name: Option<String> },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TransportKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TransportKind::Tcp => write!(f, "tcp"),
|
||||
TransportKind::Tls { .. } => write!(f, "tls"),
|
||||
TransportKind::Iroh { .. } => write!(f, "iroh"),
|
||||
|
||||
TransportKind::WebTransport { .. } => write!(f, "webtransport"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, DuplexStream};
|
||||
|
||||
struct MockTransport;
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for MockTransport {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let (stream, _) = duplex(1024);
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
"mock".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
struct MockAcceptor;
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for MockAcceptor {
|
||||
type Stream = DuplexStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (stream, _) = duplex(1024);
|
||||
let info = TransportInfo {
|
||||
remote_addr: None,
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_trait_object() {
|
||||
let _boxed: Box<dyn Transport<Stream = DuplexStream>> = Box::new(MockTransport);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_acceptor_trait_object() {
|
||||
let _boxed: Box<dyn TransportAcceptor<Stream = DuplexStream>> = Box::new(MockAcceptor);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_connect_returns_stream() {
|
||||
let t = MockTransport;
|
||||
let _stream = t.connect().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transport_describe_returns_string() {
|
||||
let t = MockTransport;
|
||||
assert_eq!(t.describe(), "mock");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn acceptor_accept_returns_stream_and_info() {
|
||||
let a = MockAcceptor;
|
||||
let (_, info) = a.accept().await.unwrap();
|
||||
assert!(info.remote_addr.is_none());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_kind_variants() {
|
||||
let tcp = TransportKind::Tcp;
|
||||
let tls = TransportKind::Tls {
|
||||
server_name: Some("example.com".to_string()),
|
||||
};
|
||||
let iroh = TransportKind::Iroh {
|
||||
endpoint_id: "abc123".to_string(),
|
||||
};
|
||||
let wt = TransportKind::WebTransport {
|
||||
server_name: Some("example.com".to_string()),
|
||||
};
|
||||
|
||||
if let TransportKind::Tcp = tcp {}
|
||||
if let TransportKind::Tls {
|
||||
server_name: Some(name),
|
||||
} = tls
|
||||
{
|
||||
assert_eq!(name, "example.com");
|
||||
}
|
||||
if let TransportKind::Iroh { endpoint_id } = iroh {
|
||||
assert_eq!(endpoint_id, "abc123");
|
||||
}
|
||||
if let TransportKind::WebTransport { server_name } = wt {
|
||||
assert_eq!(server_name, Some("example.com".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
/// A TCP-based client transport that connects to a remote address.
|
||||
///
|
||||
/// Connects via `TcpStream::connect(addr)`. Uses tokio's default
|
||||
/// connect timeout behavior: the OS controls connection timeout
|
||||
/// (typically ~2 minutes on Linux via `net.ipv4.tcp_syn_retries`).
|
||||
/// For custom timeouts, wrap `TcpTransport` with
|
||||
/// `tokio::time::timeout(duration, transport.connect())`.
|
||||
pub struct TcpTransport {
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl TcpTransport {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self { addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for TcpTransport {
|
||||
type Stream = TcpStream;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let stream = TcpStream::connect(self.addr).await?;
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("tcp://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
/// A TCP-based server transport acceptor that listens for incoming connections.
|
||||
///
|
||||
/// Binds via `TcpListener::bind(addr)`. Accepts connections and returns
|
||||
/// the stream together with `TransportInfo` containing the remote address
|
||||
/// and `TransportKind::Tcp`.
|
||||
pub struct TcpAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl TcpAcceptor {
|
||||
/// Bind a TCP listener on the given address.
|
||||
///
|
||||
/// Returns the acceptor ready to receive connections.
|
||||
/// The actual bound address may differ from the requested one
|
||||
/// (e.g., when binding to port 0 the OS assigns an ephemeral port).
|
||||
pub async fn bind(addr: SocketAddr) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for TcpAcceptor {
|
||||
type Stream = TcpStream;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (stream, remote_addr) = self.listener.accept().await?;
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tcp,
|
||||
};
|
||||
Ok((stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_transport_connect_creates_stream() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
let transport = TcpTransport::new(addr);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let stream = transport.connect().await.unwrap();
|
||||
assert_eq!(stream.local_addr().unwrap().ip(), addr.ip());
|
||||
|
||||
let (_server_stream, info) = accept_handle.await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_acceptor_accept_receives_connection() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
tokio::spawn(async move {
|
||||
TcpStream::connect(addr).await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tcp));
|
||||
assert_eq!(
|
||||
info.remote_addr.unwrap().ip(),
|
||||
stream.peer_addr().unwrap().ip()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp_transport_describe_format() {
|
||||
let addr: SocketAddr = "1.2.3.4:22".parse().unwrap();
|
||||
let transport = TcpTransport::new(addr);
|
||||
assert_eq!(transport.describe(), "tcp://1.2.3.4:22");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_stream_is_duplex() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let mut client = TcpStream::connect(addr).await.unwrap();
|
||||
let (mut server, _) = acceptor.accept().await.unwrap();
|
||||
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
|
||||
server.write_all(b"world").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_acceptor_bind_port_zero_assigns_ephemeral() {
|
||||
let acceptor = TcpAcceptor::bind("127.0.0.1:0".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
}
|
||||
@@ -1,429 +0,0 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
||||
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_rustls::{
|
||||
client::TlsStream as ClientTlsStream, TlsAcceptor as TokioTlsAcceptor, TlsConnector,
|
||||
};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
#[cfg(feature = "acme")]
|
||||
use rustls_acme::ResolvesServerCertAcme;
|
||||
|
||||
use super::{Transport, TransportAcceptor, TransportInfo, TransportKind};
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1";
|
||||
|
||||
/// A TLS-based client transport that connects to a remote address over TLS.
|
||||
///
|
||||
/// Wraps a TCP connection with a TLS client session via `tokio_rustls::TlsConnector`.
|
||||
/// Supports insecure mode (accepts any certificate, for development) and
|
||||
/// custom root CA certificates for verification. The `tls_server_name` field
|
||||
/// overrides the SNI hostname sent during the TLS handshake (ADR-010).
|
||||
pub struct TlsTransport {
|
||||
addr: SocketAddr,
|
||||
tls_server_name: Option<String>,
|
||||
insecure: bool,
|
||||
root_cert: Option<CertificateDer<'static>>,
|
||||
}
|
||||
|
||||
impl TlsTransport {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
tls_server_name: None,
|
||||
insecure: false,
|
||||
root_cert: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.tls_server_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_insecure(mut self, insecure: bool) -> Self {
|
||||
self.insecure = insecure;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_root_cert(mut self, cert: CertificateDer<'static>) -> Self {
|
||||
self.root_cert = Some(cert);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_client_config(&self) -> Result<ClientConfig> {
|
||||
if self.insecure {
|
||||
let config = ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
return Ok(config);
|
||||
}
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
if let Some(ref cert) = self.root_cert {
|
||||
root_store.add(cert.clone())?;
|
||||
}
|
||||
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn resolve_server_name(&self) -> Result<ServerName<'static>> {
|
||||
let name = match &self.tls_server_name {
|
||||
Some(n) => n.clone(),
|
||||
None => self.addr.ip().to_string(),
|
||||
};
|
||||
ServerName::try_from(name.clone())
|
||||
.map_err(move |e| anyhow!("invalid server name '{}': {}", name, e))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for TlsTransport {
|
||||
type Stream = ClientTlsStream<TcpStream>;
|
||||
|
||||
async fn connect(&self) -> Result<Self::Stream> {
|
||||
let tcp_stream = TcpStream::connect(self.addr).await?;
|
||||
let config = self.build_client_config()?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
let server_name = self.resolve_server_name()?;
|
||||
let tls_stream = connector.connect(server_name, tcp_stream).await?;
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
fn describe(&self) -> String {
|
||||
format!("tls://{}", self.addr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stub configuration for ACME certificate provisioning (ADR-008).
|
||||
/// Feature-gated behind the `acme` feature. When implemented, this will
|
||||
/// hold the ACME domain and challenge responder configuration.
|
||||
#[derive(Debug)]
|
||||
pub struct AcmeConfig {
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// A TLS-based server transport acceptor that accepts TCP connections
|
||||
/// and wraps them with TLS server sessions via `tokio_rustls::TlsAcceptor`.
|
||||
///
|
||||
/// Supports three certificate modes (ADR-008):
|
||||
/// - Manual certs via `bind()` with explicit cert/key
|
||||
/// - ACME certs via `bind_acme()` with an `AcmeCertProvider`
|
||||
/// - The stub `AcmeConfig` parameter in `bind()` is kept for backward compat
|
||||
pub struct TlsAcceptor {
|
||||
listener: TcpListener,
|
||||
listen_addr: SocketAddr,
|
||||
#[allow(dead_code)]
|
||||
server_config: Arc<ServerConfig>,
|
||||
tokio_acceptor: TokioTlsAcceptor,
|
||||
}
|
||||
|
||||
impl TlsAcceptor {
|
||||
pub async fn bind(
|
||||
addr: SocketAddr,
|
||||
tls_certs: Vec<CertificateDer<'static>>,
|
||||
tls_key: PrivateKeyDer<'static>,
|
||||
_acme_config: Option<AcmeConfig>,
|
||||
) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(tls_certs, tls_key)?;
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
pub async fn bind_acme(
|
||||
addr: SocketAddr,
|
||||
acme_resolver: Arc<ResolvesServerCertAcme>,
|
||||
) -> Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let listen_addr = listener.local_addr()?;
|
||||
|
||||
let provider = default_provider().into();
|
||||
let mut server_config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| anyhow!("failed to set protocol versions: {}", e))?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(acme_resolver);
|
||||
server_config
|
||||
.alpn_protocols
|
||||
.push(ACME_TLS_ALPN_NAME.to_vec());
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let tokio_acceptor = TokioTlsAcceptor::from(server_config.clone());
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
listen_addr,
|
||||
server_config,
|
||||
tokio_acceptor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.listen_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportAcceptor for TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<TcpStream>;
|
||||
|
||||
async fn accept(&self) -> Result<(Self::Stream, TransportInfo)> {
|
||||
let (tcp_stream, remote_addr) = self.listener.accept().await?;
|
||||
let tls_stream = self.tokio_acceptor.accept(tcp_stream).await?;
|
||||
|
||||
let server_name = tls_stream.get_ref().1.server_name().map(|s| s.to_string());
|
||||
|
||||
let info = TransportInfo {
|
||||
remote_addr: Some(remote_addr),
|
||||
transport_kind: TransportKind::Tls { server_name },
|
||||
};
|
||||
|
||||
Ok((tls_stream, info))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> std::result::Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_doc: &DigitallySignedStruct,
|
||||
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use rustls::crypto::aws_lc_rs::default_provider;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
fn ensure_crypto_provider() {
|
||||
let _ = default_provider().install_default();
|
||||
}
|
||||
|
||||
fn generate_self_signed_cert() -> (CertificateDer<'static>, PrivateKeyDer<'static>) {
|
||||
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
let cert_der: CertificateDer<'static> = cert.into();
|
||||
let key_der = PrivateKeyDer::Pkcs8(key_pair.serialize_der().into());
|
||||
(cert_der, key_der)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_format() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr).with_server_name("example.com");
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_describe_with_ip() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr);
|
||||
assert_eq!(transport.describe(), "tls://1.2.3.4:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_transport_builder_methods() {
|
||||
let addr: SocketAddr = "1.2.3.4:443".parse().unwrap();
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("alknet.test")
|
||||
.with_insecure(true);
|
||||
assert_eq!(transport.tls_server_name, Some("alknet.test".to_string()));
|
||||
assert!(transport.insecure);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_connect_insecure_self_signed() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
|
||||
let (mut server, info) = accept_handle.await.unwrap();
|
||||
assert!(info.remote_addr.is_some());
|
||||
assert!(matches!(info.transport_kind, TransportKind::Tls { .. }));
|
||||
|
||||
client.write_all(b"hello tls").await.unwrap();
|
||||
let mut buf = [0u8; 9];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello tls");
|
||||
|
||||
server.write_all(b"reply").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"reply");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_returns_server_name() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let _client = transport.connect().await.unwrap();
|
||||
|
||||
let (_, info) = accept_handle.await.unwrap();
|
||||
if let TransportKind::Tls { server_name } = info.transport_kind {
|
||||
assert_eq!(server_name, Some("localhost".to_string()));
|
||||
} else {
|
||||
panic!("expected TransportKind::Tls");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_full_client_to_server_connection() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = acceptor.listen_addr();
|
||||
|
||||
let transport = TlsTransport::new(addr)
|
||||
.with_server_name("localhost")
|
||||
.with_insecure(true);
|
||||
|
||||
let accept_handle = tokio::spawn(async move { acceptor.accept().await.unwrap() });
|
||||
|
||||
let mut client = transport.connect().await.unwrap();
|
||||
let (mut server, _info) = accept_handle.await.unwrap();
|
||||
|
||||
let msg = b"alknet integration test";
|
||||
client.write_all(msg).await.unwrap();
|
||||
let mut buf = vec![0u8; msg.len()];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..], msg);
|
||||
|
||||
let reply = b"ok";
|
||||
server.write_all(reply).await.unwrap();
|
||||
let mut buf = [0u8; 2];
|
||||
client.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, reply);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_acceptor_bind_port_zero_assigns_ephemeral() {
|
||||
ensure_crypto_provider();
|
||||
let (cert_der, key_der) = generate_self_signed_cert();
|
||||
|
||||
let acceptor = TlsAcceptor::bind(
|
||||
"127.0.0.1:0".parse().unwrap(),
|
||||
vec![cert_der],
|
||||
key_der,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(acceptor.listen_addr().port(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_verifier_accepts_any_cert() {
|
||||
let verifier = NoVerifier;
|
||||
assert!(verifier.supported_verify_schemes().len() > 0);
|
||||
}
|
||||
}
|
||||
882
crates/alknet-core/src/types.rs
Normal file
882
crates/alknet-core/src/types.rs
Normal file
@@ -0,0 +1,882 @@
|
||||
//! Core types: `ProtocolHandler`, `HandlerError`, `Connection`, `BiStream`,
|
||||
//! `SendStream`, `RecvStream`, `StreamError`, `Capabilities`.
|
||||
//!
|
||||
//! See `docs/architecture/crates/core/core-types.md` for the full specification.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||
|
||||
use crate::auth::{AuthContext, Identity};
|
||||
|
||||
pub struct Secret<T: Zeroize + Clone> {
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T: Zeroize + Clone> Secret<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
Self { inner: value }
|
||||
}
|
||||
|
||||
pub fn expose_secret(&self) -> &T {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Zeroize + Clone> Clone for Secret<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Zeroize + Clone> Zeroize for Secret<T> {
|
||||
fn zeroize(&mut self) {
|
||||
self.inner.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Zeroize + Clone> Drop for Secret<T> {
|
||||
fn drop(&mut self) {
|
||||
self.inner.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Zeroize + Clone> std::fmt::Debug for Secret<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("[REDACTED]")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Capabilities {
|
||||
entries: HashMap<String, Secret<String>>,
|
||||
}
|
||||
|
||||
impl Zeroize for Capabilities {
|
||||
fn zeroize(&mut self) {
|
||||
for (_, v) in self.entries.iter_mut() {
|
||||
v.zeroize();
|
||||
}
|
||||
self.entries.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl ZeroizeOnDrop for Capabilities {}
|
||||
|
||||
impl Clone for Capabilities {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
entries: self.entries.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Capabilities {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_api_key(mut self, service: &str, key: String) -> Self {
|
||||
self.entries
|
||||
.insert(format!("api_key:{service}"), Secret::new(key));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_http_token(mut self, service: &str, token: String) -> Self {
|
||||
self.entries
|
||||
.insert(format!("http_token:{service}"), Secret::new(token));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn get(&self, service: &str) -> Option<&Secret<String>> {
|
||||
self.entries
|
||||
.get(&format!("api_key:{service}"))
|
||||
.or_else(|| self.entries.get(&format!("http_token:{service}")))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Capabilities {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Capabilities {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Capabilities")
|
||||
.field("entries", &format!("[{} redacted]", self.entries.len()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum IdentityAlreadySet {
|
||||
#[error("connection identity already set")]
|
||||
AlreadySet,
|
||||
}
|
||||
|
||||
pub enum HandlerError {
|
||||
ConnectionClosed,
|
||||
StreamError(io::Error),
|
||||
AuthRequired,
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for HandlerError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => f.write_str("HandlerError::ConnectionClosed"),
|
||||
Self::StreamError(e) => f.debug_tuple("HandlerError::StreamError").field(e).finish(),
|
||||
Self::AuthRequired => f.write_str("HandlerError::AuthRequired"),
|
||||
Self::Internal(e) => f.debug_tuple("HandlerError::Internal").field(e).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HandlerError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => f.write_str("connection closed"),
|
||||
Self::StreamError(e) => write!(f, "stream error: {e}"),
|
||||
Self::AuthRequired => f.write_str("authentication required"),
|
||||
Self::Internal(e) => write!(f, "internal handler error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for HandlerError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::StreamError(e) => Some(e),
|
||||
Self::Internal(e) => Some(e.as_ref()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum StreamError {
|
||||
ConnectionClosed,
|
||||
StreamClosed,
|
||||
Timeout,
|
||||
Internal(io::Error),
|
||||
}
|
||||
|
||||
impl From<StreamError> for HandlerError {
|
||||
fn from(e: StreamError) -> Self {
|
||||
match e {
|
||||
StreamError::ConnectionClosed => HandlerError::ConnectionClosed,
|
||||
StreamError::StreamClosed => HandlerError::StreamError(io::Error::new(
|
||||
io::ErrorKind::ConnectionReset,
|
||||
"stream closed",
|
||||
)),
|
||||
StreamError::Timeout => HandlerError::StreamError(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"stream timed out",
|
||||
)),
|
||||
StreamError::Internal(e) => HandlerError::StreamError(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StreamError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => f.write_str("StreamError::ConnectionClosed"),
|
||||
Self::StreamClosed => f.write_str("StreamError::StreamClosed"),
|
||||
Self::Timeout => f.write_str("StreamError::Timeout"),
|
||||
Self::Internal(e) => f.debug_tuple("StreamError::Internal").field(e).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StreamError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConnectionClosed => f.write_str("connection closed"),
|
||||
Self::StreamClosed => f.write_str("stream closed"),
|
||||
Self::Timeout => f.write_str("stream timed out"),
|
||||
Self::Internal(e) => write!(f, "stream error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for StreamError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Internal(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ProtocolHandler: Send + Sync + 'static {
|
||||
fn alpn(&self) -> &'static [u8];
|
||||
async fn handle(&self, connection: Connection, auth: &AuthContext) -> Result<(), HandlerError>;
|
||||
}
|
||||
|
||||
pub trait BiStream: AsyncRead + AsyncWrite + Send + Unpin {}
|
||||
|
||||
enum SendStreamKind {
|
||||
#[cfg(feature = "quinn")]
|
||||
Quinn(quinn::SendStream),
|
||||
#[cfg(feature = "iroh")]
|
||||
Iroh(iroh::endpoint::SendStream),
|
||||
Mock(Box<dyn AsyncWrite + Send + Unpin>),
|
||||
}
|
||||
|
||||
enum RecvStreamKind {
|
||||
#[cfg(feature = "quinn")]
|
||||
Quinn(quinn::RecvStream),
|
||||
#[cfg(feature = "iroh")]
|
||||
Iroh(iroh::endpoint::RecvStream),
|
||||
Mock(Box<dyn AsyncRead + Send + Unpin>),
|
||||
}
|
||||
|
||||
pub struct SendStream {
|
||||
kind: SendStreamKind,
|
||||
}
|
||||
|
||||
pub struct RecvStream {
|
||||
kind: RecvStreamKind,
|
||||
}
|
||||
|
||||
impl SendStream {
|
||||
#[cfg(feature = "quinn")]
|
||||
fn from_quinn(stream: quinn::SendStream) -> Self {
|
||||
Self {
|
||||
kind: SendStreamKind::Quinn(stream),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
fn from_iroh(stream: iroh::endpoint::SendStream) -> Self {
|
||||
Self {
|
||||
kind: SendStreamKind::Iroh(stream),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn from_mock(stream: impl AsyncWrite + Send + Unpin + 'static) -> Self {
|
||||
Self {
|
||||
kind: SendStreamKind::Mock(Box::new(stream)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecvStream {
|
||||
#[cfg(feature = "quinn")]
|
||||
fn from_quinn(stream: quinn::RecvStream) -> Self {
|
||||
Self {
|
||||
kind: RecvStreamKind::Quinn(stream),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
fn from_iroh(stream: iroh::endpoint::RecvStream) -> Self {
|
||||
Self {
|
||||
kind: RecvStreamKind::Iroh(stream),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn from_mock(stream: impl AsyncRead + Send + Unpin + 'static) -> Self {
|
||||
Self {
|
||||
kind: RecvStreamKind::Mock(Box::new(stream)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for SendStream {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<io::Result<usize>> {
|
||||
match &mut self.get_mut().kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
SendStreamKind::Quinn(s) => AsyncWrite::poll_write(std::pin::Pin::new(s), cx, buf),
|
||||
#[cfg(feature = "iroh")]
|
||||
SendStreamKind::Iroh(s) => AsyncWrite::poll_write(std::pin::Pin::new(s), cx, buf),
|
||||
SendStreamKind::Mock(s) => {
|
||||
AsyncWrite::poll_write(std::pin::Pin::new(s.as_mut()), cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<io::Result<()>> {
|
||||
match &mut self.get_mut().kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
SendStreamKind::Quinn(s) => AsyncWrite::poll_flush(std::pin::Pin::new(s), cx),
|
||||
#[cfg(feature = "iroh")]
|
||||
SendStreamKind::Iroh(s) => AsyncWrite::poll_flush(std::pin::Pin::new(s), cx),
|
||||
SendStreamKind::Mock(s) => AsyncWrite::poll_flush(std::pin::Pin::new(s.as_mut()), cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<io::Result<()>> {
|
||||
match &mut self.get_mut().kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
SendStreamKind::Quinn(s) => AsyncWrite::poll_shutdown(std::pin::Pin::new(s), cx),
|
||||
#[cfg(feature = "iroh")]
|
||||
SendStreamKind::Iroh(s) => AsyncWrite::poll_shutdown(std::pin::Pin::new(s), cx),
|
||||
SendStreamKind::Mock(s) => {
|
||||
AsyncWrite::poll_shutdown(std::pin::Pin::new(s.as_mut()), cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for RecvStream {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<io::Result<()>> {
|
||||
match &mut self.get_mut().kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
RecvStreamKind::Quinn(s) => AsyncRead::poll_read(std::pin::Pin::new(s), cx, buf),
|
||||
#[cfg(feature = "iroh")]
|
||||
RecvStreamKind::Iroh(s) => AsyncRead::poll_read(std::pin::Pin::new(s), cx, buf),
|
||||
RecvStreamKind::Mock(s) => {
|
||||
AsyncRead::poll_read(std::pin::Pin::new(s.as_mut()), cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum ConnectionKind {
|
||||
#[cfg(feature = "quinn")]
|
||||
Quinn(quinn::Connection),
|
||||
#[cfg(feature = "iroh")]
|
||||
Iroh(iroh::endpoint::Connection),
|
||||
Mock(Arc<dyn MockConnection + Send + Sync>),
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub trait MockConnection: Send + Sync {
|
||||
fn remote_alpn(&self) -> &[u8];
|
||||
fn remote_addr(&self) -> Option<SocketAddr>;
|
||||
fn close(&self, code: u32, reason: &str);
|
||||
}
|
||||
|
||||
pub struct Connection {
|
||||
kind: ConnectionKind,
|
||||
alpn: Vec<u8>,
|
||||
identity: OnceLock<Identity>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
#[cfg(feature = "quinn")]
|
||||
pub fn from_quinn(conn: quinn::Connection) -> Self {
|
||||
Self::from_quinn_with_alpn(conn, Vec::new())
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
pub fn from_quinn_with_alpn(conn: quinn::Connection, alpn: Vec<u8>) -> Self {
|
||||
Self {
|
||||
kind: ConnectionKind::Quinn(conn),
|
||||
alpn,
|
||||
identity: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
pub fn from_iroh(conn: iroh::endpoint::Connection) -> Self {
|
||||
let alpn = conn.alpn().unwrap_or_default();
|
||||
Self {
|
||||
kind: ConnectionKind::Iroh(conn),
|
||||
alpn,
|
||||
identity: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn from_mock(mock: Arc<dyn MockConnection + Send + Sync>) -> Self {
|
||||
let alpn = mock.remote_alpn().to_vec();
|
||||
Self {
|
||||
kind: ConnectionKind::Mock(mock),
|
||||
alpn,
|
||||
identity: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), StreamError> {
|
||||
match &self.kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
ConnectionKind::Quinn(c) => {
|
||||
let (send, recv) = c.accept_bi().await.map_err(map_quinn_connection_error)?;
|
||||
Ok((SendStream::from_quinn(send), RecvStream::from_quinn(recv)))
|
||||
}
|
||||
#[cfg(feature = "iroh")]
|
||||
ConnectionKind::Iroh(c) => {
|
||||
let (send, recv) = c.accept_bi().await.map_err(map_iroh_connection_error)?;
|
||||
Ok((SendStream::from_iroh(send), RecvStream::from_iroh(recv)))
|
||||
}
|
||||
ConnectionKind::Mock(_) => Err(StreamError::StreamClosed),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), StreamError> {
|
||||
match &self.kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
ConnectionKind::Quinn(c) => {
|
||||
let (send, recv) = c.open_bi().await.map_err(map_quinn_connection_error)?;
|
||||
Ok((SendStream::from_quinn(send), RecvStream::from_quinn(recv)))
|
||||
}
|
||||
#[cfg(feature = "iroh")]
|
||||
ConnectionKind::Iroh(c) => {
|
||||
let (send, recv) = c.open_bi().await.map_err(map_iroh_connection_error)?;
|
||||
Ok((SendStream::from_iroh(send), RecvStream::from_iroh(recv)))
|
||||
}
|
||||
ConnectionKind::Mock(_) => Err(StreamError::StreamClosed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remote_alpn(&self) -> &[u8] {
|
||||
&self.alpn
|
||||
}
|
||||
|
||||
pub fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
match &self.kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
ConnectionKind::Quinn(c) => Some(c.remote_address()),
|
||||
#[cfg(feature = "iroh")]
|
||||
ConnectionKind::Iroh(_) => None,
|
||||
ConnectionKind::Mock(m) => m.remote_addr(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(&self, code: u32, reason: &str) {
|
||||
match &self.kind {
|
||||
#[cfg(feature = "quinn")]
|
||||
ConnectionKind::Quinn(c) => {
|
||||
let code = quinn::VarInt::from(code);
|
||||
c.close(code, reason.as_bytes());
|
||||
}
|
||||
#[cfg(feature = "iroh")]
|
||||
ConnectionKind::Iroh(c) => {
|
||||
let code = iroh::endpoint::VarInt::from(code);
|
||||
c.close(code, reason.as_bytes());
|
||||
}
|
||||
ConnectionKind::Mock(m) => m.close(code, reason),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_identity(&self, identity: Identity) -> Result<(), IdentityAlreadySet> {
|
||||
self.identity
|
||||
.set(identity)
|
||||
.map_err(|_| IdentityAlreadySet::AlreadySet)
|
||||
}
|
||||
|
||||
pub fn identity(&self) -> Option<&Identity> {
|
||||
self.identity.get()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
fn map_quinn_connection_error(e: quinn::ConnectionError) -> StreamError {
|
||||
use quinn::ConnectionError as E;
|
||||
match e {
|
||||
E::TimedOut => StreamError::Timeout,
|
||||
E::ConnectionClosed(_) | E::ApplicationClosed(_) | E::Reset => {
|
||||
StreamError::ConnectionClosed
|
||||
}
|
||||
other => StreamError::Internal(io::Error::other(other)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
fn map_iroh_connection_error(e: iroh::endpoint::ConnectionError) -> StreamError {
|
||||
use iroh::endpoint::ConnectionError as E;
|
||||
match e {
|
||||
E::TimedOut => StreamError::Timeout,
|
||||
E::ConnectionClosed(_) | E::ApplicationClosed(_) | E::Reset => {
|
||||
StreamError::ConnectionClosed
|
||||
}
|
||||
other => StreamError::Internal(io::Error::other(other)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
|
||||
struct MockConn {
|
||||
alpn: &'static [u8],
|
||||
addr: Option<SocketAddr>,
|
||||
closed: std::sync::Mutex<Option<(u32, String)>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl MockConnection for MockConn {
|
||||
fn remote_alpn(&self) -> &[u8] {
|
||||
self.alpn
|
||||
}
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.addr
|
||||
}
|
||||
fn close(&self, code: u32, reason: &str) {
|
||||
*self.closed.lock().unwrap() = Some((code, reason.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_connection() -> Connection {
|
||||
Connection::from_mock(Arc::new(MockConn {
|
||||
alpn: b"alknet/test",
|
||||
addr: Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234)),
|
||||
closed: std::sync::Mutex::new(None),
|
||||
}))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_new_is_empty() {
|
||||
let caps = Capabilities::new();
|
||||
assert!(caps.get("google").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_with_api_key_then_get() {
|
||||
let caps = Capabilities::new().with_api_key("google", "sekrit".to_string());
|
||||
let secret = caps.get("google").expect("api key present");
|
||||
assert_eq!(secret.expose_secret(), "sekrit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_with_http_token_then_get() {
|
||||
let caps = Capabilities::new().with_http_token("github", "tok".to_string());
|
||||
let secret = caps.get("github").expect("http token present");
|
||||
assert_eq!(secret.expose_secret(), "tok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_clone_preserves_entries() {
|
||||
let caps = Capabilities::new().with_api_key("google", "k".to_string());
|
||||
let cloned = caps.clone();
|
||||
assert_eq!(
|
||||
cloned.get("google").map(|s| s.expose_secret().clone()),
|
||||
Some("k".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
caps.get("google").map(|s| s.expose_secret().clone()),
|
||||
Some("k".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_zeroize_on_drop_clears_secret() {
|
||||
let mut secret = Secret::new("sensitive".to_string());
|
||||
secret.zeroize();
|
||||
assert_eq!(secret.expose_secret(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_does_not_derive_serialize() {
|
||||
fn assert_not_serialize<T>() {}
|
||||
assert_not_serialize::<Capabilities>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_debug_redacts_entries() {
|
||||
let caps = Capabilities::new().with_api_key("google", "sekrit".to_string());
|
||||
let s = format!("{:?}", caps);
|
||||
assert!(s.contains("redacted"));
|
||||
assert!(!s.contains("sekrit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_debug_redacts() {
|
||||
let secret = Secret::new("hidden".to_string());
|
||||
assert_eq!(format!("{:?}", secret), "[REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_identity_once_succeeds_twice_errors() {
|
||||
let conn = mock_connection();
|
||||
let id = Identity {
|
||||
id: "alk_test".to_string(),
|
||||
scopes: vec!["relay:connect".to_string()],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
assert!(conn.set_identity(id.clone()).is_ok());
|
||||
assert!(matches!(
|
||||
conn.set_identity(id),
|
||||
Err(IdentityAlreadySet::AlreadySet)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_get_returns_set_value() {
|
||||
let conn = mock_connection();
|
||||
assert!(conn.identity().is_none());
|
||||
let id = Identity {
|
||||
id: "alk_test".to_string(),
|
||||
scopes: vec![],
|
||||
resources: HashMap::new(),
|
||||
};
|
||||
conn.set_identity(id.clone()).unwrap();
|
||||
assert_eq!(conn.identity(), Some(&id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_remote_alpn_and_addr_from_mock() {
|
||||
let conn = mock_connection();
|
||||
assert_eq!(conn.remote_alpn(), b"alknet/test");
|
||||
assert_eq!(
|
||||
conn.remote_addr(),
|
||||
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_error_maps_to_handler_error() {
|
||||
assert!(matches!(
|
||||
HandlerError::from(StreamError::ConnectionClosed),
|
||||
HandlerError::ConnectionClosed
|
||||
));
|
||||
match HandlerError::from(StreamError::StreamClosed) {
|
||||
HandlerError::StreamError(e) => assert_eq!(e.kind(), io::ErrorKind::ConnectionReset),
|
||||
other => panic!("expected StreamError, got {other:?}"),
|
||||
}
|
||||
match HandlerError::from(StreamError::Timeout) {
|
||||
HandlerError::StreamError(e) => assert_eq!(e.kind(), io::ErrorKind::TimedOut),
|
||||
other => panic!("expected StreamError, got {other:?}"),
|
||||
}
|
||||
match HandlerError::from(StreamError::Internal(io::Error::other("x"))) {
|
||||
HandlerError::StreamError(e) => assert_eq!(e.kind(), io::ErrorKind::Other),
|
||||
other => panic!("expected StreamError, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_error_auth_required_constructible() {
|
||||
let e = HandlerError::AuthRequired;
|
||||
assert_eq!(format!("{e}"), "authentication required");
|
||||
}
|
||||
|
||||
// --- HandlerError / StreamError Debug + Display + source ---------------
|
||||
|
||||
#[test]
|
||||
fn handler_error_debug_covers_all_variants() {
|
||||
assert_eq!(
|
||||
format!("{:?}", HandlerError::ConnectionClosed),
|
||||
"HandlerError::ConnectionClosed"
|
||||
);
|
||||
let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "boom");
|
||||
let dbg = format!("{:?}", HandlerError::StreamError(io_err));
|
||||
assert!(dbg.contains("HandlerError::StreamError"));
|
||||
assert_eq!(
|
||||
format!("{:?}", HandlerError::AuthRequired),
|
||||
"HandlerError::AuthRequired"
|
||||
);
|
||||
let inner: Box<dyn std::error::Error + Send + Sync> = "oops".into();
|
||||
let dbg = format!("{:?}", HandlerError::Internal(inner));
|
||||
assert!(dbg.contains("HandlerError::Internal"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_error_display_covers_all_variants() {
|
||||
assert_eq!(
|
||||
format!("{}", HandlerError::ConnectionClosed),
|
||||
"connection closed"
|
||||
);
|
||||
let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "boom");
|
||||
let s = format!("{}", HandlerError::StreamError(io_err));
|
||||
assert!(s.starts_with("stream error: "));
|
||||
assert_eq!(
|
||||
format!("{}", HandlerError::AuthRequired),
|
||||
"authentication required"
|
||||
);
|
||||
let inner: Box<dyn std::error::Error + Send + Sync> = "oops".into();
|
||||
assert_eq!(
|
||||
format!("{}", HandlerError::Internal(inner)),
|
||||
"internal handler error: oops"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_error_source_covers_all_variants() {
|
||||
use std::error::Error;
|
||||
assert!(HandlerError::ConnectionClosed.source().is_none());
|
||||
assert!(HandlerError::AuthRequired.source().is_none());
|
||||
let stream_err =
|
||||
HandlerError::StreamError(io::Error::new(io::ErrorKind::BrokenPipe, "boom"));
|
||||
assert!(
|
||||
stream_err.source().is_some(),
|
||||
"StreamError must expose its io::Error as source"
|
||||
);
|
||||
let internal_inner: Box<dyn std::error::Error + Send + Sync> = "boom".into();
|
||||
let internal_err = HandlerError::Internal(internal_inner);
|
||||
assert!(
|
||||
internal_err.source().is_some(),
|
||||
"Internal must expose its inner error as source"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_error_debug_covers_all_variants() {
|
||||
assert_eq!(
|
||||
format!("{:?}", StreamError::ConnectionClosed),
|
||||
"StreamError::ConnectionClosed"
|
||||
);
|
||||
assert_eq!(
|
||||
format!("{:?}", StreamError::StreamClosed),
|
||||
"StreamError::StreamClosed"
|
||||
);
|
||||
assert_eq!(
|
||||
format!("{:?}", StreamError::Timeout),
|
||||
"StreamError::Timeout"
|
||||
);
|
||||
let dbg = format!("{:?}", StreamError::Internal(io::Error::other("x")));
|
||||
assert!(dbg.contains("StreamError::Internal"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_error_display_covers_all_variants() {
|
||||
assert_eq!(
|
||||
format!("{}", StreamError::ConnectionClosed),
|
||||
"connection closed"
|
||||
);
|
||||
assert_eq!(format!("{}", StreamError::StreamClosed), "stream closed");
|
||||
assert_eq!(format!("{}", StreamError::Timeout), "stream timed out");
|
||||
assert_eq!(
|
||||
format!("{}", StreamError::Internal(io::Error::other("boom"))),
|
||||
"stream error: boom"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_error_source_covers_all_variants() {
|
||||
use std::error::Error;
|
||||
assert!(StreamError::ConnectionClosed.source().is_none());
|
||||
assert!(StreamError::StreamClosed.source().is_none());
|
||||
assert!(StreamError::Timeout.source().is_none());
|
||||
let internal = StreamError::Internal(io::Error::other("x"));
|
||||
assert!(
|
||||
internal.source().is_some(),
|
||||
"Internal must expose its io::Error as source"
|
||||
);
|
||||
}
|
||||
|
||||
// --- map_*_connection_error -------------------------------------------
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
#[test]
|
||||
fn map_quinn_connection_error_timed_out_maps_to_timeout() {
|
||||
assert!(matches!(
|
||||
map_quinn_connection_error(quinn::ConnectionError::TimedOut),
|
||||
StreamError::Timeout
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
#[test]
|
||||
fn map_quinn_connection_error_reset_maps_to_connection_closed() {
|
||||
assert!(matches!(
|
||||
map_quinn_connection_error(quinn::ConnectionError::Reset),
|
||||
StreamError::ConnectionClosed
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
#[test]
|
||||
fn map_quinn_connection_error_application_closed_maps_to_connection_closed() {
|
||||
use bytes::Bytes;
|
||||
let close = quinn::ConnectionError::ApplicationClosed(quinn::ApplicationClose {
|
||||
error_code: quinn::VarInt::from_u32(1),
|
||||
reason: Bytes::new(),
|
||||
});
|
||||
assert!(matches!(
|
||||
map_quinn_connection_error(close),
|
||||
StreamError::ConnectionClosed
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "quinn")]
|
||||
#[test]
|
||||
fn map_quinn_connection_error_other_maps_to_internal() {
|
||||
let other = quinn::ConnectionError::VersionMismatch;
|
||||
match map_quinn_connection_error(other) {
|
||||
StreamError::Internal(e) => assert_eq!(e.kind(), io::ErrorKind::Other),
|
||||
other => panic!("expected StreamError::Internal, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[test]
|
||||
fn map_iroh_connection_error_timed_out_maps_to_timeout() {
|
||||
assert!(matches!(
|
||||
map_iroh_connection_error(iroh::endpoint::ConnectionError::TimedOut),
|
||||
StreamError::Timeout
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[test]
|
||||
fn map_iroh_connection_error_reset_maps_to_connection_closed() {
|
||||
assert!(matches!(
|
||||
map_iroh_connection_error(iroh::endpoint::ConnectionError::Reset),
|
||||
StreamError::ConnectionClosed
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[test]
|
||||
fn map_iroh_connection_error_application_closed_maps_to_connection_closed() {
|
||||
use bytes::Bytes;
|
||||
let close =
|
||||
iroh::endpoint::ConnectionError::ApplicationClosed(iroh::endpoint::ApplicationClose {
|
||||
error_code: iroh::endpoint::VarInt::from_u32(1),
|
||||
reason: Bytes::new(),
|
||||
});
|
||||
assert!(matches!(
|
||||
map_iroh_connection_error(close),
|
||||
StreamError::ConnectionClosed
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(feature = "iroh")]
|
||||
#[test]
|
||||
fn map_iroh_connection_error_other_maps_to_internal() {
|
||||
let other = iroh::endpoint::ConnectionError::VersionMismatch;
|
||||
match map_iroh_connection_error(other) {
|
||||
StreamError::Internal(e) => assert_eq!(e.kind(), io::ErrorKind::Other),
|
||||
other => panic!("expected StreamError::Internal, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Capabilities zeroize + default -----------------------------------
|
||||
|
||||
#[test]
|
||||
fn capabilities_default_is_empty() {
|
||||
let caps = Capabilities::default();
|
||||
assert!(caps.get("anything").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_zeroize_clears_entries() {
|
||||
let mut caps = Capabilities::new()
|
||||
.with_api_key("svc-a", "k1".to_string())
|
||||
.with_http_token("svc-b", "t1".to_string());
|
||||
assert!(caps.get("svc-a").is_some());
|
||||
assert!(caps.get("svc-b").is_some());
|
||||
caps.zeroize();
|
||||
assert!(caps.get("svc-a").is_none());
|
||||
assert!(caps.get("svc-b").is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn auth_placeholder() {}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn client_placeholder() {}
|
||||
@@ -1,2 +0,0 @@
|
||||
#[tokio::test]
|
||||
async fn server_placeholder() {}
|
||||
@@ -1,28 +0,0 @@
|
||||
use alknet_core::testutil::{
|
||||
mock_pair, MockTransport, MockTransportAcceptor, Transport, TransportAcceptor,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_connect() {
|
||||
let transport = MockTransport::new(1024);
|
||||
let stream = transport.connect().await.unwrap();
|
||||
drop(stream);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_transport_acceptor_accept() {
|
||||
let acceptor = MockTransportAcceptor::new(1024);
|
||||
let (stream, info) = acceptor.accept().await.unwrap();
|
||||
drop(stream);
|
||||
drop(info);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_pair_communicates() {
|
||||
let (mut client, mut server) = mock_pair(1024);
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
let mut buf = [0u8; 5];
|
||||
server.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf, b"hello");
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
[package]
|
||||
name = "alknet-napi"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Node.js native addon for Alknet via napi-rs: connect() and serve() SSH tunnel functions"
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
alknet-core = { path = "../alknet-core", features = ["tls", "iroh"] }
|
||||
napi = { version = "3", features = ["async", "error_anyhow"] }
|
||||
napi-derive = "3"
|
||||
tokio = { version = "1", features = ["io-util", "sync", "rt", "macros", "net", "time", "signal"] }
|
||||
russh = "0.49"
|
||||
async-trait = "0.1"
|
||||
rustls-pemfile = "2"
|
||||
rustls-pki-types = "1"
|
||||
iroh = "0.34"
|
||||
ipnetwork = "0.21"
|
||||
url = "2"
|
||||
arc-swap = "1"
|
||||
tracing = "0.1"
|
||||
@@ -1,304 +0,0 @@
|
||||
//! NAPI `connect()` function and `AlknetStream` type.
|
||||
//!
|
||||
//! Opens a single SSH channel as a duplex stream for programmatic use.
|
||||
//! Unlike the CLI client, this does not start a SOCKS5 server or port forwards —
|
||||
//! it provides a raw stream that JavaScript code can read from and write to.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use russh::client;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use alknet_core::auth::client_auth::{ClientAuthConfig, ClientHandler};
|
||||
use alknet_core::auth::keys::KeySource;
|
||||
use alknet_core::transport::{IrohTransport, TcpTransport, TlsTransport, Transport};
|
||||
|
||||
const DEFAULT_HOST: &str = "alknet-control";
|
||||
const DEFAULT_PORT: u32 = 0;
|
||||
|
||||
#[napi(object)]
|
||||
pub struct AlknetConnectOptions {
|
||||
pub server: Option<String>,
|
||||
pub peer: Option<String>,
|
||||
pub transport: String,
|
||||
pub identity: Option<Either<String, Buffer>>,
|
||||
pub tls_server_name: Option<String>,
|
||||
pub insecure: Option<bool>,
|
||||
pub iroh_relay: Option<String>,
|
||||
pub proxy: Option<String>,
|
||||
}
|
||||
|
||||
fn resolve_key_source(identity: &Option<Either<String, Buffer>>) -> Result<KeySource> {
|
||||
match identity {
|
||||
None => Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
"identity is required: provide a file path (string) or key data (Buffer)",
|
||||
)),
|
||||
Some(Either::A(path)) => Ok(KeySource::File(path.into())),
|
||||
Some(Either::B(buf)) => Ok(KeySource::Memory(buf.to_vec())),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_addr(addr_str: &str) -> Result<SocketAddr> {
|
||||
addr_str.parse().map_err(|e| {
|
||||
Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("invalid server address '{}': {}", addr_str, e),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct AlknetStream {
|
||||
read: Arc<Mutex<tokio::io::ReadHalf<russh::ChannelStream<client::Msg>>>>,
|
||||
write: Arc<Mutex<tokio::io::WriteHalf<russh::ChannelStream<client::Msg>>>>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl AlknetStream {
|
||||
#[napi]
|
||||
pub async fn read(&self, size: u32) -> Result<Buffer> {
|
||||
let mut buf = vec![0u8; size as usize];
|
||||
let mut guard = self.read.lock().await;
|
||||
let n = guard
|
||||
.read(&mut buf)
|
||||
.await
|
||||
.map_err(|e| Error::new(Status::GenericFailure, format!("read failed: {}", e)))?;
|
||||
if n == 0 {
|
||||
return Ok(Vec::<u8>::new().into());
|
||||
}
|
||||
buf.truncate(n);
|
||||
Ok(buf.into())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn write(&self, data: Buffer) -> Result<()> {
|
||||
let mut guard = self.write.lock().await;
|
||||
guard
|
||||
.write_all(&data)
|
||||
.await
|
||||
.map_err(|e| Error::new(Status::GenericFailure, format!("write failed: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn close(&self) -> Result<()> {
|
||||
let mut guard = self.write.lock().await;
|
||||
guard
|
||||
.shutdown()
|
||||
.await
|
||||
.map_err(|e| Error::new(Status::GenericFailure, format!("close failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn connect(options: AlknetConnectOptions) -> Result<AlknetStream> {
|
||||
let key_source = resolve_key_source(&options.identity)?;
|
||||
let auth_config = Arc::new(
|
||||
ClientAuthConfig::from_key_source(key_source)
|
||||
.map_err(|e| Error::new(Status::InvalidArg, format!("invalid identity key: {}", e)))?,
|
||||
);
|
||||
|
||||
let transport_mode = options.transport.to_lowercase();
|
||||
let handler = ClientHandler::from_config(&auth_config);
|
||||
let username = "alknet".to_string();
|
||||
|
||||
let config = Arc::new(client::Config::default());
|
||||
|
||||
let mut handle: client::Handle<ClientHandler> = match transport_mode.as_str() {
|
||||
"tcp" => {
|
||||
let server = options.server.as_ref().ok_or_else(|| {
|
||||
Error::new(Status::InvalidArg, "server is required for tcp transport")
|
||||
})?;
|
||||
let addr = parse_addr(server)?;
|
||||
let transport = TcpTransport::new(addr);
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
Error::new(Status::GenericFailure, format!("tcp connect failed: {}", e))
|
||||
})?;
|
||||
client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("ssh handshake failed: {}", e),
|
||||
)
|
||||
})?
|
||||
}
|
||||
"tls" => {
|
||||
let server = options.server.as_ref().ok_or_else(|| {
|
||||
Error::new(Status::InvalidArg, "server is required for tls transport")
|
||||
})?;
|
||||
let addr = parse_addr(server)?;
|
||||
let mut transport = TlsTransport::new(addr);
|
||||
if let Some(ref name) = options.tls_server_name {
|
||||
transport = transport.with_server_name(name);
|
||||
}
|
||||
if let Some(true) = options.insecure {
|
||||
transport = transport.with_insecure(true);
|
||||
}
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
Error::new(Status::GenericFailure, format!("tls connect failed: {}", e))
|
||||
})?;
|
||||
client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("ssh handshake failed: {}", e),
|
||||
)
|
||||
})?
|
||||
}
|
||||
"iroh" => {
|
||||
let peer = options.peer.as_ref().ok_or_else(|| {
|
||||
Error::new(Status::InvalidArg, "peer is required for iroh transport")
|
||||
})?;
|
||||
let node_id: iroh::NodeId = peer.parse().map_err(|e| {
|
||||
Error::new(
|
||||
Status::InvalidArg,
|
||||
format!("invalid iroh peer ID '{}': {}", peer, e),
|
||||
)
|
||||
})?;
|
||||
let relay_url: Option<iroh::RelayUrl> = match options.iroh_relay.as_deref() {
|
||||
Some(u) => Some(u.parse().map_err(|e| {
|
||||
Error::new(Status::InvalidArg, format!("invalid iroh relay URL: {}", e))
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
let proxy_url: Option<url::Url> = match options.proxy.as_deref() {
|
||||
Some(u) => Some(u.parse().map_err(|e| {
|
||||
Error::new(Status::InvalidArg, format!("invalid proxy URL: {}", e))
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
let transport = IrohTransport::new(node_id, relay_url, proxy_url)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("iroh endpoint setup failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
let stream = transport.connect().await.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("iroh connect failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
client::connect_stream(config, stream, handler)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("ssh handshake failed: {}", e),
|
||||
)
|
||||
})?
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::new(
|
||||
Status::InvalidArg,
|
||||
format!(
|
||||
"unknown transport '{}'; expected tcp, tls, or iroh",
|
||||
transport_mode
|
||||
),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let auth_ok = auth_config
|
||||
.authenticate(&mut handle, &username)
|
||||
.await
|
||||
.map_err(|e| Error::new(Status::GenericFailure, format!("ssh auth failed: {}", e)))?;
|
||||
if !auth_ok {
|
||||
return Err(Error::new(
|
||||
Status::GenericFailure,
|
||||
"ssh authentication rejected",
|
||||
));
|
||||
}
|
||||
|
||||
let channel = handle
|
||||
.channel_open_direct_tcpip(DEFAULT_HOST, DEFAULT_PORT, "127.0.0.1", 0)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("failed to open ssh channel: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let stream = channel.into_stream();
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
Ok(AlknetStream {
|
||||
read: Arc::new(Mutex::new(read_half)),
|
||||
write: Arc::new(Mutex::new(write_half)),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const ED25519_PRIVATE_KEY: &str = "-----BEGIN OPENSSH PRIVATE KEY-----\nb3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\nQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01QAAAJiQ+NvMkPjb\nzAAAAAtzc2gtZWQyNTUxOQAAACBOfInDyRS33JEeDNT8xd10qRdwFN8z/QukCOgEIkv01Q\nAAAECIWwJf7+7MOuZAOOWmoQbE9i/5GxjKsFrtJHjZ34E/fk58icPJFLfckR4M1PzF3XSp\nF3AU3zP9C6QI6AQiS/TVAAAAD3VidW50dUBuczUyODA5NgECAwQFBg==\n-----END OPENSSH PRIVATE KEY-----\n";
|
||||
|
||||
#[test]
|
||||
fn resolve_key_source_file_path() {
|
||||
let identity = Some(Either::<String, Buffer>::A("/path/to/key".to_string()));
|
||||
let result = resolve_key_source(&identity);
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
KeySource::File(p) => assert_eq!(p.to_str(), Some("/path/to/key")),
|
||||
_ => panic!("expected File variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_key_source_buffer() {
|
||||
let identity = Some(Either::<String, Buffer>::B(Buffer::from(
|
||||
ED25519_PRIVATE_KEY.as_bytes().to_vec(),
|
||||
)));
|
||||
let result = resolve_key_source(&identity);
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
KeySource::Memory(data) => assert!(!data.is_empty()),
|
||||
_ => panic!("expected Memory variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_key_source_missing() {
|
||||
let identity: Option<Either<String, Buffer>> = None;
|
||||
let result = resolve_key_source(&identity);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_addr_valid() {
|
||||
let addr = parse_addr("127.0.0.1:22");
|
||||
assert!(addr.is_ok());
|
||||
assert_eq!(addr.unwrap().port(), 22);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_addr_invalid() {
|
||||
let addr = parse_addr("not-an-address");
|
||||
assert!(addr.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_config_from_memory_key() {
|
||||
let source = KeySource::Memory(ED25519_PRIVATE_KEY.as_bytes().to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source);
|
||||
assert!(config.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_config_from_invalid_key() {
|
||||
let source = KeySource::Memory(b"not-a-key".to_vec());
|
||||
let config = ClientAuthConfig::from_key_source(source);
|
||||
assert!(config.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
//! # alknet-napi
|
||||
//!
|
||||
//! Node.js native addon for [Alknet](https://git.alk.dev/alkdev/alknet) via napi-rs.
|
||||
//! Exposes `connect()` and `serve()` functions for programmatic SSH tunnel creation.
|
||||
//!
|
||||
//! > **Alpha software.** The NAPI interface may change between versions.
|
||||
//!
|
||||
//! # Quick example (Node.js)
|
||||
//!
|
||||
//! ```js
|
||||
//! const { connect, serve } = require('alknet-napi');
|
||||
//!
|
||||
//! // Client: open a duplex SSH stream
|
||||
//! const stream = await connect({
|
||||
//! server: "example.com:22",
|
||||
//! transport: "tcp",
|
||||
//! identity: "/path/to/key",
|
||||
//! });
|
||||
//! await stream.write(Buffer.from("hello"));
|
||||
//! const data = await stream.read(1024);
|
||||
//! await stream.close();
|
||||
//! ```
|
||||
|
||||
#[allow(unused_imports)]
|
||||
#[macro_use]
|
||||
extern crate napi_derive;
|
||||
|
||||
mod connect;
|
||||
mod serve;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,47 +0,0 @@
|
||||
//! # alknet-secret
|
||||
//!
|
||||
//! BIP39 mnemonic generation, SLIP-0010 Ed25519 HD key derivation, AES-256-GCM
|
||||
//! encryption for external credentials, and the `SecretProtocol` irpc service.
|
||||
//!
|
||||
//! This crate is the only component that holds the master seed phrase. All other
|
||||
//! crates request derived keys through the `SecretProtocol` irpc service or the
|
||||
//! `SecretServiceHandle` local API.
|
||||
//!
|
||||
//! ## Crate Independence
|
||||
//!
|
||||
//! alknet-secret does **not** depend on alknet-core or alknet-storage. Per ADR-027,
|
||||
//! it is fully independent. The `EncryptedData` wire format is shared with
|
||||
//! alknet-storage by type-level compatibility, not a crate dependency.
|
||||
//!
|
||||
//! ## Security Model
|
||||
//!
|
||||
//! The seed phrase is never persisted to disk. It is entered at startup or via
|
||||
//! `Unlock` and held only in `Zeroize`-protected RAM (ADR-038). `Lock` purges
|
||||
//! the seed and all cached derived keys.
|
||||
//!
|
||||
//! ## Module Organization
|
||||
//!
|
||||
//! - [`mnemonic`] — BIP39 mnemonic generation, validation, and seed derivation
|
||||
//! - [`derivation`] — SLIP-0010 Ed25519 HD key derivation and path constants
|
||||
//! - [`encryption`] — AES-256-GCM encrypt/decrypt and `EncryptedData` type
|
||||
//! - [`protocol`] — `SecretProtocol` irpc service enum, `DerivedKey`, `KeyType`
|
||||
//! - [`service`] — `SecretService` implementation with Unlock/Lock lifecycle
|
||||
//! - [`ethereum`] — BIP-0032 secp256k1 HD key derivation (behind `secp256k1` feature)
|
||||
|
||||
pub mod cache;
|
||||
pub mod derivation;
|
||||
pub mod encryption;
|
||||
pub mod mnemonic;
|
||||
pub mod protocol;
|
||||
pub mod service;
|
||||
|
||||
#[cfg(feature = "secp256k1")]
|
||||
pub mod ethereum;
|
||||
|
||||
// Re-export primary public API
|
||||
pub use cache::CacheConfig;
|
||||
pub use derivation::{DerivationError, ExtendedPrivKey, PATHS};
|
||||
pub use encryption::{EncryptedData, EncryptionError};
|
||||
pub use mnemonic::{Language, Mnemonic, Seed};
|
||||
pub use protocol::{DerivedKey, KeyType, SecretMessage, SecretProtocol};
|
||||
pub use service::{SecretService, SecretServiceActor, SecretServiceError, SecretServiceHandle};
|
||||
@@ -1,310 +0,0 @@
|
||||
//! SecretProtocol irpc service definition and associated types.
|
||||
//!
|
||||
//! This module defines the `SecretProtocol` enum for irpc-based inter-service
|
||||
//! communication. The protocol supports unlock/lock lifecycle, key derivation,
|
||||
//! and encryption/decryption operations.
|
||||
//!
|
||||
//! # Protocol Operation
|
||||
//!
|
||||
//! The SecretProtocol follows a lifecycle: the service starts in a **locked**
|
||||
//! state where no derivation or encryption operations are possible. The `Unlock`
|
||||
//! call loads the seed into memory (derived from the mnemonic passphrase). After
|
||||
//! that, derive and encrypt/decrypt operations are available. The `Lock` call
|
||||
//! purges the seed and all cached keys.
|
||||
//!
|
||||
//! # Wire Format
|
||||
//!
|
||||
//! For local (in-process) calls, the protocol uses tokio channels directly.
|
||||
//! For remote (in-cluster) calls, the protocol is serialized with postcard.
|
||||
//! For cross-node (call protocol) exposure, the service is wrapped in an
|
||||
//! operation that serializes to JSON.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use irpc::rpc_requests;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use zeroize::Zeroize;
|
||||
|
||||
use crate::encryption::EncryptedData;
|
||||
|
||||
/// The type of a derived key.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum KeyType {
|
||||
/// Ed25519 keypair (SLIP-0010 derivation).
|
||||
Ed25519,
|
||||
/// AES-256-GCM symmetric key (derived from seed, used for external credential encryption).
|
||||
Aes256Gcm,
|
||||
/// secp256k1 keypair (BIP-0032 derivation, for Ethereum signing).
|
||||
Secp256k1,
|
||||
}
|
||||
|
||||
/// A derived key pair (private key + public key).
|
||||
///
|
||||
/// The private key is sensitive material that is zeroized on drop (ADR-038).
|
||||
/// This type is **not** `Clone` — it is move-only. Consumers receive a
|
||||
/// `DerivedKey` by value and must zeroize it when done (handled automatically
|
||||
/// by `#[zeroize(drop)]`).
|
||||
///
|
||||
/// Serialization redacts the `private_key` field for human-readable formats
|
||||
/// (JSON) for safety, showing `"[REDACTED]"` instead of the key bytes. For
|
||||
/// binary formats (postcard, used by irpc), the actual bytes are serialized
|
||||
/// so that remote communication works correctly. Deserialization always reads
|
||||
/// the full bytes.
|
||||
#[derive(Zeroize, Deserialize)]
|
||||
#[zeroize(drop)]
|
||||
pub struct DerivedKey {
|
||||
/// The type of key that was derived.
|
||||
#[zeroize(skip)]
|
||||
pub key_type: KeyType,
|
||||
/// The private key bytes (sensitive — zeroized on drop).
|
||||
#[zeroize]
|
||||
#[serde(deserialize_with = "deserialize_private_key")]
|
||||
pub private_key: Vec<u8>,
|
||||
/// The public key bytes.
|
||||
#[zeroize(skip)]
|
||||
pub public_key: Vec<u8>,
|
||||
}
|
||||
|
||||
fn deserialize_private_key<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
|
||||
Vec::<u8>::deserialize(d)
|
||||
}
|
||||
|
||||
impl fmt::Debug for DerivedKey {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("DerivedKey")
|
||||
.field("key_type", &self.key_type)
|
||||
.field("private_key", &"[REDACTED]")
|
||||
.field("public_key", &self.public_key)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for DerivedKey {
|
||||
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
use serde::ser::SerializeStruct;
|
||||
if s.is_human_readable() {
|
||||
let mut state = s.serialize_struct("DerivedKey", 3)?;
|
||||
state.serialize_field("key_type", &self.key_type)?;
|
||||
state.serialize_field("private_key", "[REDACTED]")?;
|
||||
state.serialize_field("public_key", &self.public_key)?;
|
||||
state.end()
|
||||
} else {
|
||||
let mut state = s.serialize_struct("DerivedKey", 3)?;
|
||||
state.serialize_field("key_type", &self.key_type)?;
|
||||
state.serialize_field("private_key", &self.private_key)?;
|
||||
state.serialize_field("public_key", &self.public_key)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SecretProtocol service definition.
|
||||
///
|
||||
/// This is the irpc protocol enum that defines all secret service operations.
|
||||
/// The `#[rpc_requests]` macro generates:
|
||||
/// - **`SecretMessage`**: message enum with `WithChannels` wrappers for each variant
|
||||
/// - **`Channels<SecretProtocol>`** impls for each wrapper type
|
||||
/// - **`From`** impls for protocol enum and message enum conversions
|
||||
/// - **`Service`** and **`RemoteService`** trait impls for remote dispatch
|
||||
///
|
||||
/// # State Requirements
|
||||
///
|
||||
/// All operations except `Unlock` require the service to be in an **unlocked**
|
||||
/// state. Calling derive/encrypt/decrypt on a locked service returns an error.
|
||||
#[rpc_requests(message = SecretMessage, no_spans)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum SecretProtocol {
|
||||
/// Derive an Ed25519 keypair at the given path.
|
||||
///
|
||||
/// Path format: `m/74'/0'/0'/0'` (SLIP-0010 hardened-only notation).
|
||||
/// Returns a `DerivedKey` with `KeyType::Ed25519`.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<DerivedKey, crate::service::SecretServiceError>>)]
|
||||
#[wrap(DeriveEd25519)]
|
||||
DeriveEd25519 {
|
||||
/// SLIP-0010 derivation path (e.g., "m/74'/0'/0'/0'").
|
||||
path: String,
|
||||
},
|
||||
|
||||
/// Derive an AES-256-GCM encryption key at the given path.
|
||||
///
|
||||
/// The default encryption path is `m/74'/2'/0'/0'`.
|
||||
/// Returns a `DerivedKey` with `KeyType::Aes256Gcm`.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<DerivedKey, crate::service::SecretServiceError>>)]
|
||||
#[wrap(DeriveEncryptionKey)]
|
||||
DeriveEncryptionKey {
|
||||
/// SLIP-0010 derivation path for the encryption key.
|
||||
path: String,
|
||||
},
|
||||
|
||||
/// Derive a secp256k1 (Ethereum) keypair at the given path.
|
||||
///
|
||||
/// The default Ethereum path is `m/44'/60'/0'/0/0`.
|
||||
/// Returns a `DerivedKey` with `KeyType::Secp256k1`.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<DerivedKey, crate::service::SecretServiceError>>)]
|
||||
#[wrap(DeriveEthereumKey)]
|
||||
DeriveEthereumKey {
|
||||
/// BIP-0032 derivation path (e.g., "m/44'/60'/0'/0/0").
|
||||
path: String,
|
||||
},
|
||||
|
||||
/// Derive a deterministic password at the given path.
|
||||
///
|
||||
/// Path format: `m/74'/1'/0'/{hash}'` (SLIP-0010 hardened notation).
|
||||
/// The `length` parameter controls the output length.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<Vec<u8>, crate::service::SecretServiceError>>)]
|
||||
#[wrap(DerivePassword)]
|
||||
DerivePassword {
|
||||
/// SLIP-0010 derivation path for the password.
|
||||
path: String,
|
||||
/// Desired password length in bytes.
|
||||
length: usize,
|
||||
},
|
||||
|
||||
/// Encrypt plaintext using a derived encryption key.
|
||||
///
|
||||
/// The key is derived at the path `m/74'/2'/0'/0'` with the given version.
|
||||
/// Returns an `EncryptedData` blob suitable for storage.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<EncryptedData, crate::service::SecretServiceError>>)]
|
||||
#[wrap(Encrypt)]
|
||||
Encrypt {
|
||||
/// The plaintext string to encrypt.
|
||||
plaintext: String,
|
||||
/// The key version for rotation tracking.
|
||||
key_version: u32,
|
||||
},
|
||||
|
||||
/// Decrypt an `EncryptedData` blob back to plaintext.
|
||||
///
|
||||
/// The key is derived from the seed at the path indicated by the key version.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<String, crate::service::SecretServiceError>>)]
|
||||
#[wrap(Decrypt)]
|
||||
Decrypt {
|
||||
/// The encrypted data blob to decrypt.
|
||||
encrypted: EncryptedData,
|
||||
},
|
||||
|
||||
/// Lock the service, purging the seed and all cached derived keys.
|
||||
///
|
||||
/// After locking, no derive/encrypt/decrypt operations are possible
|
||||
/// until `Unlock` is called again. Calls `zeroize()` on all sensitive
|
||||
/// material (ADR-038).
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<(), crate::service::SecretServiceError>>)]
|
||||
#[wrap(Lock)]
|
||||
Lock,
|
||||
|
||||
/// Unlock the service with a BIP39 mnemonic and optional passphrase.
|
||||
///
|
||||
/// The mnemonic is the space-separated BIP39 word list. The passphrase is
|
||||
/// the optional BIP39 password extension (the "25th word"). After unlocking,
|
||||
/// derive and encrypt/decrypt operations are available.
|
||||
#[rpc(tx = irpc::channel::oneshot::Sender<Result<(), crate::service::SecretServiceError>>)]
|
||||
#[wrap(Unlock)]
|
||||
Unlock {
|
||||
/// The BIP39 mnemonic phrase (space-separated word list).
|
||||
mnemonic: String,
|
||||
/// Optional BIP39 passphrase (the "25th word" password extension).
|
||||
passphrase: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_key() -> DerivedKey {
|
||||
DerivedKey {
|
||||
key_type: KeyType::Ed25519,
|
||||
private_key: vec![0xABu8; 32],
|
||||
public_key: vec![0xCDu8; 32],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_debug_redacts_private_key() {
|
||||
let key = make_test_key();
|
||||
let debug_output = format!("{:?}", key);
|
||||
assert!(
|
||||
!debug_output.contains("AB"),
|
||||
"Debug must not leak private_key bytes"
|
||||
);
|
||||
assert!(
|
||||
debug_output.contains("[REDACTED]"),
|
||||
"Debug must show [REDACTED] for private_key"
|
||||
);
|
||||
assert!(debug_output.contains("Ed25519"), "Debug must show key_type");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_serialize_redacts_private_key_json() {
|
||||
let key = make_test_key();
|
||||
let json = serde_json::to_string(&key).unwrap();
|
||||
assert!(
|
||||
!json.contains("AB"),
|
||||
"JSON must not contain private_key bytes"
|
||||
);
|
||||
assert!(
|
||||
json.contains("[REDACTED]"),
|
||||
"JSON must show [REDACTED] for private_key"
|
||||
);
|
||||
assert!(json.contains("Ed25519"), "JSON must contain key_type");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_serialize_preserves_bytes_postcard() {
|
||||
let key = make_test_key();
|
||||
let bytes = postcard::to_allocvec(&key).unwrap();
|
||||
let restored: DerivedKey = postcard::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(
|
||||
restored.private_key,
|
||||
vec![0xABu8; 32],
|
||||
"postcard must preserve private_key bytes"
|
||||
);
|
||||
assert_eq!(
|
||||
restored.public_key,
|
||||
vec![0xCDu8; 32],
|
||||
"postcard must preserve public_key bytes"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_deserialize_preserves_bytes() {
|
||||
let key = make_test_key();
|
||||
let bytes = postcard::to_allocvec(&key.private_key).unwrap();
|
||||
let restored: Vec<u8> = postcard::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(
|
||||
restored,
|
||||
vec![0xABu8; 32],
|
||||
"Deserialization must preserve private_key bytes"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_zeroize_on_drop() {
|
||||
let key = DerivedKey {
|
||||
key_type: KeyType::Aes256Gcm,
|
||||
private_key: vec![0xFFu8; 32],
|
||||
public_key: vec![0x00u8; 32],
|
||||
};
|
||||
drop(key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_not_clone() {
|
||||
let key = make_test_key();
|
||||
let _moved = key;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derived_key_zeroize_method_overwrites_private_key() {
|
||||
let mut key = make_test_key();
|
||||
assert_ne!(key.private_key, vec![0u8; 32]);
|
||||
assert!(!key.private_key.is_empty());
|
||||
|
||||
key.zeroize();
|
||||
|
||||
assert!(
|
||||
key.private_key.is_empty(),
|
||||
"zeroize() must clear the private_key Vec"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,972 +0,0 @@
|
||||
//! SecretService implementation with Unlock/Lock lifecycle.
|
||||
//!
|
||||
//! The `SecretService` is the primary runtime interface for key management.
|
||||
//! It holds the master seed in `Zeroize`-protected memory and provides methods
|
||||
//! for the Unlock/Lock lifecycle, key derivation, and encryption/decryption.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//!
|
||||
//! ```text
|
||||
//! Unlock(passphrase)
|
||||
//! → validate mnemonic (if restoring) or generate new
|
||||
//! → derive master key from seed
|
||||
//! → store seed in SeedHolder (Zeroize-protected)
|
||||
//! → cache empty (keys derived on demand)
|
||||
//!
|
||||
//! DeriveEd25519/DeriveEncryptionKey/Encrypt/Decrypt
|
||||
//! → require unlocked state (ServiceLocked error if locked)
|
||||
//! → derive key, return result
|
||||
//! → optionally cache derived key
|
||||
//!
|
||||
//! Lock
|
||||
//! → zeroize all cached derived keys
|
||||
//! → zeroize seed
|
||||
//! → drop all sensitive material
|
||||
//! → service returns to locked state
|
||||
//! ```
|
||||
//!
|
||||
//! # Dispatch Paths
|
||||
//!
|
||||
//! There are two ways to interact with the secret service:
|
||||
//!
|
||||
//! 1. **Local (in-process)**: `SecretServiceHandle` wraps `SecretServiceInner`
|
||||
//! behind `Arc<RwLock<>>` and provides direct method calls without serialization.
|
||||
//! 2. **Remote (in-cluster)**: `SecretServiceActor` processes `SecretMessage`
|
||||
//! variants from an mpsc channel and dispatches to the handle methods.
|
||||
//!
|
||||
//! # Assembly
|
||||
//!
|
||||
//! The `SecretService` is assembled by the CLI binary or NAPI layer. Per ADR-027,
|
||||
//! alknet-core never sees the secret service directly — it is wired through the
|
||||
//! `OperationEnv` dispatch mechanism. For minimal deployments, no secret service
|
||||
//! is available (the `SecretStoreCredentialProvider` returns `None`).
|
||||
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use base64::Engine;
|
||||
use irpc::WithChannels;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::cache::{CacheConfig, CachedKey, KeyCache};
|
||||
use crate::derivation::{self, DerivationError, PATHS};
|
||||
use crate::encryption::{self, EncryptedData, EncryptionKey};
|
||||
use crate::mnemonic::{Language, Mnemonic, Seed};
|
||||
use crate::protocol::{
|
||||
Decrypt, DeriveEd25519, DeriveEncryptionKey, DeriveEthereumKey, DerivePassword, Encrypt,
|
||||
SecretMessage, SecretProtocol, Unlock,
|
||||
};
|
||||
use crate::protocol::{DerivedKey, KeyType};
|
||||
|
||||
/// Handle to a running SecretService for local (in-process) use.
|
||||
///
|
||||
/// This is the primary API for local secret operations. It wraps the
|
||||
/// service state in an `Arc<RwLock<>>` for thread-safe access.
|
||||
#[derive(Clone)]
|
||||
pub struct SecretServiceHandle {
|
||||
inner: Arc<RwLock<SecretServiceInner>>,
|
||||
}
|
||||
|
||||
/// Internal state of the secret service.
|
||||
struct SecretServiceInner {
|
||||
/// The mnemonic phrase, if unlocked. None if locked.
|
||||
mnemonic: Option<Mnemonic>,
|
||||
/// The master seed, if unlocked. None if locked.
|
||||
seed: Option<Seed>,
|
||||
/// Whether the service is unlocked.
|
||||
unlocked: bool,
|
||||
/// TTL-based key cache with LRU eviction.
|
||||
cache: KeyCache,
|
||||
}
|
||||
|
||||
/// Errors that can occur during secret service operations.
|
||||
#[derive(Debug, thiserror::Error, Serialize, Deserialize)]
|
||||
pub enum SecretServiceError {
|
||||
#[error("service is locked; call Unlock first")]
|
||||
ServiceLocked,
|
||||
#[error("service is already unlocked")]
|
||||
AlreadyUnlocked,
|
||||
#[error("mnemonic error: {0}")]
|
||||
Mnemonic(String),
|
||||
#[error("derivation error: {0}")]
|
||||
Derivation(String),
|
||||
#[error("encryption error: {0}")]
|
||||
Encryption(String),
|
||||
#[error("invalid path: {0}")]
|
||||
InvalidPath(String),
|
||||
#[error("unsupported key type")]
|
||||
UnsupportedKeyType,
|
||||
}
|
||||
|
||||
impl From<crate::mnemonic::MnemonicError> for SecretServiceError {
|
||||
fn from(e: crate::mnemonic::MnemonicError) -> Self {
|
||||
SecretServiceError::Mnemonic(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DerivationError> for SecretServiceError {
|
||||
fn from(e: DerivationError) -> Self {
|
||||
SecretServiceError::Derivation(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<encryption::EncryptionError> for SecretServiceError {
|
||||
fn from(e: encryption::EncryptionError) -> Self {
|
||||
SecretServiceError::Encryption(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl SecretServiceHandle {
|
||||
/// Create a new SecretServiceHandle in the locked state with default cache config.
|
||||
pub fn new() -> Self {
|
||||
Self::with_cache_config(CacheConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new SecretServiceHandle with the given cache configuration.
|
||||
pub fn with_cache_config(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(SecretServiceInner {
|
||||
mnemonic: None,
|
||||
seed: None,
|
||||
unlocked: false,
|
||||
cache: KeyCache::new(config),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Unlock the service with an existing mnemonic phrase.
|
||||
///
|
||||
/// The passphrase is the BIP39 password (may be empty string for none).
|
||||
/// After unlocking, derive and encrypt/decrypt operations are available.
|
||||
pub fn unlock(&self, phrase: &str, passphrase: Option<&str>) -> Result<(), SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if inner.unlocked {
|
||||
return Err(SecretServiceError::AlreadyUnlocked);
|
||||
}
|
||||
|
||||
let mnemonic = Mnemonic::from_phrase(phrase, Language::English)?;
|
||||
let seed = mnemonic.to_seed(passphrase);
|
||||
|
||||
inner.mnemonic = Some(mnemonic);
|
||||
inner.seed = Some(seed);
|
||||
inner.unlocked = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unlock the service with a new randomly generated mnemonic.
|
||||
///
|
||||
/// Returns the generated mnemonic phrase. Store this phrase securely —
|
||||
/// it is the root of trust for all derived keys.
|
||||
pub fn unlock_new(&self, word_count: usize) -> Result<String, SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if inner.unlocked {
|
||||
return Err(SecretServiceError::AlreadyUnlocked);
|
||||
}
|
||||
|
||||
let mnemonic = Mnemonic::generate(word_count)?;
|
||||
let seed = mnemonic.to_seed(None);
|
||||
let phrase = mnemonic.phrase().to_string();
|
||||
|
||||
inner.mnemonic = Some(mnemonic);
|
||||
inner.seed = Some(seed);
|
||||
inner.unlocked = true;
|
||||
Ok(phrase)
|
||||
}
|
||||
|
||||
/// Lock the service, purging the seed and all cached derived keys.
|
||||
///
|
||||
/// After locking, no derive/encrypt/decrypt operations are possible
|
||||
/// until `unlock` is called again. Calls `zeroize()` on all sensitive
|
||||
/// material per ADR-038.
|
||||
pub fn lock(&self) {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
inner.cache.clear();
|
||||
inner.seed = None;
|
||||
inner.mnemonic = None;
|
||||
inner.unlocked = false;
|
||||
}
|
||||
|
||||
/// Check whether the service is currently unlocked.
|
||||
pub fn is_unlocked(&self) -> bool {
|
||||
self.inner.read().unwrap().unlocked
|
||||
}
|
||||
|
||||
/// Derive an Ed25519 keypair at the given path.
|
||||
pub fn derive_ed25519(&self, path: &str) -> Result<DerivedKey, SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
|
||||
if let Some(cached) = inner.cache.get(path) {
|
||||
return Ok(DerivedKey {
|
||||
key_type: cached.key_type.clone(),
|
||||
private_key: cached.private_key.clone(),
|
||||
public_key: cached.public_key.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
let key = derivation::derive_path_from_seed(seed.as_bytes(), path)?;
|
||||
let private_key = key.private_key().to_vec();
|
||||
let public_key = key.public_key().to_vec();
|
||||
let cached = CachedKey::new(KeyType::Ed25519, private_key.clone(), public_key.clone());
|
||||
inner.cache.insert(path, cached);
|
||||
Ok(DerivedKey {
|
||||
key_type: KeyType::Ed25519,
|
||||
private_key,
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Derive an AES-256-GCM encryption key at the given path.
|
||||
pub fn derive_encryption_key(&self, path: &str) -> Result<DerivedKey, SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
|
||||
if let Some(cached) = inner.cache.get(path) {
|
||||
return Ok(DerivedKey {
|
||||
key_type: cached.key_type.clone(),
|
||||
private_key: cached.private_key.clone(),
|
||||
public_key: cached.public_key.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
let key = derivation::derive_path_from_seed(seed.as_bytes(), path)?;
|
||||
let private_key = key.private_key().to_vec();
|
||||
let public_key = key.public_key().to_vec();
|
||||
let cached = CachedKey::new(KeyType::Aes256Gcm, private_key.clone(), public_key.clone());
|
||||
inner.cache.insert(path, cached);
|
||||
Ok(DerivedKey {
|
||||
key_type: KeyType::Aes256Gcm,
|
||||
private_key,
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Derive a secp256k1 (Ethereum) keypair at the given path.
|
||||
///
|
||||
/// Uses BIP-0032 derivation (HMAC-SHA512 with "Bitcoin seed") when the
|
||||
/// `secp256k1` feature is enabled. Returns `UnsupportedKeyType` when the
|
||||
/// feature is disabled.
|
||||
pub fn derive_ethereum_key(&self, path: &str) -> Result<DerivedKey, SecretServiceError> {
|
||||
#[cfg(feature = "secp256k1")]
|
||||
{
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
|
||||
if let Some(cached) = inner.cache.get(path) {
|
||||
return Ok(DerivedKey {
|
||||
key_type: cached.key_type.clone(),
|
||||
private_key: cached.private_key.clone(),
|
||||
public_key: cached.public_key.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
|
||||
let key = crate::ethereum::derive_secp256k1_path(seed.as_bytes(), path)?;
|
||||
let private_key = key.private_key().to_vec();
|
||||
let public_key = key.public_key().to_vec();
|
||||
let cached =
|
||||
CachedKey::new(KeyType::Secp256k1, private_key.clone(), public_key.clone());
|
||||
inner.cache.insert(path, cached);
|
||||
Ok(DerivedKey {
|
||||
key_type: KeyType::Secp256k1,
|
||||
private_key,
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "secp256k1"))]
|
||||
{
|
||||
let _ = path;
|
||||
Err(SecretServiceError::UnsupportedKeyType)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn derive_password(
|
||||
&self,
|
||||
path: &str,
|
||||
length: usize,
|
||||
) -> Result<Vec<u8>, SecretServiceError> {
|
||||
let inner = self.inner.read().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
|
||||
let key = derivation::derive_path_from_seed(seed.as_bytes(), path)?;
|
||||
let private_key = key.private_key();
|
||||
let truncated_len = length.min(private_key.len());
|
||||
let result = private_key[..truncated_len].to_vec();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn derive_password_string(
|
||||
&self,
|
||||
path: &str,
|
||||
length: usize,
|
||||
) -> Result<String, SecretServiceError> {
|
||||
let bytes = self.derive_password(path, length)?;
|
||||
Ok(URL_SAFE_NO_PAD.encode(&bytes))
|
||||
}
|
||||
|
||||
/// Encrypt plaintext using the derived encryption key.
|
||||
///
|
||||
/// Uses the key at path `m/74'/2'/0'/0'` (PATHS::ENCRYPTION) by default.
|
||||
pub fn encrypt(
|
||||
&self,
|
||||
plaintext: &str,
|
||||
key_version: u32,
|
||||
) -> Result<EncryptedData, SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
|
||||
let private_key = if let Some(cached) = inner.cache.get(PATHS::ENCRYPTION) {
|
||||
cached.private_key.clone()
|
||||
} else {
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?;
|
||||
let pk = derived.private_key().to_vec();
|
||||
let pubk = derived.public_key().to_vec();
|
||||
let cached = CachedKey::new(KeyType::Aes256Gcm, pk.clone(), pubk);
|
||||
inner.cache.insert(PATHS::ENCRYPTION, cached);
|
||||
pk
|
||||
};
|
||||
|
||||
let enc_key = EncryptionKey::from_derived_bytes(&private_key, key_version);
|
||||
|
||||
encryption::encrypt(plaintext, &enc_key).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Decrypt an EncryptedData blob using the derived encryption key.
|
||||
pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<String, SecretServiceError> {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
if !inner.unlocked {
|
||||
return Err(SecretServiceError::ServiceLocked);
|
||||
}
|
||||
|
||||
let private_key = if let Some(cached) = inner.cache.get(PATHS::ENCRYPTION) {
|
||||
cached.private_key.clone()
|
||||
} else {
|
||||
let seed = inner
|
||||
.seed
|
||||
.as_ref()
|
||||
.ok_or(SecretServiceError::ServiceLocked)?;
|
||||
let derived = derivation::derive_path_from_seed(seed.as_bytes(), PATHS::ENCRYPTION)?;
|
||||
let pk = derived.private_key().to_vec();
|
||||
let pubk = derived.public_key().to_vec();
|
||||
let cached = CachedKey::new(KeyType::Aes256Gcm, pk.clone(), pubk);
|
||||
inner.cache.insert(PATHS::ENCRYPTION, cached);
|
||||
pk
|
||||
};
|
||||
|
||||
let enc_key = EncryptionKey::from_derived_bytes(&private_key, encrypted.key_version);
|
||||
|
||||
encryption::decrypt(encrypted, &enc_key).map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecretServiceHandle {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The SecretService manages the lifecycle of the master seed and provides
|
||||
/// secret operations. This is the type used by the irpc service handler.
|
||||
///
|
||||
/// For local (in-process) use, prefer `SecretServiceHandle` which wraps
|
||||
/// this in thread-safe locks.
|
||||
pub struct SecretService {
|
||||
handle: SecretServiceHandle,
|
||||
}
|
||||
|
||||
impl SecretService {
|
||||
/// Create a new SecretService in the locked state.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handle: SecretServiceHandle::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a handle for local (in-process) use.
|
||||
pub fn handle(&self) -> &SecretServiceHandle {
|
||||
&self.handle
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecretService {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Actor that processes `SecretMessage` variants and dispatches to `SecretServiceHandle`.
|
||||
///
|
||||
/// The actor runs as a `tokio::task`, receives messages from an mpsc channel,
|
||||
/// dispatches to the handle methods, and sends responses through oneshot channels.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```ignore
|
||||
/// let handle = SecretServiceHandle::new();
|
||||
/// let (client, actor) = SecretServiceActor::spawn(handle);
|
||||
/// tokio::task::spawn(actor.run(rx));
|
||||
/// // Use client to send messages
|
||||
/// ```
|
||||
pub struct SecretServiceActor {
|
||||
handle: SecretServiceHandle,
|
||||
}
|
||||
|
||||
impl SecretServiceActor {
|
||||
/// Create a new actor wrapping the given handle.
|
||||
pub fn new(handle: SecretServiceHandle) -> Self {
|
||||
Self { handle }
|
||||
}
|
||||
|
||||
/// Run the actor message loop, processing `SecretMessage` variants.
|
||||
///
|
||||
/// This method runs until the receiver channel is closed. Each message
|
||||
/// variant is dispatched to the corresponding `SecretServiceHandle` method
|
||||
/// and the response is sent through the oneshot channel embedded in the message.
|
||||
pub async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SecretMessage>) {
|
||||
while let Some(msg) = rx.recv().await {
|
||||
self.handle_message(msg);
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn the actor as a `tokio::task` and return a `Client<SecretProtocol>` for sending messages.
|
||||
///
|
||||
/// The actor runs on a tokio task and processes messages from the mpsc channel.
|
||||
/// The returned `Client<SecretProtocol>` can be used to send `SecretMessage` variants
|
||||
/// to the actor.
|
||||
pub fn spawn(
|
||||
handle: SecretServiceHandle,
|
||||
) -> (irpc::Client<SecretProtocol>, SecretServiceActor) {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
let client = irpc::Client::local(tx);
|
||||
let actor = Self::new(handle.clone());
|
||||
tokio::task::spawn(actor.run(rx));
|
||||
(client, Self::new(handle))
|
||||
}
|
||||
|
||||
/// Handle a single `SecretMessage` by dispatching to the appropriate handle method.
|
||||
fn handle_message(&mut self, msg: SecretMessage) {
|
||||
match msg {
|
||||
SecretMessage::DeriveEd25519(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let DeriveEd25519 { path } = inner;
|
||||
let result = self.handle.derive_ed25519(&path);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::DeriveEncryptionKey(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let DeriveEncryptionKey { path } = inner;
|
||||
let result = self.handle.derive_encryption_key(&path);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::DeriveEthereumKey(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let DeriveEthereumKey { path } = inner;
|
||||
let result = self.handle.derive_ethereum_key(&path);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::DerivePassword(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let DerivePassword { path, length } = inner;
|
||||
let result = self.handle.derive_password(&path, length);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::Encrypt(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let Encrypt {
|
||||
plaintext,
|
||||
key_version,
|
||||
} = inner;
|
||||
let result = self.handle.encrypt(&plaintext, key_version);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::Decrypt(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let Decrypt { encrypted } = inner;
|
||||
let result = self.handle.decrypt(&encrypted);
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::Lock(msg) => {
|
||||
let WithChannels { inner: _, tx, .. } = msg;
|
||||
self.handle.lock();
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(Ok(())).await;
|
||||
});
|
||||
}
|
||||
SecretMessage::Unlock(msg) => {
|
||||
let WithChannels { inner, tx, .. } = msg;
|
||||
let Unlock {
|
||||
mnemonic,
|
||||
passphrase,
|
||||
} = inner;
|
||||
let result = self.handle.unlock(&mnemonic, passphrase.as_deref());
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(result).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::Lock;
|
||||
use irpc::channel::oneshot;
|
||||
use irpc::WithChannels;
|
||||
|
||||
#[test]
|
||||
fn test_service_starts_locked() {
|
||||
let service = SecretServiceHandle::new();
|
||||
assert!(!service.is_unlocked());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unlock_new_generates_mnemonic() {
|
||||
let service = SecretServiceHandle::new();
|
||||
let phrase = service.unlock_new(24).unwrap();
|
||||
assert!(!phrase.is_empty());
|
||||
assert!(service.is_unlocked());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lock_purges_state() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
assert!(service.is_unlocked());
|
||||
|
||||
service.lock();
|
||||
assert!(!service.is_unlocked());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_on_locked_fails() {
|
||||
let service = SecretServiceHandle::new();
|
||||
let result = service.derive_ed25519(PATHS::IDENTITY);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_on_locked_fails() {
|
||||
let service = SecretServiceHandle::new();
|
||||
let result = service.encrypt("secret", 1);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_lifecycle() {
|
||||
let service = SecretServiceHandle::new();
|
||||
|
||||
assert!(!service.is_unlocked());
|
||||
|
||||
assert!(service.derive_ed25519(PATHS::IDENTITY).is_err());
|
||||
|
||||
let _phrase = service.unlock_new(24).unwrap();
|
||||
assert!(service.is_unlocked());
|
||||
|
||||
let key = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
assert!(!key.private_key.is_empty());
|
||||
|
||||
service.lock();
|
||||
assert!(!service.is_unlocked());
|
||||
|
||||
assert!(service.derive_ed25519(PATHS::IDENTITY).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unlock_with_known_phrase() {
|
||||
let service = SecretServiceHandle::new();
|
||||
|
||||
let phrase = service.unlock_new(24).unwrap();
|
||||
service.lock();
|
||||
|
||||
service.unlock(&phrase, None).unwrap();
|
||||
assert!(service.is_unlocked());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_unlock_fails() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let result = service.unlock_new(12);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_lifecycle() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let plaintext = "my-api-key-12345";
|
||||
let encrypted = service.encrypt(plaintext, 1).unwrap();
|
||||
let decrypted = service.decrypt(&encrypted).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
|
||||
service.lock();
|
||||
assert!(service.decrypt(&encrypted).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_password_deterministic() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let path = "m/74'/1'/0'/12345'";
|
||||
let pw1 = service.derive_password(path, 16).unwrap();
|
||||
let pw2 = service.derive_password(path, 16).unwrap();
|
||||
assert_eq!(pw1, pw2, "derive_password must be deterministic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_password_different_paths() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let pw_a = service.derive_password("m/74'/1'/0'/100'", 16).unwrap();
|
||||
let pw_b = service.derive_password("m/74'/1'/0'/200'", 16).unwrap();
|
||||
assert_ne!(
|
||||
pw_a, pw_b,
|
||||
"different paths must produce different passwords"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_password_length_truncation() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let path = "m/74'/1'/0'/999'";
|
||||
let pw_full = service.derive_password(path, 32).unwrap();
|
||||
let pw_short = service.derive_password(path, 16).unwrap();
|
||||
|
||||
assert_eq!(pw_short.len(), 16);
|
||||
assert_eq!(pw_full.len(), 32);
|
||||
assert_eq!(
|
||||
&pw_full[..16],
|
||||
&pw_short[..],
|
||||
"truncated bytes must match prefix of full key"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_password_locked_error() {
|
||||
let service = SecretServiceHandle::new();
|
||||
let result = service.derive_password("m/74'/1'/0'/1'", 16);
|
||||
assert!(matches!(result, Err(SecretServiceError::ServiceLocked)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_password_string_base64url() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let path = "m/74'/1'/0'/42'";
|
||||
let encoded = service.derive_password_string(path, 16).unwrap();
|
||||
|
||||
assert!(!encoded.contains('='), "Base64url must not contain padding");
|
||||
assert!(
|
||||
encoded
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
|
||||
"Base64url must only contain URL-safe characters"
|
||||
);
|
||||
|
||||
let raw_bytes = service.derive_password(path, 16).unwrap();
|
||||
let decoded = URL_SAFE_NO_PAD.decode(&encoded).unwrap();
|
||||
assert_eq!(raw_bytes, decoded);
|
||||
}
|
||||
|
||||
#[cfg(feature = "secp256k1")]
|
||||
#[test]
|
||||
fn test_derive_ethereum_key_bip32() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let key = service.derive_ethereum_key(PATHS::ETHEREUM).unwrap();
|
||||
assert_eq!(key.key_type, KeyType::Secp256k1);
|
||||
assert_eq!(key.private_key.len(), 32);
|
||||
assert_eq!(key.public_key.len(), 33);
|
||||
}
|
||||
|
||||
#[cfg(feature = "secp256k1")]
|
||||
#[test]
|
||||
fn test_ethereum_key_differs_from_ed25519() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let eth_key = service.derive_ethereum_key(PATHS::ETHEREUM).unwrap();
|
||||
let ed_key = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
assert_ne!(eth_key.private_key, ed_key.private_key);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "secp256k1"))]
|
||||
#[test]
|
||||
fn test_derive_ethereum_key_unsupported_without_feature() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let result = service.derive_ethereum_key(PATHS::ETHEREUM);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(SecretServiceError::UnsupportedKeyType)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_hit_avoids_re_derivation() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let key1 = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
let key2 = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
assert_eq!(key1.private_key, key2.private_key);
|
||||
assert_eq!(key1.public_key, key2.public_key);
|
||||
|
||||
let cache_len = service.inner.read().unwrap().cache.len();
|
||||
assert_eq!(cache_len, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_miss_derives_and_caches() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 0);
|
||||
|
||||
service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_entry_evicted_on_access() {
|
||||
let config = crate::cache::CacheConfig::new(std::time::Duration::from_millis(5), 64);
|
||||
let service = SecretServiceHandle::with_cache_config(config);
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let key1 = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 1);
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
|
||||
let key2 = service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
assert_eq!(key1.private_key, key2.private_key);
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lru_eviction_when_over_max_entries() {
|
||||
let config = crate::cache::CacheConfig::new(std::time::Duration::from_secs(3600), 2);
|
||||
let service = SecretServiceHandle::with_cache_config(config);
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
service.derive_ed25519(PATHS::SSH_HOST).unwrap();
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 2);
|
||||
|
||||
service.derive_ed25519(PATHS::ENCRYPTION).unwrap();
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 2);
|
||||
|
||||
let mut inner = service.inner.write().unwrap();
|
||||
assert!(inner.cache.get(PATHS::IDENTITY).is_none());
|
||||
assert!(inner.cache.get(PATHS::SSH_HOST).is_some());
|
||||
assert!(inner.cache.get(PATHS::ENCRYPTION).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lock_clears_all_cache_entries() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
service.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
service.derive_ed25519(PATHS::SSH_HOST).unwrap();
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 2);
|
||||
|
||||
service.lock();
|
||||
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_uses_cached_encryption_key() {
|
||||
let service = SecretServiceHandle::new();
|
||||
service.unlock_new(24).unwrap();
|
||||
|
||||
let plaintext = "cached-encryption-test";
|
||||
let encrypted = service.encrypt(plaintext, 1).unwrap();
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 1);
|
||||
|
||||
let decrypted = service.decrypt(&encrypted).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
|
||||
assert_eq!(service.inner.read().unwrap().cache.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_actor_unlock_responds_successfully() {
|
||||
let handle = SecretServiceHandle::new();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
let actor = SecretServiceActor::new(handle);
|
||||
tokio::task::spawn(actor.run(rx));
|
||||
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
let msg = SecretMessage::Unlock(WithChannels::from((
|
||||
Unlock {
|
||||
mnemonic: "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about".to_string(),
|
||||
passphrase: None,
|
||||
},
|
||||
resp_tx,
|
||||
)));
|
||||
tx.send(msg).await.unwrap();
|
||||
|
||||
let result = resp_rx.await.unwrap();
|
||||
assert!(result.is_ok(), "Unlock via actor must succeed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_actor_derive_ed25519_returns_key() {
|
||||
let handle = SecretServiceHandle::new();
|
||||
handle.unlock_new(24).unwrap();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
let actor = SecretServiceActor::new(handle);
|
||||
tokio::task::spawn(actor.run(rx));
|
||||
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
let msg = SecretMessage::DeriveEd25519(WithChannels::from((
|
||||
DeriveEd25519 {
|
||||
path: PATHS::IDENTITY.to_string(),
|
||||
},
|
||||
resp_tx,
|
||||
)));
|
||||
tx.send(msg).await.unwrap();
|
||||
|
||||
let result = resp_rx.await.unwrap();
|
||||
assert!(result.is_ok(), "DeriveEd25519 via actor must succeed");
|
||||
let key = result.unwrap();
|
||||
assert!(
|
||||
!key.private_key.is_empty(),
|
||||
"DerivedKey must have private_key"
|
||||
);
|
||||
assert_eq!(key.key_type, KeyType::Ed25519);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_actor_lock_clears_state() {
|
||||
let handle = SecretServiceHandle::new();
|
||||
handle.unlock_new(24).unwrap();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
let actor = SecretServiceActor::new(handle.clone());
|
||||
tokio::task::spawn(actor.run(rx));
|
||||
|
||||
let (resp_tx, resp_rx): (oneshot::Sender<Result<(), SecretServiceError>>, _) =
|
||||
oneshot::channel();
|
||||
let msg = SecretMessage::Lock(WithChannels::from((Lock, resp_tx)));
|
||||
tx.send(msg).await.unwrap();
|
||||
|
||||
let result = resp_rx.await.unwrap();
|
||||
assert!(result.is_ok(), "Lock via actor must succeed");
|
||||
assert!(!handle.is_unlocked(), "Handle must be locked after Lock");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unlock_with_passphrase_produces_different_seed() {
|
||||
let service_a = SecretServiceHandle::new();
|
||||
let service_b = SecretServiceHandle::new();
|
||||
|
||||
let phrase = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
|
||||
|
||||
service_a.unlock(phrase, None).unwrap();
|
||||
let key_a = service_a.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
service_a.lock();
|
||||
|
||||
service_a.unlock(phrase, Some("TREZOR")).unwrap();
|
||||
let key_b = service_a.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
assert_ne!(
|
||||
key_a.private_key, key_b.private_key,
|
||||
"Unlock with passphrase must produce different seed than without"
|
||||
);
|
||||
|
||||
service_a.lock();
|
||||
|
||||
service_b.unlock(phrase, None).unwrap();
|
||||
let key_c = service_b.derive_ed25519(PATHS::IDENTITY).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
key_a.private_key, key_c.private_key,
|
||||
"Unlock with None passphrase must produce same seed as another None passphrase unlock"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_actor_unlock_with_passphrase() {
|
||||
let handle = SecretServiceHandle::new();
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
let actor = SecretServiceActor::new(handle);
|
||||
tokio::task::spawn(actor.run(rx));
|
||||
|
||||
let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
|
||||
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
let msg = SecretMessage::Unlock(WithChannels::from((
|
||||
Unlock {
|
||||
mnemonic: mnemonic.to_string(),
|
||||
passphrase: Some("TREZOR".to_string()),
|
||||
},
|
||||
resp_tx,
|
||||
)));
|
||||
tx.send(msg).await.unwrap();
|
||||
|
||||
let result = resp_rx.await.unwrap();
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Unlock with passphrase via actor must succeed"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,20 +1,20 @@
|
||||
[package]
|
||||
name = "alknet-secret"
|
||||
name = "alknet-vault"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "BIP39 mnemonic generation, SLIP-0010 Ed25519 HD key derivation, AES-256-GCM encryption, and SecretProtocol irpc service for alknet"
|
||||
description = "Local key vault: BIP39 mnemonic generation, SLIP-0010 Ed25519 HD key derivation, AES-256-GCM encryption for securing provider keys, credentials, and identity material"
|
||||
repository.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "alknet_secret"
|
||||
name = "alknet_vault"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
secp256k1 = ["dep:secp256k1"]
|
||||
|
||||
[dependencies]
|
||||
bip39 = { version = "2", features = ["rand"] }
|
||||
bip39 = { version = "2", features = ["rand", "zeroize"] }
|
||||
ed25519-bip32 = "0.4"
|
||||
aes-gcm = "0.10"
|
||||
sha2 = "0.10"
|
||||
@@ -25,11 +25,7 @@ zeroize = { version = "1", features = ["derive"] }
|
||||
hmac = "0.12"
|
||||
rand = "0.8"
|
||||
base64 = "0.22"
|
||||
irpc = { workspace = true }
|
||||
irpc-derive = { workspace = true }
|
||||
tokio = { version = "1", features = ["sync", "rt", "macros"] }
|
||||
secp256k1 = { version = "0.29", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
hex = "0.4"
|
||||
postcard = { version = "1", features = ["alloc"] }
|
||||
hex = "0.4"
|
||||
@@ -1,4 +1,4 @@
|
||||
//! TTL-based key cache with LRU eviction for SecretService.
|
||||
//! TTL-based key cache with LRU eviction for VaultService.
|
||||
//!
|
||||
//! The `KeyCache` stores derived key material keyed by derivation path. Entries
|
||||
//! expire after a configurable TTL (default: 1 hour) and are evicted lazily on
|
||||
@@ -10,7 +10,7 @@ use std::time::{Duration, Instant};
|
||||
|
||||
use zeroize::Zeroize;
|
||||
|
||||
use crate::protocol::KeyType;
|
||||
use crate::protocol::{DerivedKey, KeyType};
|
||||
|
||||
/// Default TTL for cached keys (1 hour).
|
||||
pub const DEFAULT_TTL: Duration = Duration::from_secs(3600);
|
||||
@@ -18,47 +18,53 @@ pub const DEFAULT_TTL: Duration = Duration::from_secs(3600);
|
||||
/// Default maximum number of cache entries.
|
||||
pub const DEFAULT_MAX_ENTRIES: usize = 64;
|
||||
|
||||
/// A cached derived key with metadata for TTL and LRU tracking.
|
||||
/// A cached derived key. Wraps a `DerivedKey` with cache metadata.
|
||||
///
|
||||
/// The `private_key` field is zeroized on drop via `#[zeroize(drop)]`.
|
||||
/// This is a separate internal type from `DerivedKey` — it holds the same
|
||||
/// data but is managed within the cache lifecycle.
|
||||
/// Derives `Zeroize` and `ZeroizeOnDrop` — the private key is zeroized
|
||||
/// when the entry is evicted (LRU/TTL) or the cache is cleared.
|
||||
#[derive(Zeroize)]
|
||||
#[zeroize(drop)]
|
||||
pub struct CachedKey {
|
||||
/// When this key was derived (for TTL checking).
|
||||
/// The derived key (zeroized on drop).
|
||||
#[zeroize(skip)]
|
||||
pub derived_at: Instant,
|
||||
/// The type of key that was derived.
|
||||
pub key: DerivedKey,
|
||||
/// When the entry was inserted (for TTL).
|
||||
#[zeroize(skip)]
|
||||
pub key_type: KeyType,
|
||||
/// The private key bytes (sensitive — zeroized on drop).
|
||||
#[zeroize]
|
||||
pub private_key: Vec<u8>,
|
||||
/// The public key bytes.
|
||||
#[zeroize(skip)]
|
||||
pub public_key: Vec<u8>,
|
||||
pub cached_at: Instant,
|
||||
/// Last access time for LRU ordering.
|
||||
#[zeroize(skip)]
|
||||
last_accessed: Instant,
|
||||
}
|
||||
|
||||
impl CachedKey {
|
||||
/// Create a new `CachedKey` from derived key material.
|
||||
pub fn new(key_type: KeyType, private_key: Vec<u8>, public_key: Vec<u8>) -> Self {
|
||||
/// Create a new `CachedKey` from a `DerivedKey`.
|
||||
pub fn new(key: DerivedKey) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
derived_at: now,
|
||||
key_type,
|
||||
private_key,
|
||||
public_key,
|
||||
key,
|
||||
cached_at: now,
|
||||
last_accessed: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// The key type of the cached derived key.
|
||||
pub fn key_type(&self) -> &KeyType {
|
||||
&self.key.key_type
|
||||
}
|
||||
|
||||
/// The private key bytes of the cached derived key.
|
||||
pub fn private_key(&self) -> &[u8] {
|
||||
&self.key.private_key
|
||||
}
|
||||
|
||||
/// The public key bytes of the cached derived key.
|
||||
pub fn public_key(&self) -> &[u8] {
|
||||
&self.key.public_key
|
||||
}
|
||||
|
||||
/// Check whether this cached entry has expired.
|
||||
pub fn is_expired(&self, ttl: Duration) -> bool {
|
||||
Instant::now().duration_since(self.derived_at) > ttl
|
||||
Instant::now().duration_since(self.cached_at) > ttl
|
||||
}
|
||||
|
||||
/// Touch the entry to update its last-accessed time (for LRU).
|
||||
@@ -206,12 +212,92 @@ impl Default for KeyCache {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod drop_tracker {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct DropTrackedKey {
|
||||
flag: Arc<AtomicBool>,
|
||||
bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
impl DropTrackedKey {
|
||||
fn new(flag: &Arc<AtomicBool>) -> Self {
|
||||
Self {
|
||||
flag: flag.clone(),
|
||||
bytes: vec![0xABu8; 32],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DropTrackedKey {
|
||||
fn drop(&mut self) {
|
||||
for b in self.bytes.iter_mut() {
|
||||
*b = 0;
|
||||
}
|
||||
self.flag.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hashmap_clear_drops_values_triggering_drop_impls() {
|
||||
let flag1 = Arc::new(AtomicBool::new(false));
|
||||
let flag2 = Arc::new(AtomicBool::new(false));
|
||||
let mut map: HashMap<String, DropTrackedKey> = HashMap::new();
|
||||
map.insert("path1".to_string(), DropTrackedKey::new(&flag1));
|
||||
map.insert("path2".to_string(), DropTrackedKey::new(&flag2));
|
||||
|
||||
assert!(!flag1.load(Ordering::SeqCst));
|
||||
assert!(!flag2.load(Ordering::SeqCst));
|
||||
|
||||
map.clear();
|
||||
|
||||
assert!(flag1.load(Ordering::SeqCst));
|
||||
assert!(flag2.load(Ordering::SeqCst));
|
||||
assert!(map.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hashmap_remove_drops_value_triggering_drop_impl() {
|
||||
let flag = Arc::new(AtomicBool::new(false));
|
||||
let mut map: HashMap<String, DropTrackedKey> = HashMap::new();
|
||||
map.insert("path1".to_string(), DropTrackedKey::new(&flag));
|
||||
|
||||
assert!(!flag.load(Ordering::SeqCst));
|
||||
|
||||
map.remove("path1");
|
||||
|
||||
assert!(flag.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hashmap_insert_replace_drops_old_value() {
|
||||
let flag_old = Arc::new(AtomicBool::new(false));
|
||||
let mut map: HashMap<String, DropTrackedKey> = HashMap::new();
|
||||
map.insert("path1".to_string(), DropTrackedKey::new(&flag_old));
|
||||
|
||||
assert!(!flag_old.load(Ordering::SeqCst));
|
||||
|
||||
let flag_new = Arc::new(AtomicBool::new(false));
|
||||
map.insert("path1".to_string(), DropTrackedKey::new(&flag_new));
|
||||
|
||||
assert!(flag_old.load(Ordering::SeqCst));
|
||||
assert!(!flag_new.load(Ordering::SeqCst));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_cached_key(key_type: KeyType) -> CachedKey {
|
||||
CachedKey::new(key_type, vec![0xABu8; 32], vec![0xCDu8; 32])
|
||||
CachedKey::new(DerivedKey {
|
||||
key_type,
|
||||
private_key: vec![0xABu8; 32],
|
||||
public_key: vec![0xCDu8; 32],
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -220,7 +306,7 @@ mod tests {
|
||||
cache.insert("m/74'/0'/0'/0'", make_cached_key(KeyType::Ed25519));
|
||||
|
||||
let entry = cache.get("m/74'/0'/0'/0'").unwrap();
|
||||
assert_eq!(entry.key_type, KeyType::Ed25519);
|
||||
assert_eq!(*entry.key_type(), KeyType::Ed25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -231,8 +317,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_cache_expired_entry_evicted_on_access() {
|
||||
let mut config = CacheConfig::default();
|
||||
config.ttl = Duration::from_millis(1);
|
||||
let config = CacheConfig {
|
||||
ttl: Duration::from_millis(1),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
cache.insert("m/74'/0'/0'/0'", make_cached_key(KeyType::Ed25519));
|
||||
@@ -245,8 +333,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_cache_lru_eviction() {
|
||||
let mut config = CacheConfig::default();
|
||||
config.max_entries = 3;
|
||||
let config = CacheConfig {
|
||||
max_entries: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
|
||||
@@ -267,8 +357,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_cache_lru_access_reorders() {
|
||||
let mut config = CacheConfig::default();
|
||||
config.max_entries = 3;
|
||||
let config = CacheConfig {
|
||||
max_entries: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
|
||||
@@ -303,8 +395,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_evict_expired_removes_only_expired() {
|
||||
let mut config = CacheConfig::default();
|
||||
config.ttl = Duration::from_millis(10);
|
||||
let config = CacheConfig {
|
||||
ttl: Duration::from_millis(10),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
cache.insert("path1", make_cached_key(KeyType::Ed25519));
|
||||
@@ -324,16 +418,80 @@ mod tests {
|
||||
let mut cache = KeyCache::with_defaults();
|
||||
cache.insert(
|
||||
"path1",
|
||||
CachedKey::new(KeyType::Ed25519, vec![1u8; 32], vec![2u8; 32]),
|
||||
CachedKey::new(DerivedKey {
|
||||
key_type: KeyType::Ed25519,
|
||||
private_key: vec![1u8; 32],
|
||||
public_key: vec![2u8; 32],
|
||||
}),
|
||||
);
|
||||
cache.insert(
|
||||
"path1",
|
||||
CachedKey::new(KeyType::Aes256Gcm, vec![3u8; 32], vec![4u8; 32]),
|
||||
CachedKey::new(DerivedKey {
|
||||
key_type: KeyType::Aes256Gcm,
|
||||
private_key: vec![3u8; 32],
|
||||
public_key: vec![4u8; 32],
|
||||
}),
|
||||
);
|
||||
|
||||
let entry = cache.get("path1").unwrap();
|
||||
assert_eq!(entry.key_type, KeyType::Aes256Gcm);
|
||||
assert_eq!(entry.private_key, vec![3u8; 32]);
|
||||
assert_eq!(*entry.key_type(), KeyType::Aes256Gcm);
|
||||
assert_eq!(entry.private_key(), vec![3u8; 32]);
|
||||
assert_eq!(cache.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lru_eviction_drops_evicted_cached_key() {
|
||||
let config = CacheConfig {
|
||||
max_entries: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
|
||||
cache.insert("path1", make_cached_key(KeyType::Ed25519));
|
||||
cache.insert("path2", make_cached_key(KeyType::Aes256Gcm));
|
||||
assert_eq!(cache.len(), 2);
|
||||
|
||||
cache.insert("path3", make_cached_key(KeyType::Secp256k1));
|
||||
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(cache.get("path1").is_none());
|
||||
assert!(cache.get("path2").is_some());
|
||||
assert!(cache.get("path3").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ttl_expiry_evicts_entry_on_access() {
|
||||
let config = CacheConfig {
|
||||
ttl: Duration::from_millis(1),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut cache = KeyCache::new(config);
|
||||
cache.insert("path1", make_cached_key(KeyType::Ed25519));
|
||||
assert_eq!(cache.len(), 1);
|
||||
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
|
||||
assert!(cache.get("path1").is_none());
|
||||
assert_eq!(cache.len(), 0);
|
||||
assert!(cache.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_removes_all_entries_and_empties_cache() {
|
||||
let mut cache = KeyCache::with_defaults();
|
||||
cache.insert("path1", make_cached_key(KeyType::Ed25519));
|
||||
cache.insert("path2", make_cached_key(KeyType::Aes256Gcm));
|
||||
cache.insert("path3", make_cached_key(KeyType::Secp256k1));
|
||||
assert_eq!(cache.len(), 3);
|
||||
|
||||
cache.clear();
|
||||
|
||||
assert_eq!(cache.len(), 0);
|
||||
assert!(cache.is_empty());
|
||||
assert!(cache.get("path1").is_none());
|
||||
assert!(cache.get("path2").is_none());
|
||||
assert!(cache.get("path3").is_none());
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@
|
||||
//! | `m/74'/0'/0'/0'` | Primary identity keypair | Ed25519 (alknet auth) |
|
||||
//! | `m/74'/0'/0'/{n}'` | Worker/device identity | Ed25519 |
|
||||
//! | `m/74'/0'/1'/0'` | SSH host key | Ed25519 |
|
||||
//! | `m/74'/1'/0'/{hash}'` | Site-specific password | Deterministic |
|
||||
//! | `m/74'/2'/0'/0'` | Encryption key for external credentials | AES-256-GCM |
|
||||
//! | `m/44'/60'/0'/0/0` | Ethereum signing key | secp256k1 |
|
||||
|
||||
@@ -24,7 +23,7 @@ type HmacSha512 = Hmac<Sha512>;
|
||||
|
||||
/// Well-known derivation path constants for alknet key material.
|
||||
///
|
||||
/// These paths are defined once and referenced by both the secret service and
|
||||
/// These paths are defined once and referenced by both the vault service and
|
||||
/// external consumers that need to request specific key types.
|
||||
#[allow(non_snake_case)]
|
||||
pub mod PATHS {
|
||||
@@ -52,13 +51,21 @@ pub fn device_path(index: u32) -> String {
|
||||
format!("m/74'/0'/0'/{}'", index)
|
||||
}
|
||||
|
||||
/// Construct a site-specific password derivation path with the given hash.
|
||||
/// Construct the version-indexed encryption key derivation path (ADR-021).
|
||||
///
|
||||
/// Path: `m/74'/1'/0'/{hash}'`
|
||||
pub fn site_password_path(site_hash: &str) -> String {
|
||||
format!("m/74'/1'/0'/{}'", site_hash)
|
||||
/// Maps a key version to its derivation path: v2 → `m/74'/2'/0'/0'`
|
||||
/// (which is `PATHS::ENCRYPTION`), v3 → `m/74'/2'/0'/1'`, etc. Returns
|
||||
/// `DerivationError::InvalidPath` for `version < 2` — v1 is reserved for
|
||||
/// the TypeScript PBKDF2 legacy (ADR-020), which the vault cannot derive,
|
||||
/// and v0 is meaningless.
|
||||
pub fn encryption_path_for_version(version: u32) -> Result<String, DerivationError> {
|
||||
if version < 2 {
|
||||
return Err(DerivationError::InvalidPath(format!(
|
||||
"key version {version} has no derivable path (v1 is TS PBKDF2 legacy)"
|
||||
)));
|
||||
}
|
||||
Ok(format!("m/74'/2'/0'/{}'", version - 2))
|
||||
}
|
||||
|
||||
/// A derived extended private key with its public key.
|
||||
///
|
||||
/// Contains the private key bytes and public key bytes from
|
||||
@@ -101,8 +108,8 @@ impl ExtendedPrivKey {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use alknet_secret::derivation::{derive_path_from_seed, PATHS};
|
||||
/// use alknet_secret::mnemonic::Mnemonic;
|
||||
/// use alknet_vault::derivation::{derive_path_from_seed, PATHS};
|
||||
/// use alknet_vault::mnemonic::Mnemonic;
|
||||
///
|
||||
/// let mnemonic = Mnemonic::generate(24).unwrap();
|
||||
/// let seed = mnemonic.to_seed(None);
|
||||
@@ -249,8 +256,34 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_site_password_path() {
|
||||
assert_eq!(site_password_path("abc123"), "m/74'/1'/0'/abc123'");
|
||||
fn test_encryption_path_for_version_v2() {
|
||||
assert_eq!(encryption_path_for_version(2).unwrap(), PATHS::ENCRYPTION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_path_for_version_v3() {
|
||||
assert_eq!(encryption_path_for_version(3).unwrap(), "m/74'/2'/0'/1'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_path_for_version_v4() {
|
||||
assert_eq!(encryption_path_for_version(4).unwrap(), "m/74'/2'/0'/2'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_path_for_version_rejects_v1() {
|
||||
assert!(matches!(
|
||||
encryption_path_for_version(1),
|
||||
Err(DerivationError::InvalidPath(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_path_for_version_rejects_v0() {
|
||||
assert!(matches!(
|
||||
encryption_path_for_version(0),
|
||||
Err(DerivationError::InvalidPath(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user