Compare commits

...

14 Commits

Author SHA1 Message Date
Yeachan-Heo
ccebabe605 Preserve verified session persistence while syncing remote runtime branch history
Origin/rcc/runtime advanced independently while this branch implemented
conversation history persistence. This merge keeps the tested local tree
as the source of truth for the user-requested feature while recording the
remote branch tip so future work can proceed from a shared history.

Constraint: Push required incorporating origin/rcc/runtime history without breaking the verified session-persistence implementation
Rejected: Force-push over origin/rcc/runtime | would discard remote branch history
Confidence: medium
Scope-risk: narrow
Reversibility: clean
Directive: Before the next broad CLI/runtime refactor, compare this branch against origin/rcc/runtime for any remote-only startup behavior worth porting deliberately
Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test --workspace
Not-tested: Remote-only runtime startup semantics not exercised by the session persistence change
2026-04-01 01:02:05 +00:00
Yeachan-Heo
146260083c Persist CLI conversation history across sessions
The Rust CLI now stores managed sessions under ~/.claude/sessions,
records additive session metadata in the canonical JSON transcript,
and exposes a /sessions listing alias alongside ID-or-path resume.
Inactive oversized sessions are compacted automatically so old
transcripts remain resumable without growing unchecked.

Constraint: Session JSON must stay backward-compatible with legacy files that lack metadata
Constraint: Managed sessions must use a single canonical JSON file per session without new dependencies
Rejected: Sidecar metadata/index files | duplicated state and diverged from the requested single-file persistence model
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Keep CLI policy in the CLI; only add transcript-adjacent metadata to runtime::Session unless another consumer truly needs more
Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test --workspace
Not-tested: Manual interactive REPL smoke test against the live Anthropic API
2026-04-01 00:58:14 +00:00
Yeachan-Heo
d6341d54c1 feat: config discovery and CLAUDE.md loading (cherry-picked from rcc/runtime) 2026-04-01 00:40:34 +00:00
Yeachan-Heo
cd01d0e387 Honor Claude config defaults across runtime sessions
The runtime now discovers both legacy and current Claude config files at
user and project scope, merges them in precedence order, and carries the
resolved model, permission mode, instruction files, and MCP server
configuration into session startup.

This keeps CLI defaults aligned with project policy and exposes configured
MCP tools without requiring manual flags.

Constraint: Must support both legacy .claude.json and current .claude/settings.json layouts
Constraint: Session startup must preserve CLI flag precedence over config defaults
Rejected: Read only project settings files | would ignore user-scoped defaults and MCP servers
Rejected: Delay MCP tool discovery until first tool call | model would not see configured MCP tools during planning
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Keep config precedence synchronized between prompt loading, session startup, and status reporting
Tested: cargo fmt --all --check; cargo clippy --workspace --all-targets --all-features -- -D warnings; cargo test --workspace --all-features
Not-tested: Live remote MCP servers and interactive REPL session startup against external services
2026-04-01 00:36:32 +00:00
Yeachan-Heo
863958b94c Merge remote-tracking branch 'origin/rcc/api' into dev/rust 2026-04-01 00:30:20 +00:00
Yeachan-Heo
9455280f24 Enable saved OAuth startup auth without breaking local version output
Startup auth was split between the CLI and API crates, which made saved OAuth refresh behavior eager and easy to drift. This change adds a startup-specific resolver in the API layer, keeps env-only auth semantics intact, preserves saved refresh tokens when refresh responses omit them, and lets the CLI reuse the shared resolver while keeping --version on a purely local path.

Constraint: Saved OAuth credentials live in ~/.claude/credentials.json and must remain compatible with existing runtime helpers
Constraint: --version must not require config loading or any API/auth client initialization
Rejected: Keep refresh orchestration only in rusty-claude-cli | would preserve split auth policy and lazy-load bugs
Rejected: Change AnthropicClient::from_env to load config | would broaden configless API semantics for non-CLI callers
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Keep startup-only OAuth refresh separate from AuthSource::from_env() / AnthropicClient::from_env() unless all non-CLI callers are re-evaluated
Tested: cargo fmt --all; cargo build; cargo clippy --workspace --all-targets -- -D warnings; cargo test; cargo run -p rusty-claude-cli -- --version
Not-tested: Live OAuth refresh against a real auth server
2026-04-01 00:24:55 +00:00
Yeachan-Heo
c92403994d Merge remote-tracking branch 'origin/rcc/cli' into dev/rust
# Conflicts:
#	rust/crates/rusty-claude-cli/src/main.rs
2026-04-01 00:20:39 +00:00
Yeachan-Heo
e2f061fd08 Enforce tool permissions before execution
The Rust CLI/runtime now models permissions as ordered access levels, derives tool requirements from the shared tool specs, and prompts REPL users before one-off danger-full-access escalations from workspace-write sessions. This also wires explicit --permission-mode parsing and makes /permissions operate on the live session state instead of an implicit env-derived default.

Constraint: Must preserve the existing three user-facing modes read-only, workspace-write, and danger-full-access

Constraint: Must avoid new dependencies and keep enforcement inside the existing runtime/tool plumbing

Rejected: Keep the old Allow/Deny/Prompt policy model | could not represent ordered tool requirements across the CLI surface

Rejected: Continue sourcing live session mode solely from RUSTY_CLAUDE_PERMISSION_MODE | /permissions would not reliably reflect the current session state

Confidence: high

Scope-risk: moderate

Reversibility: clean

Directive: Add required_permission entries for new tools before exposing them to the runtime

Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test -q

Not-tested: Manual interactive REPL approval flow in a live Anthropic session
2026-04-01 00:06:15 +00:00
Yeachan-Heo
c139fe9bee Merge remote-tracking branch 'origin/rcc/api' into dev/rust
# Conflicts:
#	rust/crates/rusty-claude-cli/src/main.rs
2026-03-31 23:41:08 +00:00
Yeachan-Heo
842abcfe85 Merge remote-tracking branch 'origin/rcc/cli' into dev/rust 2026-03-31 23:40:35 +00:00
Yeachan-Heo
807e29c8a1 Merge remote-tracking branch 'origin/rcc/tools' into dev/rust 2026-03-31 23:40:35 +00:00
Yeachan-Heo
32e89df631 Enable Claude OAuth login without requiring API keys
This adds an end-to-end OAuth PKCE login/logout path to the Rust CLI,
persists OAuth credentials under the Claude config home, and teaches the
API client to use persisted bearer credentials with refresh support when
env-based API credentials are absent.

Constraint: Reuse existing runtime OAuth primitives and keep browser/callback orchestration in the CLI
Constraint: Preserve auth precedence as API key, then auth-token env, then persisted OAuth credentials
Rejected: Put browser launch and token exchange entirely in runtime | caused boundary creep across shared crates
Rejected: Duplicate credential parsing in CLI and api | increased drift and refresh inconsistency
Confidence: medium
Scope-risk: moderate
Reversibility: clean
Directive: Keep logout non-destructive to unrelated credentials.json fields and do not silently fall back to stale expired tokens
Tested: cargo fmt; cargo clippy --workspace --all-targets -- -D warnings; cargo test
Not-tested: Manual live Anthropic OAuth browser flow against real authorize/token endpoints
2026-03-31 23:38:05 +00:00
Yeachan-Heo
1f8cfbce38 Prevent tool regressions by locking down dispatch-level edge cases
The tools crate already covered several higher-level commands, but the
public dispatch surface still lacked direct tests for shell and file
operations plus several error-path behaviors. This change expands the
existing lib.rs unit suite to cover the requested tools through
`execute_tool`, adds deterministic temp-path helpers, and hardens
assertions around invalid inputs and tricky offset/background behavior.

Constraint: No new dependencies; coverage had to stay within the existing crate test structure
Rejected: Split coverage into new integration tests under tests/ | would require broader visibility churn for little gain
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Keep future tool-coverage additions on the public dispatch surface unless a lower-level helper contract specifically needs direct testing
Tested: cargo fmt --all; cargo clippy -p tools --all-targets --all-features -- -D warnings; cargo test -p tools
Not-tested: Cross-platform shell/runtime differences beyond the current Linux-like CI environment
2026-03-31 23:33:05 +00:00
Yeachan-Heo
1e5002b521 Add MCP server orchestration so configured stdio tools can be discovered and called
The runtime crate already had typed MCP config parsing, bootstrap metadata,
and stdio JSON-RPC transport primitives, but it lacked the stateful layer
that owns configured subprocesses and routes discovered tools back to the
right server. This change adds a thin lazy McpServerManager in mcp_stdio,
keeps unsupported transports explicit, and locks the behavior with
subprocess-backed discovery, routing, reuse, shutdown, and error tests.

Constraint: Keep the change narrow to the runtime crate and stdio transport only
Constraint: Reuse existing MCP config/bootstrap/process helpers instead of adding new dependencies
Rejected: Eagerly spawn all configured servers at construction | unnecessary startup cost and failure coupling
Rejected: Spawn a fresh process per request | defeats lifecycle management and tool routing cache
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Keep higher-level runtime/session integration separate until a caller needs this manager surface
Tested: cargo fmt --all; cargo clippy -p runtime --all-targets -- -D warnings; cargo test -p runtime
Not-tested: Integration into conversation/runtime flows outside direct manager APIs
2026-03-31 23:31:37 +00:00
22 changed files with 2967 additions and 204 deletions

View File

@@ -0,0 +1 @@
{"messages":[],"version":1}

View File

@@ -0,0 +1 @@
{"messages":[{"blocks":[{"text":"Say hello in one sentence","type":"text"}],"role":"user"},{"blocks":[{"text":"Hello! I'm Claude, an AI assistant ready to help you with software engineering tasks, code analysis, debugging, or any other programming challenges you might have.","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":11,"output_tokens":32}}],"version":1}

1
rust/Cargo.lock generated
View File

@@ -22,6 +22,7 @@ name = "api"
version = "0.1.0"
dependencies = [
"reqwest",
"runtime",
"serde",
"serde_json",
"tokio",

View File

@@ -64,6 +64,26 @@ cd rust
cargo run -p rusty-claude-cli -- --version
```
### Login with OAuth
Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run:
```bash
cd rust
cargo run -p rusty-claude-cli -- login
```
This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`).
### Logout
```bash
cd rust
cargo run -p rusty-claude-cli -- logout
```
This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`.
## Usage examples
### 1) Prompt mode
@@ -113,6 +133,7 @@ Inside the REPL, useful commands include:
/diff
/version
/export notes.txt
/sessions
/session list
/exit
```
@@ -123,14 +144,14 @@ Inspect or maintain a saved session file without entering the REPL:
```bash
cd rust
cargo run -p rusty-claude-cli -- --resume session.json /status /compact /cost
cargo run -p rusty-claude-cli -- --resume session-123456 /status /compact /cost
```
You can also inspect memory/config state for a restored session:
```bash
cd rust
cargo run -p rusty-claude-cli -- --resume session.json /memory /config
cargo run -p rusty-claude-cli -- --resume ~/.claude/sessions/session-123456.json /memory /config
```
## Available commands
@@ -138,7 +159,7 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
### Top-level CLI commands
- `prompt <text...>` — run one prompt non-interactively
- `--resume <session.json> [/commands...]` — inspect or maintain a saved session
- `--resume <session-id-or-path> [/commands...]` — inspect or maintain a saved session stored under `~/.claude/sessions/`
- `dump-manifests` — print extracted upstream manifest counts
- `bootstrap-plan` — print the current bootstrap skeleton
- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt
@@ -156,13 +177,14 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions
- `/clear [--confirm]` — clear the current local session
- `/cost` — show token usage totals
- `/resume <session-path>` — load a saved session into the REPL
- `/resume <session-id-or-path>` — load a saved session into the REPL
- `/config [env|hooks|model]` — inspect discovered Claude config
- `/memory` — inspect loaded instruction memory files
- `/init` — create a starter `CLAUDE.md`
- `/diff` — show the current git diff for the workspace
- `/version` — print version and build metadata locally
- `/export [file]` — export the current conversation transcript
- `/sessions` — list recent managed local sessions from `~/.claude/sessions/`
- `/session [list|switch <session-id>]` — inspect or switch managed local sessions
- `/exit` — leave the REPL
@@ -170,8 +192,9 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
### Anthropic/API
- `ANTHROPIC_AUTH_TOKEN`preferred bearer token for API auth
- `ANTHROPIC_API_KEY`legacy API key fallback if auth token is unset
- `ANTHROPIC_API_KEY`highest-precedence API credential
- `ANTHROPIC_AUTH_TOKEN`bearer-token override used when no API key is set
- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set
- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL
- `ANTHROPIC_MODEL` — default model used by selected live integration tests

View File

@@ -7,6 +7,7 @@ publish.workspace = true
[dependencies]
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
runtime = { path = "../runtime" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }

View File

@@ -1,6 +1,10 @@
use std::collections::VecDeque;
use std::time::Duration;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
OAuthTokenExchangeRequest,
};
use serde::Deserialize;
use crate::error::ApiError;
@@ -81,11 +85,12 @@ impl AuthSource {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct OAuthTokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
#[serde(default)]
pub scopes: Vec<String>,
}
@@ -131,7 +136,7 @@ impl AnthropicClient {
}
pub fn from_env() -> Result<Self, ApiError> {
Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url()))
Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
}
#[must_use]
@@ -225,6 +230,46 @@ impl AnthropicClient {
})
}
pub async fn exchange_oauth_code(
&self,
config: &OAuthConfig,
request: &OAuthTokenExchangeRequest,
) -> Result<OAuthTokenSet, ApiError> {
let response = self
.http
.post(&config.token_url)
.header("content-type", "application/x-www-form-urlencoded")
.form(&request.form_params())
.send()
.await
.map_err(ApiError::from)?;
let response = expect_success(response).await?;
response
.json::<OAuthTokenSet>()
.await
.map_err(ApiError::from)
}
pub async fn refresh_oauth_token(
&self,
config: &OAuthConfig,
request: &OAuthRefreshRequest,
) -> Result<OAuthTokenSet, ApiError> {
let response = self
.http
.post(&config.token_url)
.header("content-type", "application/x-www-form-urlencoded")
.form(&request.form_params())
.send()
.await
.map_err(ApiError::from)?;
let response = expect_success(response).await?;
response
.json::<OAuthTokenSet>()
.await
.map_err(ApiError::from)
}
async fn send_with_retry(
&self,
request: &MessageRequest,
@@ -304,6 +349,153 @@ impl AnthropicClient {
}
}
impl AuthSource {
pub fn from_env_or_saved() -> Result<Self, ApiError> {
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
api_key,
bearer_token,
}),
None => Ok(Self::ApiKey(api_key)),
};
}
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
return Ok(Self::BearerToken(bearer_token));
}
match load_saved_oauth_token() {
Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
if token_set.refresh_token.is_some() {
Err(ApiError::Auth(
"saved OAuth token is expired; load runtime OAuth config to refresh it"
.to_string(),
))
} else {
Err(ApiError::ExpiredOAuthToken)
}
}
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
Ok(None) => Err(ApiError::MissingApiKey),
Err(error) => Err(error),
}
}
}
#[must_use]
pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
token_set
.expires_at
.is_some_and(|expires_at| expires_at <= now_unix_timestamp())
}
pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
let Some(token_set) = load_saved_oauth_token()? else {
return Ok(None);
};
resolve_saved_oauth_token_set(config, token_set).map(Some)
}
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
where
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
{
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
api_key,
bearer_token,
}),
None => Ok(AuthSource::ApiKey(api_key)),
};
}
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
return Ok(AuthSource::BearerToken(bearer_token));
}
let Some(token_set) = load_saved_oauth_token()? else {
return Err(ApiError::MissingApiKey);
};
if !oauth_token_is_expired(&token_set) {
return Ok(AuthSource::BearerToken(token_set.access_token));
}
if token_set.refresh_token.is_none() {
return Err(ApiError::ExpiredOAuthToken);
}
let Some(config) = load_oauth_config()? else {
return Err(ApiError::Auth(
"saved OAuth token is expired; runtime OAuth config is missing".to_string(),
));
};
Ok(AuthSource::from(resolve_saved_oauth_token_set(
&config, token_set,
)?))
}
fn resolve_saved_oauth_token_set(
config: &OAuthConfig,
token_set: OAuthTokenSet,
) -> Result<OAuthTokenSet, ApiError> {
if !oauth_token_is_expired(&token_set) {
return Ok(token_set);
}
let Some(refresh_token) = token_set.refresh_token.clone() else {
return Err(ApiError::ExpiredOAuthToken);
};
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
let refreshed = client_runtime_block_on(async {
client
.refresh_oauth_token(
config,
&OAuthRefreshRequest::from_config(
config,
refresh_token,
Some(token_set.scopes.clone()),
),
)
.await
})?;
let resolved = OAuthTokenSet {
access_token: refreshed.access_token,
refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
expires_at: refreshed.expires_at,
scopes: refreshed.scopes,
};
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: resolved.access_token.clone(),
refresh_token: resolved.refresh_token.clone(),
expires_at: resolved.expires_at,
scopes: resolved.scopes.clone(),
})
.map_err(ApiError::from)?;
Ok(resolved)
}
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
where
F: std::future::Future<Output = Result<T, ApiError>>,
{
tokio::runtime::Runtime::new()
.map_err(ApiError::from)?
.block_on(future)
}
fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
let token_set = load_oauth_credentials().map_err(ApiError::from)?;
Ok(token_set.map(|token_set| OAuthTokenSet {
access_token: token_set.access_token,
refresh_token: token_set.refresh_token,
expires_at: token_set.expires_at,
scopes: token_set.scopes,
}))
}
fn now_unix_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
match std::env::var(key) {
Ok(value) if !value.is_empty() => Ok(Some(value)),
@@ -314,7 +506,7 @@ fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
#[cfg(test)]
fn read_api_key() -> Result<String, ApiError> {
let auth = AuthSource::from_env()?;
let auth = AuthSource::from_env_or_saved()?;
auth.api_key()
.or_else(|| auth.bearer_token())
.map(ToOwned::to_owned)
@@ -424,10 +616,18 @@ struct AnthropicErrorBody {
#[cfg(test)]
mod tests {
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::{Mutex, OnceLock};
use std::time::Duration;
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::client::{AuthSource, OAuthTokenSet};
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
use crate::client::{
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
};
use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
@@ -437,11 +637,53 @@ mod tests {
.expect("env lock")
}
fn temp_config_home() -> std::path::PathBuf {
std::env::temp_dir().join(format!(
"api-oauth-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
))
}
fn sample_oauth_config(token_url: String) -> OAuthConfig {
OAuthConfig {
client_id: "runtime-client".to_string(),
authorize_url: "https://console.test/oauth/authorize".to_string(),
token_url,
callback_port: Some(4545),
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
scopes: vec!["org:read".to_string(), "user:write".to_string()],
}
}
fn spawn_token_server(response_body: &'static str) -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let address = listener.local_addr().expect("local addr");
thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept connection");
let mut buffer = [0_u8; 4096];
let _ = stream.read(&mut buffer).expect("read request");
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
stream
.write_all(response.as_bytes())
.expect("write response");
});
format!("http://{address}/oauth/token")
}
#[test]
fn read_api_key_requires_presence() {
let _guard = env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME");
let error = super::read_api_key().expect_err("missing key should error");
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
}
@@ -453,6 +695,7 @@ mod tests {
std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error");
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
}
#[test]
@@ -500,6 +743,166 @@ mod tests {
std::env::remove_var("ANTHROPIC_API_KEY");
}
#[test]
fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "saved-access-token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: Some(now_unix_timestamp() + 300),
scopes: vec!["scope:a".to_string()],
})
.expect("save oauth credentials");
let auth = AuthSource::from_env_or_saved().expect("saved auth");
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn oauth_token_expiry_uses_expires_at_timestamp() {
assert!(oauth_token_is_expired(&OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: None,
expires_at: Some(1),
scopes: Vec::new(),
}));
assert!(!oauth_token_is_expired(&OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: None,
expires_at: Some(now_unix_timestamp() + 60),
scopes: Vec::new(),
}));
}
#[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let token_url = spawn_token_server(
"{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
);
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
.expect("resolve refreshed token")
.expect("token set present");
assert_eq!(resolved.access_token, "refreshed-token");
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.access_token, "refreshed-token");
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "saved-access-token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: Some(now_unix_timestamp() + 300),
scopes: vec!["scope:a".to_string()],
})
.expect("save oauth credentials");
let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
.expect("startup auth");
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let error =
resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
assert!(
matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
);
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.access_token, "expired-access-token");
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let token_url = spawn_token_server(
"{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
);
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
.expect("resolve refreshed token")
.expect("token set present");
assert_eq!(resolved.access_token, "refreshed-token");
assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn message_request_stream_helper_sets_stream_true() {
let request = MessageRequest {
@@ -517,7 +920,7 @@ mod tests {
#[test]
fn backoff_doubles_until_maximum() {
let client = super::AnthropicClient::new("test-key").with_retry_policy(
let client = AnthropicClient::new("test-key").with_retry_policy(
3,
Duration::from_millis(10),
Duration::from_millis(25),

View File

@@ -5,6 +5,8 @@ use std::time::Duration;
#[derive(Debug)]
pub enum ApiError {
MissingApiKey,
ExpiredOAuthToken,
Auth(String),
InvalidApiKeyEnv(VarError),
Http(reqwest::Error),
Io(std::io::Error),
@@ -35,6 +37,8 @@ impl ApiError {
Self::Api { retryable, .. } => *retryable,
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
Self::MissingApiKey
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Io(_)
| Self::Json(_)
@@ -53,6 +57,13 @@ impl Display for ApiError {
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
)
}
Self::ExpiredOAuthToken => {
write!(
f,
"saved OAuth token is expired and no refresh token is available"
)
}
Self::Auth(message) => write!(f, "auth error: {message}"),
Self::InvalidApiKeyEnv(error) => {
write!(
f,

View File

@@ -3,7 +3,10 @@ mod error;
mod sse;
mod types;
pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet};
pub use client::{
oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source,
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
};
pub use error::ApiError;
pub use sse::{parse_frame, SseParser};
pub use types::{

View File

@@ -84,7 +84,7 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[
SlashCommandSpec {
name: "resume",
summary: "Load a saved session into the REPL",
argument_hint: Some("<session-path>"),
argument_hint: Some("<session-id-or-path>"),
resume_supported: false,
},
SlashCommandSpec {
@@ -129,6 +129,12 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[
argument_hint: Some("[list|switch <session-id>]"),
resume_supported: false,
},
SlashCommandSpec {
name: "sessions",
summary: "List recent managed local sessions",
argument_hint: None,
resume_supported: false,
},
];
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -163,6 +169,7 @@ pub enum SlashCommand {
action: Option<String>,
target: Option<String>,
},
Sessions,
Unknown(String),
}
@@ -207,6 +214,7 @@ impl SlashCommand {
action: parts.next().map(ToOwned::to_owned),
target: parts.next().map(ToOwned::to_owned),
},
"sessions" => Self::Sessions,
other => Self::Unknown(other.to_string()),
})
}
@@ -291,6 +299,7 @@ pub fn handle_slash_command(
| SlashCommand::Version
| SlashCommand::Export { .. }
| SlashCommand::Session { .. }
| SlashCommand::Sessions
| SlashCommand::Unknown(_) => None,
}
}
@@ -365,6 +374,10 @@ mod tests {
target: Some("abc123".to_string())
})
);
assert_eq!(
SlashCommand::parse("/sessions"),
Some(SlashCommand::Sessions)
);
}
#[test]
@@ -378,7 +391,7 @@ mod tests {
assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]"));
assert!(help.contains("/clear [--confirm]"));
assert!(help.contains("/cost"));
assert!(help.contains("/resume <session-path>"));
assert!(help.contains("/resume <session-id-or-path>"));
assert!(help.contains("/config [env|hooks|model]"));
assert!(help.contains("/memory"));
assert!(help.contains("/init"));
@@ -386,7 +399,8 @@ mod tests {
assert!(help.contains("/version"));
assert!(help.contains("/export [file]"));
assert!(help.contains("/session [list|switch <session-id>]"));
assert_eq!(slash_command_specs().len(), 15);
assert!(help.contains("/sessions"));
assert_eq!(slash_command_specs().len(), 16);
assert_eq!(resume_supported_slash_commands().len(), 11);
}
@@ -404,6 +418,7 @@ mod tests {
text: "recent".to_string(),
}]),
],
metadata: None,
};
let result = handle_slash_command(
@@ -468,5 +483,6 @@ mod tests {
assert!(
handle_slash_command("/session list", &session, CompactionConfig::default()).is_none()
);
assert!(handle_slash_command("/sessions", &session, CompactionConfig::default()).is_none());
}
}

View File

@@ -105,6 +105,7 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
compacted_session: Session {
version: session.version,
messages: compacted_messages,
metadata: session.metadata.clone(),
},
removed_message_count: removed.len(),
}
@@ -393,6 +394,7 @@ mod tests {
let session = Session {
version: 1,
messages: vec![ConversationMessage::user_text("hello")],
metadata: None,
};
let result = compact_session(&session, CompactionConfig::default());
@@ -420,6 +422,7 @@ mod tests {
usage: None,
},
],
metadata: None,
};
let result = compact_session(

View File

@@ -14,6 +14,13 @@ pub enum ConfigSource {
Local,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResolvedPermissionMode {
ReadOnly,
WorkspaceWrite,
DangerFullAccess,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConfigEntry {
pub source: ConfigSource,
@@ -31,6 +38,8 @@ pub struct RuntimeConfig {
pub struct RuntimeFeatureConfig {
mcp: McpConfigCollection,
oauth: Option<OAuthConfig>,
model: Option<String>,
permission_mode: Option<ResolvedPermissionMode>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
@@ -165,11 +174,23 @@ impl ConfigLoader {
#[must_use]
pub fn discover(&self) -> Vec<ConfigEntry> {
let user_legacy_path = self.config_home.parent().map_or_else(
|| PathBuf::from(".claude.json"),
|parent| parent.join(".claude.json"),
);
vec![
ConfigEntry {
source: ConfigSource::User,
path: user_legacy_path,
},
ConfigEntry {
source: ConfigSource::User,
path: self.config_home.join("settings.json"),
},
ConfigEntry {
source: ConfigSource::Project,
path: self.cwd.join(".claude.json"),
},
ConfigEntry {
source: ConfigSource::Project,
path: self.cwd.join(".claude").join("settings.json"),
@@ -195,14 +216,15 @@ impl ConfigLoader {
loaded_entries.push(entry);
}
let merged_value = JsonValue::Object(merged.clone());
let feature_config = RuntimeFeatureConfig {
mcp: McpConfigCollection {
servers: mcp_servers,
},
oauth: parse_optional_oauth_config(
&JsonValue::Object(merged.clone()),
"merged settings.oauth",
)?,
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
model: parse_optional_model(&merged_value),
permission_mode: parse_optional_permission_mode(&merged_value)?,
};
Ok(RuntimeConfig {
@@ -257,6 +279,16 @@ impl RuntimeConfig {
pub fn oauth(&self) -> Option<&OAuthConfig> {
self.feature_config.oauth.as_ref()
}
#[must_use]
pub fn model(&self) -> Option<&str> {
self.feature_config.model.as_deref()
}
#[must_use]
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
self.feature_config.permission_mode
}
}
impl RuntimeFeatureConfig {
@@ -269,6 +301,16 @@ impl RuntimeFeatureConfig {
pub fn oauth(&self) -> Option<&OAuthConfig> {
self.oauth.as_ref()
}
#[must_use]
pub fn model(&self) -> Option<&str> {
self.model.as_deref()
}
#[must_use]
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
self.permission_mode
}
}
impl McpConfigCollection {
@@ -307,6 +349,7 @@ impl McpServerConfig {
fn read_optional_json_object(
path: &Path,
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json");
let contents = match fs::read_to_string(path) {
Ok(contents) => contents,
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
@@ -317,14 +360,20 @@ fn read_optional_json_object(
return Ok(Some(BTreeMap::new()));
}
let parsed = JsonValue::parse(&contents)
.map_err(|error| ConfigError::Parse(format!("{}: {error}", path.display())))?;
let object = parsed.as_object().ok_or_else(|| {
ConfigError::Parse(format!(
let parsed = match JsonValue::parse(&contents) {
Ok(parsed) => parsed,
Err(error) if is_legacy_config => return Ok(None),
Err(error) => return Err(ConfigError::Parse(format!("{}: {error}", path.display()))),
};
let Some(object) = parsed.as_object() else {
if is_legacy_config {
return Ok(None);
}
return Err(ConfigError::Parse(format!(
"{}: top-level settings value must be a JSON object",
path.display()
))
})?;
)));
};
Ok(Some(object.clone()))
}
@@ -355,6 +404,47 @@ fn merge_mcp_servers(
Ok(())
}
fn parse_optional_model(root: &JsonValue) -> Option<String> {
root.as_object()
.and_then(|object| object.get("model"))
.and_then(JsonValue::as_str)
.map(ToOwned::to_owned)
}
fn parse_optional_permission_mode(
root: &JsonValue,
) -> Result<Option<ResolvedPermissionMode>, ConfigError> {
let Some(object) = root.as_object() else {
return Ok(None);
};
if let Some(mode) = object.get("permissionMode").and_then(JsonValue::as_str) {
return parse_permission_mode_label(mode, "merged settings.permissionMode").map(Some);
}
let Some(mode) = object
.get("permissions")
.and_then(JsonValue::as_object)
.and_then(|permissions| permissions.get("defaultMode"))
.and_then(JsonValue::as_str)
else {
return Ok(None);
};
parse_permission_mode_label(mode, "merged settings.permissions.defaultMode").map(Some)
}
fn parse_permission_mode_label(
mode: &str,
context: &str,
) -> Result<ResolvedPermissionMode, ConfigError> {
match mode {
"default" | "plan" | "read-only" => Ok(ResolvedPermissionMode::ReadOnly),
"acceptEdits" | "auto" | "workspace-write" => Ok(ResolvedPermissionMode::WorkspaceWrite),
"dontAsk" | "danger-full-access" => Ok(ResolvedPermissionMode::DangerFullAccess),
other => Err(ConfigError::Parse(format!(
"{context}: unsupported permission mode {other}"
))),
}
}
fn parse_optional_oauth_config(
root: &JsonValue,
context: &str,
@@ -594,7 +684,8 @@ fn deep_merge_objects(
#[cfg(test)]
mod tests {
use super::{
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode,
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
};
use crate::json::JsonValue;
use std::fs;
@@ -635,14 +726,24 @@ mod tests {
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
fs::create_dir_all(&home).expect("home config dir");
fs::write(
home.parent().expect("home parent").join(".claude.json"),
r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#,
)
.expect("write user compat config");
fs::write(
home.join("settings.json"),
r#"{"model":"sonnet","env":{"A":"1"},"hooks":{"PreToolUse":["base"]}}"#,
r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#,
)
.expect("write user settings");
fs::write(
cwd.join(".claude.json"),
r#"{"model":"project-compat","env":{"B":"2"}}"#,
)
.expect("write project compat config");
fs::write(
cwd.join(".claude").join("settings.json"),
r#"{"env":{"B":"2"},"hooks":{"PostToolUse":["project"]}}"#,
r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#,
)
.expect("write project settings");
fs::write(
@@ -656,25 +757,37 @@ mod tests {
.expect("config should load");
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
assert_eq!(loaded.loaded_entries().len(), 3);
assert_eq!(loaded.loaded_entries().len(), 5);
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
assert_eq!(
loaded.get("model"),
Some(&JsonValue::String("opus".to_string()))
);
assert_eq!(loaded.model(), Some("opus"));
assert_eq!(
loaded.permission_mode(),
Some(ResolvedPermissionMode::WorkspaceWrite)
);
assert_eq!(
loaded
.get("env")
.and_then(JsonValue::as_object)
.expect("env object")
.len(),
2
4
);
assert!(loaded
.get("hooks")
.and_then(JsonValue::as_object)
.expect("hooks object")
.contains_key("PreToolUse"));
assert!(loaded
.get("hooks")
.and_then(JsonValue::as_object)
.expect("hooks object")
.contains_key("PostToolUse"));
assert!(loaded.mcp().get("home").is_some());
assert!(loaded.mcp().get("project").is_some());
fs::remove_dir_all(root).expect("cleanup temp dir");
}

View File

@@ -408,7 +408,7 @@ mod tests {
.sum::<i32>();
Ok(total.to_string())
});
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
let system_prompt = SystemPromptBuilder::new()
.with_project_context(ProjectContext {
cwd: PathBuf::from("/tmp/project"),
@@ -487,7 +487,7 @@ mod tests {
Session::new(),
SingleCallApiClient,
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::Prompt),
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
vec!["system".to_string()],
);
@@ -536,7 +536,7 @@ mod tests {
session,
SimpleApi,
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::Allow),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
);
@@ -563,7 +563,7 @@ mod tests {
Session::new(),
SimpleApi,
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::Allow),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
);
runtime.run_turn("a", None).expect("turn a");

View File

@@ -25,7 +25,8 @@ pub use config::{
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig,
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
};
pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
@@ -46,14 +47,17 @@ pub use mcp_client::{
};
pub use mcp_stdio::{
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
McpListResourcesParams, McpListResourcesResult, McpListToolsParams, McpListToolsResult,
McpReadResourceParams, McpReadResourceResult, McpResource, McpResourceContents,
McpStdioProcess, McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult,
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
};
pub use oauth::{
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
PkceChallengeMethod, PkceCodePair,
};
pub use permissions::{
@@ -69,7 +73,17 @@ pub use remote::{
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
};
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
pub use session::{
ContentBlock, ConversationMessage, MessageRole, Session, SessionError, SessionMetadata,
};
pub use usage::{
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
};
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}

View File

@@ -8,6 +8,8 @@ use serde_json::Value as JsonValue;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig};
use crate::mcp::mcp_tool_name;
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
@@ -200,6 +202,374 @@ pub struct McpReadResourceResult {
pub contents: Vec<McpResourceContents>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ManagedMcpTool {
pub server_name: String,
pub qualified_name: String,
pub raw_name: String,
pub tool: McpTool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnsupportedMcpServer {
pub server_name: String,
pub transport: McpTransport,
pub reason: String,
}
#[derive(Debug)]
pub enum McpServerManagerError {
Io(io::Error),
JsonRpc {
server_name: String,
method: &'static str,
error: JsonRpcError,
},
InvalidResponse {
server_name: String,
method: &'static str,
details: String,
},
UnknownTool {
qualified_name: String,
},
UnknownServer {
server_name: String,
},
}
impl std::fmt::Display for McpServerManagerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(error) => write!(f, "{error}"),
Self::JsonRpc {
server_name,
method,
error,
} => write!(
f,
"MCP server `{server_name}` returned JSON-RPC error for {method}: {} ({})",
error.message, error.code
),
Self::InvalidResponse {
server_name,
method,
details,
} => write!(
f,
"MCP server `{server_name}` returned invalid response for {method}: {details}"
),
Self::UnknownTool { qualified_name } => {
write!(f, "unknown MCP tool `{qualified_name}`")
}
Self::UnknownServer { server_name } => write!(f, "unknown MCP server `{server_name}`"),
}
}
}
impl std::error::Error for McpServerManagerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(error) => Some(error),
Self::JsonRpc { .. }
| Self::InvalidResponse { .. }
| Self::UnknownTool { .. }
| Self::UnknownServer { .. } => None,
}
}
}
impl From<io::Error> for McpServerManagerError {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ToolRoute {
server_name: String,
raw_name: String,
}
#[derive(Debug)]
struct ManagedMcpServer {
bootstrap: McpClientBootstrap,
process: Option<McpStdioProcess>,
initialized: bool,
}
impl ManagedMcpServer {
fn new(bootstrap: McpClientBootstrap) -> Self {
Self {
bootstrap,
process: None,
initialized: false,
}
}
}
#[derive(Debug)]
pub struct McpServerManager {
servers: BTreeMap<String, ManagedMcpServer>,
unsupported_servers: Vec<UnsupportedMcpServer>,
tool_index: BTreeMap<String, ToolRoute>,
next_request_id: u64,
}
impl McpServerManager {
#[must_use]
pub fn from_runtime_config(config: &RuntimeConfig) -> Self {
Self::from_servers(config.mcp().servers())
}
#[must_use]
pub fn from_servers(servers: &BTreeMap<String, ScopedMcpServerConfig>) -> Self {
let mut managed_servers = BTreeMap::new();
let mut unsupported_servers = Vec::new();
for (server_name, server_config) in servers {
if server_config.transport() == McpTransport::Stdio {
let bootstrap = McpClientBootstrap::from_scoped_config(server_name, server_config);
managed_servers.insert(server_name.clone(), ManagedMcpServer::new(bootstrap));
} else {
unsupported_servers.push(UnsupportedMcpServer {
server_name: server_name.clone(),
transport: server_config.transport(),
reason: format!(
"transport {:?} is not supported by McpServerManager",
server_config.transport()
),
});
}
}
Self {
servers: managed_servers,
unsupported_servers,
tool_index: BTreeMap::new(),
next_request_id: 1,
}
}
#[must_use]
pub fn unsupported_servers(&self) -> &[UnsupportedMcpServer] {
&self.unsupported_servers
}
pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
let mut discovered_tools = Vec::new();
for server_name in server_names {
self.ensure_server_ready(&server_name).await?;
self.clear_routes_for_server(&server_name);
let mut cursor = None;
loop {
let request_id = self.take_request_id();
let response = {
let server = self.server_mut(&server_name)?;
let process = server.process.as_mut().ok_or_else(|| {
McpServerManagerError::InvalidResponse {
server_name: server_name.clone(),
method: "tools/list",
details: "server process missing after initialization".to_string(),
}
})?;
process
.list_tools(
request_id,
Some(McpListToolsParams {
cursor: cursor.clone(),
}),
)
.await?
};
if let Some(error) = response.error {
return Err(McpServerManagerError::JsonRpc {
server_name: server_name.clone(),
method: "tools/list",
error,
});
}
let result =
response
.result
.ok_or_else(|| McpServerManagerError::InvalidResponse {
server_name: server_name.clone(),
method: "tools/list",
details: "missing result payload".to_string(),
})?;
for tool in result.tools {
let qualified_name = mcp_tool_name(&server_name, &tool.name);
self.tool_index.insert(
qualified_name.clone(),
ToolRoute {
server_name: server_name.clone(),
raw_name: tool.name.clone(),
},
);
discovered_tools.push(ManagedMcpTool {
server_name: server_name.clone(),
qualified_name,
raw_name: tool.name.clone(),
tool,
});
}
match result.next_cursor {
Some(next_cursor) => cursor = Some(next_cursor),
None => break,
}
}
}
Ok(discovered_tools)
}
pub async fn call_tool(
&mut self,
qualified_tool_name: &str,
arguments: Option<JsonValue>,
) -> Result<JsonRpcResponse<McpToolCallResult>, McpServerManagerError> {
let route = self
.tool_index
.get(qualified_tool_name)
.cloned()
.ok_or_else(|| McpServerManagerError::UnknownTool {
qualified_name: qualified_tool_name.to_string(),
})?;
self.ensure_server_ready(&route.server_name).await?;
let request_id = self.take_request_id();
let response =
{
let server = self.server_mut(&route.server_name)?;
let process = server.process.as_mut().ok_or_else(|| {
McpServerManagerError::InvalidResponse {
server_name: route.server_name.clone(),
method: "tools/call",
details: "server process missing after initialization".to_string(),
}
})?;
process
.call_tool(
request_id,
McpToolCallParams {
name: route.raw_name,
arguments,
meta: None,
},
)
.await?
};
Ok(response)
}
pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
for server_name in server_names {
let server = self.server_mut(&server_name)?;
if let Some(process) = server.process.as_mut() {
process.shutdown().await?;
}
server.process = None;
server.initialized = false;
}
Ok(())
}
fn clear_routes_for_server(&mut self, server_name: &str) {
self.tool_index
.retain(|_, route| route.server_name != server_name);
}
fn server_mut(
&mut self,
server_name: &str,
) -> Result<&mut ManagedMcpServer, McpServerManagerError> {
self.servers
.get_mut(server_name)
.ok_or_else(|| McpServerManagerError::UnknownServer {
server_name: server_name.to_string(),
})
}
fn take_request_id(&mut self) -> JsonRpcId {
let id = self.next_request_id;
self.next_request_id = self.next_request_id.saturating_add(1);
JsonRpcId::Number(id)
}
async fn ensure_server_ready(
&mut self,
server_name: &str,
) -> Result<(), McpServerManagerError> {
let needs_spawn = self
.servers
.get(server_name)
.map(|server| server.process.is_none())
.ok_or_else(|| McpServerManagerError::UnknownServer {
server_name: server_name.to_string(),
})?;
if needs_spawn {
let server = self.server_mut(server_name)?;
server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?);
server.initialized = false;
}
let needs_initialize = self
.servers
.get(server_name)
.map(|server| !server.initialized)
.ok_or_else(|| McpServerManagerError::UnknownServer {
server_name: server_name.to_string(),
})?;
if needs_initialize {
let request_id = self.take_request_id();
let response = {
let server = self.server_mut(server_name)?;
let process = server.process.as_mut().ok_or_else(|| {
McpServerManagerError::InvalidResponse {
server_name: server_name.to_string(),
method: "initialize",
details: "server process missing before initialize".to_string(),
}
})?;
process
.initialize(request_id, default_initialize_params())
.await?
};
if let Some(error) = response.error {
return Err(McpServerManagerError::JsonRpc {
server_name: server_name.to_string(),
method: "initialize",
error,
});
}
if response.result.is_none() {
return Err(McpServerManagerError::InvalidResponse {
server_name: server_name.to_string(),
method: "initialize",
details: "missing result payload".to_string(),
});
}
let server = self.server_mut(server_name)?;
server.initialized = true;
}
Ok(())
}
}
#[derive(Debug)]
pub struct McpStdioProcess {
child: Child,
@@ -385,6 +755,14 @@ impl McpStdioProcess {
pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
self.child.wait().await
}
async fn shutdown(&mut self) -> io::Result<()> {
if self.child.try_wait()?.is_none() {
self.child.kill().await?;
}
let _ = self.child.wait().await?;
Ok(())
}
}
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
@@ -413,6 +791,17 @@ fn encode_frame(payload: &[u8]) -> Vec<u8> {
framed
}
fn default_initialize_params() -> McpInitializeParams {
McpInitializeParams {
protocol_version: "2025-03-26".to_string(),
capabilities: JsonValue::Object(serde_json::Map::new()),
client_info: McpInitializeClientInfo {
name: "runtime".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
@@ -426,15 +815,17 @@ mod tests {
use tokio::runtime::Builder;
use crate::config::{
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
ConfigSource, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
};
use crate::mcp::mcp_tool_name;
use crate::mcp_client::McpClientBootstrap;
use super::{
spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpStdioProcess, McpTool,
McpToolCallParams,
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpServerManager,
McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams,
};
fn temp_dir() -> PathBuf {
@@ -628,6 +1019,110 @@ mod tests {
script_path
}
#[allow(clippy::too_many_lines)]
fn write_manager_mcp_server_script() -> PathBuf {
let root = temp_dir();
fs::create_dir_all(&root).expect("temp dir");
let script_path = root.join("manager-mcp-server.py");
let script = [
"#!/usr/bin/env python3",
"import json, os, sys",
"",
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
"initialize_count = 0",
"",
"def log(method):",
" if LOG_PATH:",
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
" handle.write(f'{method}\\n')",
"",
"def read_message():",
" header = b''",
r" while not header.endswith(b'\r\n\r\n'):",
" chunk = sys.stdin.buffer.read(1)",
" if not chunk:",
" return None",
" header += chunk",
" length = 0",
r" for line in header.decode().split('\r\n'):",
r" if line.lower().startswith('content-length:'):",
r" length = int(line.split(':', 1)[1].strip())",
" payload = sys.stdin.buffer.read(length)",
" return json.loads(payload.decode())",
"",
"def send_message(message):",
" payload = json.dumps(message).encode()",
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
" sys.stdout.buffer.flush()",
"",
"while True:",
" request = read_message()",
" if request is None:",
" break",
" method = request['method']",
" log(method)",
" if method == 'initialize':",
" initialize_count += 1",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'protocolVersion': request['params']['protocolVersion'],",
" 'capabilities': {'tools': {}},",
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
" }",
" })",
" elif method == 'tools/list':",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'tools': [",
" {",
" 'name': 'echo',",
" 'description': f'Echo tool for {LABEL}',",
" 'inputSchema': {",
" 'type': 'object',",
" 'properties': {'text': {'type': 'string'}},",
" 'required': ['text']",
" }",
" }",
" ]",
" }",
" })",
" elif method == 'tools/call':",
" args = request['params'].get('arguments') or {}",
" text = args.get('text', '')",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
" 'structuredContent': {",
" 'server': LABEL,",
" 'echoed': text,",
" 'initializeCount': initialize_count",
" },",
" 'isError': False",
" }",
" })",
" else:",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
" })",
"",
]
.join("\n");
fs::write(&script_path, script).expect("write script");
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
permissions.set_mode(0o755);
fs::set_permissions(&script_path, permissions).expect("chmod");
script_path
}
fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
let config = ScopedMcpServerConfig {
scope: ConfigSource::Local,
@@ -653,6 +1148,27 @@ mod tests {
fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
}
fn manager_server_config(
script_path: &Path,
label: &str,
log_path: &Path,
) -> ScopedMcpServerConfig {
ScopedMcpServerConfig {
scope: ConfigSource::Local,
config: McpServerConfig::Stdio(McpStdioServerConfig {
command: "python3".to_string(),
args: vec![script_path.to_string_lossy().into_owned()],
env: BTreeMap::from([
("MCP_SERVER_LABEL".to_string(), label.to_string()),
(
"MCP_LOG_PATH".to_string(),
log_path.to_string_lossy().into_owned(),
),
]),
}),
}
}
#[test]
fn spawns_stdio_process_and_round_trips_io() {
let runtime = Builder::new_current_thread()
@@ -935,4 +1451,247 @@ mod tests {
cleanup_script(&script_path);
});
}
#[test]
fn manager_discovers_tools_from_stdio_config() {
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("runtime");
runtime.block_on(async {
let script_path = write_manager_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("alpha.log");
let servers = BTreeMap::from([(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &log_path),
)]);
let mut manager = McpServerManager::from_servers(&servers);
let tools = manager.discover_tools().await.expect("discover tools");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].server_name, "alpha");
assert_eq!(tools[0].raw_name, "echo");
assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo"));
assert_eq!(tools[0].tool.name, "echo");
assert!(manager.unsupported_servers().is_empty());
manager.shutdown().await.expect("shutdown");
cleanup_script(&script_path);
});
}
#[test]
fn manager_routes_tool_calls_to_correct_server() {
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("runtime");
runtime.block_on(async {
let script_path = write_manager_mcp_server_script();
let root = script_path.parent().expect("script parent");
let alpha_log = root.join("alpha.log");
let beta_log = root.join("beta.log");
let servers = BTreeMap::from([
(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &alpha_log),
),
(
"beta".to_string(),
manager_server_config(&script_path, "beta", &beta_log),
),
]);
let mut manager = McpServerManager::from_servers(&servers);
let tools = manager.discover_tools().await.expect("discover tools");
assert_eq!(tools.len(), 2);
let alpha = manager
.call_tool(
&mcp_tool_name("alpha", "echo"),
Some(json!({"text": "hello"})),
)
.await
.expect("call alpha tool");
let beta = manager
.call_tool(
&mcp_tool_name("beta", "echo"),
Some(json!({"text": "world"})),
)
.await
.expect("call beta tool");
assert_eq!(
alpha
.result
.as_ref()
.and_then(|result| result.structured_content.as_ref())
.and_then(|value| value.get("server")),
Some(&json!("alpha"))
);
assert_eq!(
beta.result
.as_ref()
.and_then(|result| result.structured_content.as_ref())
.and_then(|value| value.get("server")),
Some(&json!("beta"))
);
manager.shutdown().await.expect("shutdown");
cleanup_script(&script_path);
});
}
#[test]
fn manager_records_unsupported_non_stdio_servers_without_panicking() {
let servers = BTreeMap::from([
(
"http".to_string(),
ScopedMcpServerConfig {
scope: ConfigSource::Local,
config: McpServerConfig::Http(McpRemoteServerConfig {
url: "https://example.test/mcp".to_string(),
headers: BTreeMap::new(),
headers_helper: None,
oauth: None,
}),
},
),
(
"sdk".to_string(),
ScopedMcpServerConfig {
scope: ConfigSource::Local,
config: McpServerConfig::Sdk(McpSdkServerConfig {
name: "sdk-server".to_string(),
}),
},
),
(
"ws".to_string(),
ScopedMcpServerConfig {
scope: ConfigSource::Local,
config: McpServerConfig::Ws(McpWebSocketServerConfig {
url: "wss://example.test/mcp".to_string(),
headers: BTreeMap::new(),
headers_helper: None,
}),
},
),
]);
let manager = McpServerManager::from_servers(&servers);
let unsupported = manager.unsupported_servers();
assert_eq!(unsupported.len(), 3);
assert_eq!(unsupported[0].server_name, "http");
assert_eq!(unsupported[1].server_name, "sdk");
assert_eq!(unsupported[2].server_name, "ws");
}
#[test]
fn manager_shutdown_terminates_spawned_children_and_is_idempotent() {
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("runtime");
runtime.block_on(async {
let script_path = write_manager_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("alpha.log");
let servers = BTreeMap::from([(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &log_path),
)]);
let mut manager = McpServerManager::from_servers(&servers);
manager.discover_tools().await.expect("discover tools");
manager.shutdown().await.expect("first shutdown");
manager.shutdown().await.expect("second shutdown");
cleanup_script(&script_path);
});
}
#[test]
fn manager_reuses_spawned_server_between_discovery_and_call() {
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("runtime");
runtime.block_on(async {
let script_path = write_manager_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("alpha.log");
let servers = BTreeMap::from([(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &log_path),
)]);
let mut manager = McpServerManager::from_servers(&servers);
manager.discover_tools().await.expect("discover tools");
let response = manager
.call_tool(
&mcp_tool_name("alpha", "echo"),
Some(json!({"text": "reuse"})),
)
.await
.expect("call tool");
assert_eq!(
response
.result
.as_ref()
.and_then(|result| result.structured_content.as_ref())
.and_then(|value| value.get("initializeCount")),
Some(&json!(1))
);
let log = fs::read_to_string(&log_path).expect("read log");
assert_eq!(log.lines().filter(|line| *line == "initialize").count(), 1);
assert_eq!(
log.lines().collect::<Vec<_>>(),
vec!["initialize", "tools/list", "tools/call"]
);
manager.shutdown().await.expect("shutdown");
cleanup_script(&script_path);
});
}
#[test]
fn manager_reports_unknown_qualified_tool_name() {
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("runtime");
runtime.block_on(async {
let script_path = write_manager_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("alpha.log");
let servers = BTreeMap::from([(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &log_path),
)]);
let mut manager = McpServerManager::from_servers(&servers);
let error = manager
.call_tool(
&mcp_tool_name("alpha", "missing"),
Some(json!({"text": "nope"})),
)
.await
.expect_err("unknown qualified tool should fail");
match error {
McpServerManagerError::UnknownTool { qualified_name } => {
assert_eq!(qualified_name, mcp_tool_name("alpha", "missing"));
}
other => panic!("expected unknown tool error, got {other:?}"),
}
cleanup_script(&script_path);
});
}
}

View File

@@ -1,12 +1,15 @@
use std::collections::BTreeMap;
use std::fs::File;
use std::fs::{self, File};
use std::io::{self, Read};
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use sha2::{Digest, Sha256};
use crate::config::OAuthConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OAuthTokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
@@ -65,6 +68,48 @@ pub struct OAuthRefreshRequest {
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthCallbackParams {
pub code: Option<String>,
pub state: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct StoredOAuthCredentials {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_at: Option<u64>,
#[serde(default)]
scopes: Vec<String>,
}
impl From<OAuthTokenSet> for StoredOAuthCredentials {
fn from(value: OAuthTokenSet) -> Self {
Self {
access_token: value.access_token,
refresh_token: value.refresh_token,
expires_at: value.expires_at,
scopes: value.scopes,
}
}
}
impl From<StoredOAuthCredentials> for OAuthTokenSet {
fn from(value: StoredOAuthCredentials) -> Self {
Self {
access_token: value.access_token,
refresh_token: value.refresh_token,
expires_at: value.expires_at,
scopes: value.scopes,
}
}
}
impl OAuthAuthorizationRequest {
#[must_use]
pub fn from_config(
@@ -137,7 +182,6 @@ impl OAuthTokenExchangeRequest {
verifier: impl Into<String>,
redirect_uri: impl Into<String>,
) -> Self {
let _ = config;
Self {
grant_type: "authorization_code",
code: code.into(),
@@ -211,12 +255,116 @@ pub fn loopback_redirect_uri(port: u16) -> String {
format!("http://localhost:{port}/callback")
}
pub fn credentials_path() -> io::Result<PathBuf> {
Ok(credentials_home_dir()?.join("credentials.json"))
}
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
let path = credentials_path()?;
let root = read_credentials_root(&path)?;
let Some(oauth) = root.get("oauth") else {
return Ok(None);
};
if oauth.is_null() {
return Ok(None);
}
let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
Ok(Some(stored.into()))
}
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
let path = credentials_path()?;
let mut root = read_credentials_root(&path)?;
root.insert(
"oauth".to_string(),
serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
);
write_credentials_root(&path, &root)
}
pub fn clear_oauth_credentials() -> io::Result<()> {
let path = credentials_path()?;
let mut root = read_credentials_root(&path)?;
root.remove("oauth");
write_credentials_root(&path, &root)
}
pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
let (path, query) = target
.split_once('?')
.map_or((target, ""), |(path, query)| (path, query));
if path != "/callback" {
return Err(format!("unexpected callback path: {path}"));
}
parse_oauth_callback_query(query)
}
pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
let mut params = BTreeMap::new();
for pair in query.split('&').filter(|pair| !pair.is_empty()) {
let (key, value) = pair
.split_once('=')
.map_or((pair, ""), |(key, value)| (key, value));
params.insert(percent_decode(key)?, percent_decode(value)?);
}
Ok(OAuthCallbackParams {
code: params.get("code").cloned(),
state: params.get("state").cloned(),
error: params.get("error").cloned(),
error_description: params.get("error_description").cloned(),
})
}
fn generate_random_token(bytes: usize) -> io::Result<String> {
let mut buffer = vec![0_u8; bytes];
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
Ok(base64url_encode(&buffer))
}
fn credentials_home_dir() -> io::Result<PathBuf> {
if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") {
return Ok(PathBuf::from(path));
}
let home = std::env::var_os("HOME")
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
Ok(PathBuf::from(home).join(".claude"))
}
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
match fs::read_to_string(path) {
Ok(contents) => {
if contents.trim().is_empty() {
return Ok(Map::new());
}
serde_json::from_str::<Value>(&contents)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
.as_object()
.cloned()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"credentials file must contain a JSON object",
)
})
}
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
Err(error) => Err(error),
}
}
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
let temp_path = path.with_extension("json.tmp");
fs::write(&temp_path, format!("{rendered}\n"))?;
fs::rename(temp_path, path)
}
fn base64url_encode(bytes: &[u8]) -> String {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut output = String::new();
@@ -264,11 +412,49 @@ fn percent_encode(value: &str) -> String {
encoded
}
fn percent_decode(value: &str) -> Result<String, String> {
let mut decoded = Vec::with_capacity(value.len());
let bytes = value.as_bytes();
let mut index = 0;
while index < bytes.len() {
match bytes[index] {
b'%' if index + 2 < bytes.len() => {
let hi = decode_hex(bytes[index + 1])?;
let lo = decode_hex(bytes[index + 2])?;
decoded.push((hi << 4) | lo);
index += 3;
}
b'+' => {
decoded.push(b' ');
index += 1;
}
byte => {
decoded.push(byte);
index += 1;
}
}
}
String::from_utf8(decoded).map_err(|error| error.to_string())
}
fn decode_hex(byte: u8) -> Result<u8, String> {
match byte {
b'0'..=b'9' => Ok(byte - b'0'),
b'a'..=b'f' => Ok(byte - b'a' + 10),
b'A'..=b'F' => Ok(byte - b'A' + 10),
_ => Err(format!("invalid percent-encoding byte: {byte}")),
}
}
#[cfg(test)]
mod tests {
use std::time::{SystemTime, UNIX_EPOCH};
use super::{
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
};
fn sample_config() -> OAuthConfig {
@@ -282,6 +468,21 @@ mod tests {
}
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
crate::test_env_lock()
}
fn temp_config_home() -> std::path::PathBuf {
std::env::temp_dir().join(format!(
"runtime-oauth-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
))
}
#[test]
fn s256_challenge_matches_expected_vector() {
assert_eq!(
@@ -335,4 +536,54 @@ mod tests {
Some("org:read user:write")
);
}
#[test]
fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
let path = credentials_path().expect("credentials path");
std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
let token_set = OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(123),
scopes: vec!["scope:a".to_string()],
};
save_oauth_credentials(&token_set).expect("save credentials");
assert_eq!(
load_oauth_credentials().expect("load credentials"),
Some(token_set)
);
let saved = std::fs::read_to_string(&path).expect("read saved file");
assert!(saved.contains("\"other\": \"value\""));
assert!(saved.contains("\"oauth\""));
clear_oauth_credentials().expect("clear credentials");
assert_eq!(load_oauth_credentials().expect("load cleared"), None);
let cleared = std::fs::read_to_string(&path).expect("read cleared file");
assert!(cleared.contains("\"other\": \"value\""));
assert!(!cleared.contains("\"oauth\""));
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn parses_callback_query_and_target() {
let params =
parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
.expect("parse query");
assert_eq!(params.code.as_deref(), Some("abc123"));
assert_eq!(params.state.as_deref(), Some("state-1"));
assert_eq!(params.error_description.as_deref(), Some("needs login"));
let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
.expect("parse callback target");
assert_eq!(params.code.as_deref(), Some("abc"));
assert_eq!(params.state.as_deref(), Some("xyz"));
assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
}
}

View File

@@ -1,16 +1,29 @@
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PermissionMode {
Allow,
Deny,
Prompt,
ReadOnly,
WorkspaceWrite,
DangerFullAccess,
}
impl PermissionMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::ReadOnly => "read-only",
Self::WorkspaceWrite => "workspace-write",
Self::DangerFullAccess => "danger-full-access",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PermissionRequest {
pub tool_name: String,
pub input: String,
pub current_mode: PermissionMode,
pub required_mode: PermissionMode,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -31,31 +44,41 @@ pub enum PermissionOutcome {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PermissionPolicy {
default_mode: PermissionMode,
tool_modes: BTreeMap<String, PermissionMode>,
active_mode: PermissionMode,
tool_requirements: BTreeMap<String, PermissionMode>,
}
impl PermissionPolicy {
#[must_use]
pub fn new(default_mode: PermissionMode) -> Self {
pub fn new(active_mode: PermissionMode) -> Self {
Self {
default_mode,
tool_modes: BTreeMap::new(),
active_mode,
tool_requirements: BTreeMap::new(),
}
}
#[must_use]
pub fn with_tool_mode(mut self, tool_name: impl Into<String>, mode: PermissionMode) -> Self {
self.tool_modes.insert(tool_name.into(), mode);
pub fn with_tool_requirement(
mut self,
tool_name: impl Into<String>,
required_mode: PermissionMode,
) -> Self {
self.tool_requirements
.insert(tool_name.into(), required_mode);
self
}
#[must_use]
pub fn mode_for(&self, tool_name: &str) -> PermissionMode {
self.tool_modes
pub fn active_mode(&self) -> PermissionMode {
self.active_mode
}
#[must_use]
pub fn required_mode_for(&self, tool_name: &str) -> PermissionMode {
self.tool_requirements
.get(tool_name)
.copied()
.unwrap_or(self.default_mode)
.unwrap_or(PermissionMode::DangerFullAccess)
}
#[must_use]
@@ -65,23 +88,43 @@ impl PermissionPolicy {
input: &str,
mut prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
match self.mode_for(tool_name) {
PermissionMode::Allow => PermissionOutcome::Allow,
PermissionMode::Deny => PermissionOutcome::Deny {
reason: format!("tool '{tool_name}' denied by permission policy"),
},
PermissionMode::Prompt => match prompter.as_mut() {
Some(prompter) => match prompter.decide(&PermissionRequest {
tool_name: tool_name.to_string(),
input: input.to_string(),
}) {
let current_mode = self.active_mode();
let required_mode = self.required_mode_for(tool_name);
if current_mode >= required_mode {
return PermissionOutcome::Allow;
}
let request = PermissionRequest {
tool_name: tool_name.to_string(),
input: input.to_string(),
current_mode,
required_mode,
};
if current_mode == PermissionMode::WorkspaceWrite
&& required_mode == PermissionMode::DangerFullAccess
{
return match prompter.as_mut() {
Some(prompter) => match prompter.decide(&request) {
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
},
None => PermissionOutcome::Deny {
reason: format!("tool '{tool_name}' requires interactive approval"),
reason: format!(
"tool '{tool_name}' requires approval to escalate from {} to {}",
current_mode.as_str(),
required_mode.as_str()
),
},
},
};
}
PermissionOutcome::Deny {
reason: format!(
"tool '{tool_name}' requires {} permission; current mode is {}",
required_mode.as_str(),
current_mode.as_str()
),
}
}
}
@@ -93,25 +136,92 @@ mod tests {
PermissionPrompter, PermissionRequest,
};
struct AllowPrompter;
struct RecordingPrompter {
seen: Vec<PermissionRequest>,
allow: bool,
}
impl PermissionPrompter for AllowPrompter {
impl PermissionPrompter for RecordingPrompter {
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
assert_eq!(request.tool_name, "bash");
PermissionPromptDecision::Allow
self.seen.push(request.clone());
if self.allow {
PermissionPromptDecision::Allow
} else {
PermissionPromptDecision::Deny {
reason: "not now".to_string(),
}
}
}
}
#[test]
fn uses_tool_specific_overrides() {
let policy = PermissionPolicy::new(PermissionMode::Deny)
.with_tool_mode("bash", PermissionMode::Prompt);
fn allows_tools_when_active_mode_meets_requirement() {
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
assert_eq!(
policy.authorize("read_file", "{}", None),
PermissionOutcome::Allow
);
assert_eq!(
policy.authorize("write_file", "{}", None),
PermissionOutcome::Allow
);
}
#[test]
fn denies_read_only_escalations_without_prompt() {
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let outcome = policy.authorize("bash", "echo hi", Some(&mut AllowPrompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert!(matches!(
policy.authorize("edit", "x", None),
PermissionOutcome::Deny { .. }
policy.authorize("write_file", "{}", None),
PermissionOutcome::Deny { reason } if reason.contains("requires workspace-write permission")
));
assert!(matches!(
policy.authorize("bash", "{}", None),
PermissionOutcome::Deny { reason } if reason.contains("requires danger-full-access permission")
));
}
#[test]
fn prompts_for_workspace_write_to_danger_full_access_escalation() {
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize("bash", "echo hi", Some(&mut prompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
assert_eq!(prompter.seen[0].tool_name, "bash");
assert_eq!(
prompter.seen[0].current_mode,
PermissionMode::WorkspaceWrite
);
assert_eq!(
prompter.seen[0].required_mode,
PermissionMode::DangerFullAccess
);
}
#[test]
fn honors_prompt_rejection_reason() {
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: false,
};
assert!(matches!(
policy.authorize("bash", "echo hi", Some(&mut prompter)),
PermissionOutcome::Deny { reason } if reason == "not now"
));
}
}

View File

@@ -201,6 +201,7 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
dir.join("CLAUDE.md"),
dir.join("CLAUDE.local.md"),
dir.join(".claude").join("CLAUDE.md"),
dir.join(".claude").join("instructions.md"),
] {
push_context_file(&mut files, candidate)?;
}
@@ -468,6 +469,10 @@ mod tests {
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
crate::test_env_lock()
}
#[test]
fn discovers_instruction_files_from_ancestor_chain() {
let root = temp_dir();
@@ -477,10 +482,21 @@ mod tests {
fs::write(root.join("CLAUDE.local.md"), "local instructions")
.expect("write local instructions");
fs::create_dir_all(root.join("apps")).expect("apps dir");
fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir");
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
.expect("write apps instructions");
fs::write(
root.join("apps").join(".claude").join("instructions.md"),
"apps dot claude instructions",
)
.expect("write apps dot claude instructions");
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
.expect("write nested rules");
fs::write(
nested.join(".claude").join("instructions.md"),
"nested instructions",
)
.expect("write nested instructions");
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
let contents = context
@@ -495,7 +511,9 @@ mod tests {
"root instructions",
"local instructions",
"apps instructions",
"nested rules"
"apps dot claude instructions",
"nested rules",
"nested instructions"
]
);
fs::remove_dir_all(root).expect("cleanup temp dir");
@@ -574,7 +592,12 @@ mod tests {
)
.expect("write settings");
let _guard = env_lock();
let previous = std::env::current_dir().expect("cwd");
let original_home = std::env::var("HOME").ok();
let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok();
std::env::set_var("HOME", &root);
std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home"));
std::env::set_current_dir(&root).expect("change cwd");
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
.expect("system prompt should load")
@@ -584,6 +607,16 @@ mod tests {
",
);
std::env::set_current_dir(previous).expect("restore cwd");
if let Some(value) = original_home {
std::env::set_var("HOME", value);
} else {
std::env::remove_var("HOME");
}
if let Some(value) = original_claude_home {
std::env::set_var("CLAUDE_CONFIG_HOME", value);
} else {
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
assert!(prompt.contains("Project rules"));
assert!(prompt.contains("permissionMode"));
@@ -631,6 +664,29 @@ mod tests {
assert!(rendered.chars().count() <= 4_000 + "\n\n[truncated]".chars().count());
}
#[test]
fn discovers_dot_claude_instructions_markdown() {
let root = temp_dir();
let nested = root.join("apps").join("api");
fs::create_dir_all(nested.join(".claude")).expect("nested claude dir");
fs::write(
nested.join(".claude").join("instructions.md"),
"instruction markdown",
)
.expect("write instructions.md");
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
assert!(context
.instruction_files
.iter()
.any(|file| file.path.ends_with(".claude/instructions.md")));
assert!(
render_instruction_files(&context.instruction_files).contains("instruction markdown")
);
fs::remove_dir_all(root).expect("cleanup temp dir");
}
#[test]
fn renders_instruction_file_metadata() {
let rendered = render_instruction_files(&[ContextFile {

View File

@@ -39,10 +39,19 @@ pub struct ConversationMessage {
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionMetadata {
pub started_at: String,
pub model: String,
pub message_count: u32,
pub last_prompt: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Session {
pub version: u32,
pub messages: Vec<ConversationMessage>,
pub metadata: Option<SessionMetadata>,
}
#[derive(Debug)]
@@ -82,6 +91,7 @@ impl Session {
Self {
version: 1,
messages: Vec::new(),
metadata: None,
}
}
@@ -111,6 +121,9 @@ impl Session {
.collect(),
),
);
if let Some(metadata) = &self.metadata {
object.insert("metadata".to_string(), metadata.to_json());
}
JsonValue::Object(object)
}
@@ -131,7 +144,15 @@ impl Session {
.iter()
.map(ConversationMessage::from_json)
.collect::<Result<Vec<_>, _>>()?;
Ok(Self { version, messages })
let metadata = object
.get("metadata")
.map(SessionMetadata::from_json)
.transpose()?;
Ok(Self {
version,
messages,
metadata,
})
}
}
@@ -141,6 +162,41 @@ impl Default for Session {
}
}
impl SessionMetadata {
#[must_use]
pub fn to_json(&self) -> JsonValue {
let mut object = BTreeMap::new();
object.insert(
"started_at".to_string(),
JsonValue::String(self.started_at.clone()),
);
object.insert("model".to_string(), JsonValue::String(self.model.clone()));
object.insert(
"message_count".to_string(),
JsonValue::Number(i64::from(self.message_count)),
);
if let Some(last_prompt) = &self.last_prompt {
object.insert(
"last_prompt".to_string(),
JsonValue::String(last_prompt.clone()),
);
}
JsonValue::Object(object)
}
fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
let object = value.as_object().ok_or_else(|| {
SessionError::Format("session metadata must be an object".to_string())
})?;
Ok(Self {
started_at: required_string(object, "started_at")?,
model: required_string(object, "model")?,
message_count: required_u32(object, "message_count")?,
last_prompt: optional_string(object, "last_prompt"),
})
}
}
impl ConversationMessage {
#[must_use]
pub fn user_text(text: impl Into<String>) -> Self {
@@ -368,6 +424,13 @@ fn required_string(
.ok_or_else(|| SessionError::Format(format!("missing {key}")))
}
fn optional_string(object: &BTreeMap<String, JsonValue>, key: &str) -> Option<String> {
object
.get(key)
.and_then(JsonValue::as_str)
.map(ToOwned::to_owned)
}
fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
let value = object
.get(key)
@@ -378,7 +441,8 @@ fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32,
#[cfg(test)]
mod tests {
use super::{ContentBlock, ConversationMessage, MessageRole, Session};
use super::{ContentBlock, ConversationMessage, MessageRole, Session, SessionMetadata};
use crate::json::JsonValue;
use crate::usage::TokenUsage;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
@@ -386,6 +450,12 @@ mod tests {
#[test]
fn persists_and_restores_session_json() {
let mut session = Session::new();
session.metadata = Some(SessionMetadata {
started_at: "2026-04-01T00:00:00Z".to_string(),
model: "claude-sonnet".to_string(),
message_count: 3,
last_prompt: Some("hello".to_string()),
});
session
.messages
.push(ConversationMessage::user_text("hello"));
@@ -428,5 +498,23 @@ mod tests {
restored.messages[1].usage.expect("usage").total_tokens(),
17
);
assert_eq!(restored.metadata, session.metadata);
}
#[test]
fn loads_legacy_session_without_metadata() {
let legacy = r#"{
"version": 1,
"messages": [
{
"role": "user",
"blocks": [{"type": "text", "text": "hello"}]
}
]
}"#;
let restored = Session::from_json(&JsonValue::parse(legacy).expect("legacy json"))
.expect("legacy session should parse");
assert_eq!(restored.messages.len(), 1);
assert!(restored.metadata.is_none());
}
}

View File

@@ -300,6 +300,7 @@ mod tests {
cache_read_input_tokens: 0,
}),
}],
metadata: None,
};
let tracker = UsageTracker::from_session(&session);

View File

@@ -31,6 +31,10 @@ pub enum Command {
DumpManifests,
/// Print the current bootstrap phase skeleton
BootstrapPlan,
/// Start the OAuth login flow
Login,
/// Clear saved OAuth credentials
Logout,
/// Run a non-interactive prompt and exit
Prompt { prompt: Vec<String> },
}
@@ -86,4 +90,13 @@ mod tests {
})
);
}
#[test]
fn parses_login_and_logout_commands() {
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
assert_eq!(login.command, Some(Command::Login));
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
assert_eq!(logout.command, Some(Command::Logout));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ use std::time::{Duration, Instant};
use reqwest::blocking::Client;
use runtime::{
edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput,
GrepSearchInput,
GrepSearchInput, PermissionMode,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
@@ -45,6 +45,7 @@ pub struct ToolSpec {
pub name: &'static str,
pub description: &'static str,
pub input_schema: Value,
pub required_permission: PermissionMode,
}
#[must_use]
@@ -66,6 +67,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["command"],
"additionalProperties": false
}),
required_permission: PermissionMode::DangerFullAccess,
},
ToolSpec {
name: "read_file",
@@ -80,6 +82,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["path"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "write_file",
@@ -93,6 +96,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["path", "content"],
"additionalProperties": false
}),
required_permission: PermissionMode::WorkspaceWrite,
},
ToolSpec {
name: "edit_file",
@@ -108,6 +112,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["path", "old_string", "new_string"],
"additionalProperties": false
}),
required_permission: PermissionMode::WorkspaceWrite,
},
ToolSpec {
name: "glob_search",
@@ -121,6 +126,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["pattern"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "grep_search",
@@ -146,6 +152,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["pattern"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "WebFetch",
@@ -160,6 +167,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["url", "prompt"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "WebSearch",
@@ -180,6 +188,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["query"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "TodoWrite",
@@ -207,6 +216,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["todos"],
"additionalProperties": false
}),
required_permission: PermissionMode::WorkspaceWrite,
},
ToolSpec {
name: "Skill",
@@ -220,6 +230,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["skill"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "Agent",
@@ -236,6 +247,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["description", "prompt"],
"additionalProperties": false
}),
required_permission: PermissionMode::DangerFullAccess,
},
ToolSpec {
name: "ToolSearch",
@@ -249,6 +261,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["query"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "NotebookEdit",
@@ -265,6 +278,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["notebook_path"],
"additionalProperties": false
}),
required_permission: PermissionMode::WorkspaceWrite,
},
ToolSpec {
name: "Sleep",
@@ -277,6 +291,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["duration_ms"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "SendUserMessage",
@@ -297,6 +312,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["message", "status"],
"additionalProperties": false
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "Config",
@@ -312,6 +328,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["setting"],
"additionalProperties": false
}),
required_permission: PermissionMode::WorkspaceWrite,
},
ToolSpec {
name: "StructuredOutput",
@@ -320,6 +337,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"type": "object",
"additionalProperties": true
}),
required_permission: PermissionMode::ReadOnly,
},
ToolSpec {
name: "REPL",
@@ -334,6 +352,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["code", "language"],
"additionalProperties": false
}),
required_permission: PermissionMode::DangerFullAccess,
},
ToolSpec {
name: "PowerShell",
@@ -349,6 +368,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
"required": ["command"],
"additionalProperties": false
}),
required_permission: PermissionMode::DangerFullAccess,
},
]
}
@@ -2349,8 +2369,10 @@ fn parse_skill_description(contents: &str) -> Option<String> {
#[cfg(test)]
mod tests {
use std::fs;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
use std::thread;
use std::time::Duration;
@@ -2363,6 +2385,14 @@ mod tests {
LOCK.get_or_init(|| Mutex::new(()))
}
fn temp_path(name: &str) -> PathBuf {
let unique = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos();
std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}"))
}
#[test]
fn exposes_mvp_tools() {
let names = mvp_tool_specs()
@@ -2432,6 +2462,40 @@ mod tests {
assert!(titled_summary.contains("Title: Ignored"));
}
#[test]
fn web_fetch_supports_plain_text_and_rejects_invalid_url() {
let server = TestServer::spawn(Arc::new(|request_line: &str| {
assert!(request_line.starts_with("GET /plain "));
HttpResponse::text(200, "OK", "plain text response")
}));
let result = execute_tool(
"WebFetch",
&json!({
"url": format!("http://{}/plain", server.addr()),
"prompt": "Show me the content"
}),
)
.expect("WebFetch should succeed for text content");
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
assert_eq!(output["url"], format!("http://{}/plain", server.addr()));
assert!(output["result"]
.as_str()
.expect("result")
.contains("plain text response"));
let error = execute_tool(
"WebFetch",
&json!({
"url": "not a url",
"prompt": "Summarize"
}),
)
.expect_err("invalid URL should fail");
assert!(error.contains("relative URL without a base") || error.contains("invalid"));
}
#[test]
fn web_search_extracts_and_filters_results() {
let server = TestServer::spawn(Arc::new(|request_line: &str| {
@@ -2476,15 +2540,63 @@ mod tests {
assert_eq!(content[0]["url"], "https://docs.rs/reqwest");
}
#[test]
fn web_search_handles_generic_links_and_invalid_base_url() {
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let server = TestServer::spawn(Arc::new(|request_line: &str| {
assert!(request_line.contains("GET /fallback?q=generic+links "));
HttpResponse::html(
200,
"OK",
r#"
<html><body>
<a href="https://example.com/one">Example One</a>
<a href="https://example.com/one">Duplicate Example One</a>
<a href="https://docs.rs/tokio">Tokio Docs</a>
</body></html>
"#,
)
}));
std::env::set_var(
"CLAWD_WEB_SEARCH_BASE_URL",
format!("http://{}/fallback", server.addr()),
);
let result = execute_tool(
"WebSearch",
&json!({
"query": "generic links"
}),
)
.expect("WebSearch fallback parsing should succeed");
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
let results = output["results"].as_array().expect("results array");
let search_result = results
.iter()
.find(|item| item.get("content").is_some())
.expect("search result block present");
let content = search_result["content"].as_array().expect("content array");
assert_eq!(content.len(), 2);
assert_eq!(content[0]["url"], "https://example.com/one");
assert_eq!(content[1]["url"], "https://docs.rs/tokio");
std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url");
let error = execute_tool("WebSearch", &json!({ "query": "generic links" }))
.expect_err("invalid base URL should fail");
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
assert!(error.contains("relative URL without a base") || error.contains("empty host"));
}
#[test]
fn todo_write_persists_and_returns_previous_state() {
let path = std::env::temp_dir().join(format!(
"clawd-tools-todos-{}.json",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let path = temp_path("todos.json");
std::env::set_var("CLAWD_TODO_STORE", &path);
let first = execute_tool(
@@ -2526,6 +2638,59 @@ mod tests {
assert!(second_output["verificationNudgeNeeded"].is_null());
}
#[test]
fn todo_write_rejects_invalid_payloads_and_sets_verification_nudge() {
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let path = temp_path("todos-errors.json");
std::env::set_var("CLAWD_TODO_STORE", &path);
let empty = execute_tool("TodoWrite", &json!({ "todos": [] }))
.expect_err("empty todos should fail");
assert!(empty.contains("todos must not be empty"));
let too_many_active = execute_tool(
"TodoWrite",
&json!({
"todos": [
{"content": "One", "activeForm": "Doing one", "status": "in_progress"},
{"content": "Two", "activeForm": "Doing two", "status": "in_progress"}
]
}),
)
.expect_err("multiple in-progress todos should fail");
assert!(too_many_active.contains("zero or one todo items may be in_progress"));
let blank_content = execute_tool(
"TodoWrite",
&json!({
"todos": [
{"content": " ", "activeForm": "Doing it", "status": "pending"}
]
}),
)
.expect_err("blank content should fail");
assert!(blank_content.contains("todo content must not be empty"));
let nudge = execute_tool(
"TodoWrite",
&json!({
"todos": [
{"content": "Write tests", "activeForm": "Writing tests", "status": "completed"},
{"content": "Fix errors", "activeForm": "Fixing errors", "status": "completed"},
{"content": "Ship branch", "activeForm": "Shipping branch", "status": "completed"}
]
}),
)
.expect("completed todos should succeed");
std::env::remove_var("CLAWD_TODO_STORE");
let _ = fs::remove_file(path);
let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json");
assert_eq!(output["verificationNudgeNeeded"], true);
}
#[test]
fn skill_loads_local_skill_prompt() {
let result = execute_tool(
@@ -2599,13 +2764,10 @@ mod tests {
#[test]
fn agent_persists_handoff_metadata() {
let dir = std::env::temp_dir().join(format!(
"clawd-agent-store-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let dir = temp_path("agent-store");
std::env::set_var("CLAWD_AGENT_STORE", &dir);
let result = execute_tool(
@@ -2661,15 +2823,32 @@ mod tests {
let _ = std::fs::remove_dir_all(dir);
}
#[test]
fn agent_rejects_blank_required_fields() {
let missing_description = execute_tool(
"Agent",
&json!({
"description": " ",
"prompt": "Inspect"
}),
)
.expect_err("blank description should fail");
assert!(missing_description.contains("description must not be empty"));
let missing_prompt = execute_tool(
"Agent",
&json!({
"description": "Inspect branch",
"prompt": " "
}),
)
.expect_err("blank prompt should fail");
assert!(missing_prompt.contains("prompt must not be empty"));
}
#[test]
fn notebook_edit_replaces_inserts_and_deletes_cells() {
let path = std::env::temp_dir().join(format!(
"clawd-notebook-{}.ipynb",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
let path = temp_path("notebook.ipynb");
std::fs::write(
&path,
r#"{
@@ -2747,6 +2926,270 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn notebook_edit_rejects_invalid_inputs() {
let text_path = temp_path("notebook.txt");
fs::write(&text_path, "not a notebook").expect("write text file");
let wrong_extension = execute_tool(
"NotebookEdit",
&json!({
"notebook_path": text_path.display().to_string(),
"new_source": "print(1)\n"
}),
)
.expect_err("non-ipynb file should fail");
assert!(wrong_extension.contains("Jupyter notebook"));
let _ = fs::remove_file(&text_path);
let empty_notebook = temp_path("empty.ipynb");
fs::write(
&empty_notebook,
r#"{"cells":[],"metadata":{"kernelspec":{"language":"python"}},"nbformat":4,"nbformat_minor":5}"#,
)
.expect("write empty notebook");
let missing_source = execute_tool(
"NotebookEdit",
&json!({
"notebook_path": empty_notebook.display().to_string(),
"edit_mode": "insert"
}),
)
.expect_err("insert without source should fail");
assert!(missing_source.contains("new_source is required"));
let missing_cell = execute_tool(
"NotebookEdit",
&json!({
"notebook_path": empty_notebook.display().to_string(),
"edit_mode": "delete"
}),
)
.expect_err("delete on empty notebook should fail");
assert!(missing_cell.contains("Notebook has no cells to edit"));
let _ = fs::remove_file(empty_notebook);
}
#[test]
fn bash_tool_reports_success_exit_failure_timeout_and_background() {
let success = execute_tool("bash", &json!({ "command": "printf 'hello'" }))
.expect("bash should succeed");
let success_output: serde_json::Value = serde_json::from_str(&success).expect("json");
assert_eq!(success_output["stdout"], "hello");
assert_eq!(success_output["interrupted"], false);
let failure = execute_tool("bash", &json!({ "command": "printf 'oops' >&2; exit 7" }))
.expect("bash failure should still return structured output");
let failure_output: serde_json::Value = serde_json::from_str(&failure).expect("json");
assert_eq!(failure_output["returnCodeInterpretation"], "exit_code:7");
assert!(failure_output["stderr"]
.as_str()
.expect("stderr")
.contains("oops"));
let timeout = execute_tool("bash", &json!({ "command": "sleep 1", "timeout": 10 }))
.expect("bash timeout should return output");
let timeout_output: serde_json::Value = serde_json::from_str(&timeout).expect("json");
assert_eq!(timeout_output["interrupted"], true);
assert_eq!(timeout_output["returnCodeInterpretation"], "timeout");
assert!(timeout_output["stderr"]
.as_str()
.expect("stderr")
.contains("Command exceeded timeout"));
let background = execute_tool(
"bash",
&json!({ "command": "sleep 1", "run_in_background": true }),
)
.expect("bash background should succeed");
let background_output: serde_json::Value = serde_json::from_str(&background).expect("json");
assert!(background_output["backgroundTaskId"].as_str().is_some());
assert_eq!(background_output["noOutputExpected"], true);
}
#[test]
fn file_tools_cover_read_write_and_edit_behaviors() {
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let root = temp_path("fs-suite");
fs::create_dir_all(&root).expect("create root");
let original_dir = std::env::current_dir().expect("cwd");
std::env::set_current_dir(&root).expect("set cwd");
let write_create = execute_tool(
"write_file",
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
)
.expect("write create should succeed");
let write_create_output: serde_json::Value =
serde_json::from_str(&write_create).expect("json");
assert_eq!(write_create_output["type"], "create");
assert!(root.join("nested/demo.txt").exists());
let write_update = execute_tool(
"write_file",
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\ngamma\n" }),
)
.expect("write update should succeed");
let write_update_output: serde_json::Value =
serde_json::from_str(&write_update).expect("json");
assert_eq!(write_update_output["type"], "update");
assert_eq!(write_update_output["originalFile"], "alpha\nbeta\nalpha\n");
let read_full = execute_tool("read_file", &json!({ "path": "nested/demo.txt" }))
.expect("read full should succeed");
let read_full_output: serde_json::Value = serde_json::from_str(&read_full).expect("json");
assert_eq!(read_full_output["file"]["content"], "alpha\nbeta\ngamma");
assert_eq!(read_full_output["file"]["startLine"], 1);
let read_slice = execute_tool(
"read_file",
&json!({ "path": "nested/demo.txt", "offset": 1, "limit": 1 }),
)
.expect("read slice should succeed");
let read_slice_output: serde_json::Value = serde_json::from_str(&read_slice).expect("json");
assert_eq!(read_slice_output["file"]["content"], "beta");
assert_eq!(read_slice_output["file"]["startLine"], 2);
let read_past_end = execute_tool(
"read_file",
&json!({ "path": "nested/demo.txt", "offset": 50 }),
)
.expect("read past EOF should succeed");
let read_past_end_output: serde_json::Value =
serde_json::from_str(&read_past_end).expect("json");
assert_eq!(read_past_end_output["file"]["content"], "");
assert_eq!(read_past_end_output["file"]["startLine"], 4);
let read_error = execute_tool("read_file", &json!({ "path": "missing.txt" }))
.expect_err("missing file should fail");
assert!(!read_error.is_empty());
let edit_once = execute_tool(
"edit_file",
&json!({ "path": "nested/demo.txt", "old_string": "alpha", "new_string": "omega" }),
)
.expect("single edit should succeed");
let edit_once_output: serde_json::Value = serde_json::from_str(&edit_once).expect("json");
assert_eq!(edit_once_output["replaceAll"], false);
assert_eq!(
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
"omega\nbeta\ngamma\n"
);
execute_tool(
"write_file",
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
)
.expect("reset file");
let edit_all = execute_tool(
"edit_file",
&json!({
"path": "nested/demo.txt",
"old_string": "alpha",
"new_string": "omega",
"replace_all": true
}),
)
.expect("replace all should succeed");
let edit_all_output: serde_json::Value = serde_json::from_str(&edit_all).expect("json");
assert_eq!(edit_all_output["replaceAll"], true);
assert_eq!(
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
"omega\nbeta\nomega\n"
);
let edit_same = execute_tool(
"edit_file",
&json!({ "path": "nested/demo.txt", "old_string": "omega", "new_string": "omega" }),
)
.expect_err("identical old/new should fail");
assert!(edit_same.contains("must differ"));
let edit_missing = execute_tool(
"edit_file",
&json!({ "path": "nested/demo.txt", "old_string": "missing", "new_string": "omega" }),
)
.expect_err("missing substring should fail");
assert!(edit_missing.contains("old_string not found"));
std::env::set_current_dir(&original_dir).expect("restore cwd");
let _ = fs::remove_dir_all(root);
}
#[test]
fn glob_and_grep_tools_cover_success_and_errors() {
let _guard = env_lock()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let root = temp_path("search-suite");
fs::create_dir_all(root.join("nested")).expect("create root");
let original_dir = std::env::current_dir().expect("cwd");
std::env::set_current_dir(&root).expect("set cwd");
fs::write(
root.join("nested/lib.rs"),
"fn main() {}\nlet alpha = 1;\nlet alpha = 2;\n",
)
.expect("write rust file");
fs::write(root.join("nested/notes.txt"), "alpha\nbeta\n").expect("write txt file");
let globbed = execute_tool("glob_search", &json!({ "pattern": "nested/*.rs" }))
.expect("glob should succeed");
let globbed_output: serde_json::Value = serde_json::from_str(&globbed).expect("json");
assert_eq!(globbed_output["numFiles"], 1);
assert!(globbed_output["filenames"][0]
.as_str()
.expect("filename")
.ends_with("nested/lib.rs"));
let glob_error = execute_tool("glob_search", &json!({ "pattern": "[" }))
.expect_err("invalid glob should fail");
assert!(!glob_error.is_empty());
let grep_content = execute_tool(
"grep_search",
&json!({
"pattern": "alpha",
"path": "nested",
"glob": "*.rs",
"output_mode": "content",
"-n": true,
"head_limit": 1,
"offset": 1
}),
)
.expect("grep content should succeed");
let grep_content_output: serde_json::Value =
serde_json::from_str(&grep_content).expect("json");
assert_eq!(grep_content_output["numFiles"], 0);
assert!(grep_content_output["appliedLimit"].is_null());
assert_eq!(grep_content_output["appliedOffset"], 1);
assert!(grep_content_output["content"]
.as_str()
.expect("content")
.contains("let alpha = 2;"));
let grep_count = execute_tool(
"grep_search",
&json!({ "pattern": "alpha", "path": "nested", "output_mode": "count" }),
)
.expect("grep count should succeed");
let grep_count_output: serde_json::Value = serde_json::from_str(&grep_count).expect("json");
assert_eq!(grep_count_output["numMatches"], 3);
let grep_error = execute_tool(
"grep_search",
&json!({ "pattern": "(alpha", "path": "nested" }),
)
.expect_err("invalid regex should fail");
assert!(!grep_error.is_empty());
std::env::set_current_dir(&original_dir).expect("restore cwd");
let _ = fs::remove_dir_all(root);
}
#[test]
fn sleep_waits_and_reports_duration() {
let started = std::time::Instant::now();
@@ -3038,6 +3481,15 @@ printf 'pwsh:%s' "$1"
}
}
fn text(status: u16, reason: &'static str, body: &str) -> Self {
Self {
status,
reason,
content_type: "text/plain; charset=utf-8",
body: body.to_string(),
}
}
fn to_bytes(&self) -> Vec<u8> {
format!(
"HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",