mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-05 16:39:04 +08:00
This adds crate-level and type-level Rustdoc to the runtime crate's core exported types so downstream crates and contributors can understand the session, prompt, permission, OAuth, usage, and tool I/O primitives without spelunking every implementation file. Constraint: The docs pass needed to stay focused on public runtime types without changing behavior Rejected: Add blanket docs to every public item in one sweep | larger churn than needed for a targeted docs pass Confidence: high Scope-risk: narrow Reversibility: clean Directive: When exporting new runtime primitives from lib.rs, add a short Rustdoc summary in the defining module at the same time Tested: cargo build --workspace; cargo test --workspace Not-tested: rustdoc HTML rendering beyond doc-test coverage
1691 lines
56 KiB
Rust
1691 lines
56 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::fmt::{Display, Formatter};
|
|
|
|
use serde_json::{Map, Value};
|
|
use telemetry::SessionTracer;
|
|
|
|
use crate::compact::{
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
};
|
|
use crate::config::RuntimeFeatureConfig;
|
|
use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
|
|
use crate::permissions::{
|
|
PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
|
|
};
|
|
use crate::session::{ContentBlock, ConversationMessage, Session};
|
|
use crate::usage::{TokenUsage, UsageTracker};
|
|
|
|
const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000;
|
|
const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
|
|
|
|
/// Fully assembled request payload sent to the upstream model client.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ApiRequest {
|
|
pub system_prompt: Vec<String>,
|
|
pub messages: Vec<ConversationMessage>,
|
|
}
|
|
|
|
/// Streamed events emitted while processing a single assistant turn.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum AssistantEvent {
|
|
TextDelta(String),
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
input: String,
|
|
},
|
|
Usage(TokenUsage),
|
|
PromptCache(PromptCacheEvent),
|
|
MessageStop,
|
|
}
|
|
|
|
/// Prompt-cache telemetry captured from the provider response stream.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct PromptCacheEvent {
|
|
pub unexpected: bool,
|
|
pub reason: String,
|
|
pub previous_cache_read_input_tokens: u32,
|
|
pub current_cache_read_input_tokens: u32,
|
|
pub token_drop: u32,
|
|
}
|
|
|
|
/// Minimal streaming API contract required by [`ConversationRuntime`].
|
|
pub trait ApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
|
|
}
|
|
|
|
/// Trait implemented by tool dispatchers that execute model-requested tools.
|
|
pub trait ToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
|
|
}
|
|
|
|
/// Error returned when a tool invocation fails locally.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ToolError {
|
|
message: String,
|
|
}
|
|
|
|
impl ToolError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for ToolError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ToolError {}
|
|
|
|
/// Error returned when a conversation turn cannot be completed.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct RuntimeError {
|
|
message: String,
|
|
}
|
|
|
|
impl RuntimeError {
|
|
#[must_use]
|
|
pub fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Display for RuntimeError {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for RuntimeError {}
|
|
|
|
/// Summary of one completed runtime turn, including tool results and usage.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct TurnSummary {
|
|
pub assistant_messages: Vec<ConversationMessage>,
|
|
pub tool_results: Vec<ConversationMessage>,
|
|
pub prompt_cache_events: Vec<PromptCacheEvent>,
|
|
pub iterations: usize,
|
|
pub usage: TokenUsage,
|
|
pub auto_compaction: Option<AutoCompactionEvent>,
|
|
}
|
|
|
|
/// Details about automatic session compaction applied during a turn.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct AutoCompactionEvent {
|
|
pub removed_message_count: usize,
|
|
}
|
|
|
|
/// Coordinates the model loop, tool execution, hooks, and session updates.
|
|
pub struct ConversationRuntime<C, T> {
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
max_iterations: usize,
|
|
usage_tracker: UsageTracker,
|
|
hook_runner: HookRunner,
|
|
auto_compaction_input_tokens_threshold: u32,
|
|
hook_abort_signal: HookAbortSignal,
|
|
hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
|
|
session_tracer: Option<SessionTracer>,
|
|
}
|
|
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
where
|
|
C: ApiClient,
|
|
T: ToolExecutor,
|
|
{
|
|
#[must_use]
|
|
pub fn new(
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
) -> Self {
|
|
Self::new_with_features(
|
|
session,
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
&RuntimeFeatureConfig::default(),
|
|
)
|
|
}
|
|
|
|
#[must_use]
|
|
#[allow(clippy::needless_pass_by_value)]
|
|
pub fn new_with_features(
|
|
session: Session,
|
|
api_client: C,
|
|
tool_executor: T,
|
|
permission_policy: PermissionPolicy,
|
|
system_prompt: Vec<String>,
|
|
feature_config: &RuntimeFeatureConfig,
|
|
) -> Self {
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
Self {
|
|
session,
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
max_iterations: usize::MAX,
|
|
usage_tracker,
|
|
hook_runner: HookRunner::from_feature_config(feature_config),
|
|
auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
|
|
hook_abort_signal: HookAbortSignal::default(),
|
|
hook_progress_reporter: None,
|
|
session_tracer: None,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
|
|
self.max_iterations = max_iterations;
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
|
|
self.auto_compaction_input_tokens_threshold = threshold;
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
|
|
self.hook_abort_signal = hook_abort_signal;
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_hook_progress_reporter(
|
|
mut self,
|
|
hook_progress_reporter: Box<dyn HookProgressReporter>,
|
|
) -> Self {
|
|
self.hook_progress_reporter = Some(hook_progress_reporter);
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
|
|
self.session_tracer = Some(session_tracer);
|
|
self
|
|
}
|
|
|
|
fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
|
|
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
self.hook_runner.run_pre_tool_use_with_context(
|
|
tool_name,
|
|
input,
|
|
Some(&self.hook_abort_signal),
|
|
Some(reporter.as_mut()),
|
|
)
|
|
} else {
|
|
self.hook_runner.run_pre_tool_use_with_context(
|
|
tool_name,
|
|
input,
|
|
Some(&self.hook_abort_signal),
|
|
None,
|
|
)
|
|
}
|
|
}
|
|
|
|
fn run_post_tool_use_hook(
|
|
&mut self,
|
|
tool_name: &str,
|
|
input: &str,
|
|
output: &str,
|
|
is_error: bool,
|
|
) -> HookRunResult {
|
|
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
self.hook_runner.run_post_tool_use_with_context(
|
|
tool_name,
|
|
input,
|
|
output,
|
|
is_error,
|
|
Some(&self.hook_abort_signal),
|
|
Some(reporter.as_mut()),
|
|
)
|
|
} else {
|
|
self.hook_runner.run_post_tool_use_with_context(
|
|
tool_name,
|
|
input,
|
|
output,
|
|
is_error,
|
|
Some(&self.hook_abort_signal),
|
|
None,
|
|
)
|
|
}
|
|
}
|
|
|
|
fn run_post_tool_use_failure_hook(
|
|
&mut self,
|
|
tool_name: &str,
|
|
input: &str,
|
|
output: &str,
|
|
) -> HookRunResult {
|
|
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
self.hook_runner.run_post_tool_use_failure_with_context(
|
|
tool_name,
|
|
input,
|
|
output,
|
|
Some(&self.hook_abort_signal),
|
|
Some(reporter.as_mut()),
|
|
)
|
|
} else {
|
|
self.hook_runner.run_post_tool_use_failure_with_context(
|
|
tool_name,
|
|
input,
|
|
output,
|
|
Some(&self.hook_abort_signal),
|
|
None,
|
|
)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn run_turn(
|
|
&mut self,
|
|
user_input: impl Into<String>,
|
|
mut prompter: Option<&mut dyn PermissionPrompter>,
|
|
) -> Result<TurnSummary, RuntimeError> {
|
|
let user_input = user_input.into();
|
|
self.record_turn_started(&user_input);
|
|
self.session
|
|
.push_user_text(user_input)
|
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
|
|
|
let mut assistant_messages = Vec::new();
|
|
let mut tool_results = Vec::new();
|
|
let mut prompt_cache_events = Vec::new();
|
|
let mut iterations = 0;
|
|
|
|
loop {
|
|
iterations += 1;
|
|
if iterations > self.max_iterations {
|
|
let error = RuntimeError::new(
|
|
"conversation loop exceeded the maximum number of iterations",
|
|
);
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
|
|
let request = ApiRequest {
|
|
system_prompt: self.system_prompt.clone(),
|
|
messages: self.session.messages.clone(),
|
|
};
|
|
let events = match self.api_client.stream(request) {
|
|
Ok(events) => events,
|
|
Err(error) => {
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
};
|
|
let (assistant_message, usage, turn_prompt_cache_events) =
|
|
match build_assistant_message(events) {
|
|
Ok(result) => result,
|
|
Err(error) => {
|
|
self.record_turn_failed(iterations, &error);
|
|
return Err(error);
|
|
}
|
|
};
|
|
if let Some(usage) = usage {
|
|
self.usage_tracker.record(usage);
|
|
}
|
|
prompt_cache_events.extend(turn_prompt_cache_events);
|
|
let pending_tool_uses = assistant_message
|
|
.blocks
|
|
.iter()
|
|
.filter_map(|block| match block {
|
|
ContentBlock::ToolUse { id, name, input } => {
|
|
Some((id.clone(), name.clone(), input.clone()))
|
|
}
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
self.record_assistant_iteration(
|
|
iterations,
|
|
&assistant_message,
|
|
pending_tool_uses.len(),
|
|
);
|
|
|
|
self.session
|
|
.push_message(assistant_message.clone())
|
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
|
assistant_messages.push(assistant_message);
|
|
|
|
if pending_tool_uses.is_empty() {
|
|
break;
|
|
}
|
|
|
|
for (tool_use_id, tool_name, input) in pending_tool_uses {
|
|
let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
|
|
let effective_input = pre_hook_result
|
|
.updated_input()
|
|
.map_or_else(|| input.clone(), ToOwned::to_owned);
|
|
let permission_context = PermissionContext::new(
|
|
pre_hook_result.permission_override(),
|
|
pre_hook_result.permission_reason().map(ToOwned::to_owned),
|
|
);
|
|
|
|
let permission_outcome = if pre_hook_result.is_cancelled() {
|
|
PermissionOutcome::Deny {
|
|
reason: format_hook_message(
|
|
&pre_hook_result,
|
|
&format!("PreToolUse hook cancelled tool `{tool_name}`"),
|
|
),
|
|
}
|
|
} else if pre_hook_result.is_failed() {
|
|
PermissionOutcome::Deny {
|
|
reason: format_hook_message(
|
|
&pre_hook_result,
|
|
&format!("PreToolUse hook failed for tool `{tool_name}`"),
|
|
),
|
|
}
|
|
} else if pre_hook_result.is_denied() {
|
|
PermissionOutcome::Deny {
|
|
reason: format_hook_message(
|
|
&pre_hook_result,
|
|
&format!("PreToolUse hook denied tool `{tool_name}`"),
|
|
),
|
|
}
|
|
} else if let Some(prompt) = prompter.as_mut() {
|
|
self.permission_policy.authorize_with_context(
|
|
&tool_name,
|
|
&effective_input,
|
|
&permission_context,
|
|
Some(*prompt),
|
|
)
|
|
} else {
|
|
self.permission_policy.authorize_with_context(
|
|
&tool_name,
|
|
&effective_input,
|
|
&permission_context,
|
|
None,
|
|
)
|
|
};
|
|
|
|
let result_message = match permission_outcome {
|
|
PermissionOutcome::Allow => {
|
|
self.record_tool_started(iterations, &tool_name);
|
|
let (mut output, mut is_error) =
|
|
match self.tool_executor.execute(&tool_name, &effective_input) {
|
|
Ok(output) => (output, false),
|
|
Err(error) => (error.to_string(), true),
|
|
};
|
|
output = merge_hook_feedback(pre_hook_result.messages(), output, false);
|
|
|
|
let post_hook_result = if is_error {
|
|
self.run_post_tool_use_failure_hook(
|
|
&tool_name,
|
|
&effective_input,
|
|
&output,
|
|
)
|
|
} else {
|
|
self.run_post_tool_use_hook(
|
|
&tool_name,
|
|
&effective_input,
|
|
&output,
|
|
false,
|
|
)
|
|
};
|
|
if post_hook_result.is_denied()
|
|
|| post_hook_result.is_failed()
|
|
|| post_hook_result.is_cancelled()
|
|
{
|
|
is_error = true;
|
|
}
|
|
output = merge_hook_feedback(
|
|
post_hook_result.messages(),
|
|
output,
|
|
post_hook_result.is_denied()
|
|
|| post_hook_result.is_failed()
|
|
|| post_hook_result.is_cancelled(),
|
|
);
|
|
|
|
ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
|
|
}
|
|
PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
|
|
tool_use_id,
|
|
tool_name,
|
|
merge_hook_feedback(pre_hook_result.messages(), reason, true),
|
|
true,
|
|
),
|
|
};
|
|
self.session
|
|
.push_message(result_message.clone())
|
|
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
|
self.record_tool_finished(iterations, &result_message);
|
|
tool_results.push(result_message);
|
|
}
|
|
}
|
|
|
|
let auto_compaction = self.maybe_auto_compact();
|
|
|
|
let summary = TurnSummary {
|
|
assistant_messages,
|
|
tool_results,
|
|
prompt_cache_events,
|
|
iterations,
|
|
usage: self.usage_tracker.cumulative_usage(),
|
|
auto_compaction,
|
|
};
|
|
self.record_turn_completed(&summary);
|
|
|
|
Ok(summary)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
|
|
compact_session(&self.session, config)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn estimated_tokens(&self) -> usize {
|
|
estimate_session_tokens(&self.session)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn usage(&self) -> &UsageTracker {
|
|
&self.usage_tracker
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn session(&self) -> &Session {
|
|
&self.session
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn fork_session(&self, branch_name: Option<String>) -> Session {
|
|
self.session.fork(branch_name)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn into_session(self) -> Session {
|
|
self.session
|
|
}
|
|
|
|
fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
|
|
if self.usage_tracker.cumulative_usage().input_tokens
|
|
< self.auto_compaction_input_tokens_threshold
|
|
{
|
|
return None;
|
|
}
|
|
|
|
let result = compact_session(
|
|
&self.session,
|
|
CompactionConfig {
|
|
max_estimated_tokens: 0,
|
|
..CompactionConfig::default()
|
|
},
|
|
);
|
|
|
|
if result.removed_message_count == 0 {
|
|
return None;
|
|
}
|
|
|
|
self.session = result.compacted_session;
|
|
Some(AutoCompactionEvent {
|
|
removed_message_count: result.removed_message_count,
|
|
})
|
|
}
|
|
|
|
fn record_turn_started(&self, user_input: &str) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"user_input".to_string(),
|
|
Value::String(user_input.to_string()),
|
|
);
|
|
session_tracer.record("turn_started", attributes);
|
|
}
|
|
|
|
fn record_assistant_iteration(
|
|
&self,
|
|
iteration: usize,
|
|
assistant_message: &ConversationMessage,
|
|
pending_tool_use_count: usize,
|
|
) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert("iteration".to_string(), Value::from(iteration as u64));
|
|
attributes.insert(
|
|
"assistant_blocks".to_string(),
|
|
Value::from(assistant_message.blocks.len() as u64),
|
|
);
|
|
attributes.insert(
|
|
"pending_tool_use_count".to_string(),
|
|
Value::from(pending_tool_use_count as u64),
|
|
);
|
|
session_tracer.record("assistant_iteration_completed", attributes);
|
|
}
|
|
|
|
fn record_tool_started(&self, iteration: usize, tool_name: &str) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert("iteration".to_string(), Value::from(iteration as u64));
|
|
attributes.insert(
|
|
"tool_name".to_string(),
|
|
Value::String(tool_name.to_string()),
|
|
);
|
|
session_tracer.record("tool_execution_started", attributes);
|
|
}
|
|
|
|
fn record_tool_finished(&self, iteration: usize, result_message: &ConversationMessage) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let Some(ContentBlock::ToolResult {
|
|
tool_name,
|
|
is_error,
|
|
..
|
|
}) = result_message.blocks.first()
|
|
else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert("iteration".to_string(), Value::from(iteration as u64));
|
|
attributes.insert("tool_name".to_string(), Value::String(tool_name.clone()));
|
|
attributes.insert("is_error".to_string(), Value::Bool(*is_error));
|
|
session_tracer.record("tool_execution_finished", attributes);
|
|
}
|
|
|
|
fn record_turn_completed(&self, summary: &TurnSummary) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert(
|
|
"iterations".to_string(),
|
|
Value::from(summary.iterations as u64),
|
|
);
|
|
attributes.insert(
|
|
"assistant_messages".to_string(),
|
|
Value::from(summary.assistant_messages.len() as u64),
|
|
);
|
|
attributes.insert(
|
|
"tool_results".to_string(),
|
|
Value::from(summary.tool_results.len() as u64),
|
|
);
|
|
attributes.insert(
|
|
"prompt_cache_events".to_string(),
|
|
Value::from(summary.prompt_cache_events.len() as u64),
|
|
);
|
|
session_tracer.record("turn_completed", attributes);
|
|
}
|
|
|
|
fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) {
|
|
let Some(session_tracer) = &self.session_tracer else {
|
|
return;
|
|
};
|
|
|
|
let mut attributes = Map::new();
|
|
attributes.insert("iteration".to_string(), Value::from(iteration as u64));
|
|
attributes.insert("error".to_string(), Value::String(error.to_string()));
|
|
session_tracer.record("turn_failed", attributes);
|
|
}
|
|
}
|
|
|
|
/// Reads the automatic compaction threshold from the environment.
|
|
#[must_use]
|
|
pub fn auto_compaction_threshold_from_env() -> u32 {
|
|
parse_auto_compaction_threshold(
|
|
std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
|
|
.ok()
|
|
.as_deref(),
|
|
)
|
|
}
|
|
|
|
#[must_use]
|
|
fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
|
|
value
|
|
.and_then(|raw| raw.trim().parse::<u32>().ok())
|
|
.filter(|threshold| *threshold > 0)
|
|
.unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
|
|
}
|
|
|
|
fn build_assistant_message(
|
|
events: Vec<AssistantEvent>,
|
|
) -> Result<
|
|
(
|
|
ConversationMessage,
|
|
Option<TokenUsage>,
|
|
Vec<PromptCacheEvent>,
|
|
),
|
|
RuntimeError,
|
|
> {
|
|
let mut text = String::new();
|
|
let mut blocks = Vec::new();
|
|
let mut prompt_cache_events = Vec::new();
|
|
let mut finished = false;
|
|
let mut usage = None;
|
|
|
|
for event in events {
|
|
match event {
|
|
AssistantEvent::TextDelta(delta) => text.push_str(&delta),
|
|
AssistantEvent::ToolUse { id, name, input } => {
|
|
flush_text_block(&mut text, &mut blocks);
|
|
blocks.push(ContentBlock::ToolUse { id, name, input });
|
|
}
|
|
AssistantEvent::Usage(value) => usage = Some(value),
|
|
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
|
|
AssistantEvent::MessageStop => {
|
|
finished = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
flush_text_block(&mut text, &mut blocks);
|
|
|
|
if !finished {
|
|
return Err(RuntimeError::new(
|
|
"assistant stream ended without a message stop event",
|
|
));
|
|
}
|
|
if blocks.is_empty() {
|
|
return Err(RuntimeError::new("assistant stream produced no content"));
|
|
}
|
|
|
|
Ok((
|
|
ConversationMessage::assistant_with_usage(blocks, usage),
|
|
usage,
|
|
prompt_cache_events,
|
|
))
|
|
}
|
|
|
|
fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
|
|
if !text.is_empty() {
|
|
blocks.push(ContentBlock::Text {
|
|
text: std::mem::take(text),
|
|
});
|
|
}
|
|
}
|
|
|
|
fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
|
|
if result.messages().is_empty() {
|
|
fallback.to_string()
|
|
} else {
|
|
result.messages().join("\n")
|
|
}
|
|
}
|
|
|
|
fn merge_hook_feedback(messages: &[String], output: String, is_error: bool) -> String {
|
|
if messages.is_empty() {
|
|
return output;
|
|
}
|
|
|
|
let mut sections = Vec::new();
|
|
if !output.trim().is_empty() {
|
|
sections.push(output);
|
|
}
|
|
let label = if is_error {
|
|
"Hook feedback (error)"
|
|
} else {
|
|
"Hook feedback"
|
|
};
|
|
sections.push(format!("{label}:\n{}", messages.join("\n")));
|
|
sections.join("\n\n")
|
|
}
|
|
|
|
type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
|
|
|
|
/// Simple in-memory tool executor for tests and lightweight integrations.
|
|
#[derive(Default)]
|
|
pub struct StaticToolExecutor {
|
|
handlers: BTreeMap<String, ToolHandler>,
|
|
}
|
|
|
|
impl StaticToolExecutor {
|
|
#[must_use]
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn register(
|
|
mut self,
|
|
tool_name: impl Into<String>,
|
|
handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
|
|
) -> Self {
|
|
self.handlers.insert(tool_name.into(), Box::new(handler));
|
|
self
|
|
}
|
|
}
|
|
|
|
impl ToolExecutor for StaticToolExecutor {
|
|
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
|
|
self.handlers
|
|
.get_mut(tool_name)
|
|
.ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
build_assistant_message, parse_auto_compaction_threshold, ApiClient, ApiRequest,
|
|
AssistantEvent, AutoCompactionEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
|
|
StaticToolExecutor, ToolExecutor, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
|
|
};
|
|
use crate::compact::CompactionConfig;
|
|
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
|
|
use crate::permissions::{
|
|
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
|
PermissionRequest,
|
|
};
|
|
use crate::prompt::{ProjectContext, SystemPromptBuilder};
|
|
use crate::session::{ContentBlock, MessageRole, Session};
|
|
use crate::usage::TokenUsage;
|
|
use crate::ToolError;
|
|
use std::fs;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use telemetry::{MemoryTelemetrySink, SessionTracer, TelemetryEvent};
|
|
|
|
struct ScriptedApiClient {
|
|
call_count: usize,
|
|
}
|
|
|
|
impl ApiClient for ScriptedApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.call_count += 1;
|
|
match self.call_count {
|
|
1 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::User));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("Let me calculate that.".to_string()),
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "add".to_string(),
|
|
input: "2,2".to_string(),
|
|
},
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 20,
|
|
output_tokens: 6,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 2,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
2 => {
|
|
let last_message = request
|
|
.messages
|
|
.last()
|
|
.expect("tool result should be present");
|
|
assert_eq!(last_message.role, MessageRole::Tool);
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("The answer is 4.".to_string()),
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 24,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 1,
|
|
cache_read_input_tokens: 3,
|
|
}),
|
|
AssistantEvent::PromptCache(PromptCacheEvent {
|
|
unexpected: true,
|
|
reason:
|
|
"cache read tokens dropped while prompt fingerprint remained stable"
|
|
.to_string(),
|
|
previous_cache_read_input_tokens: 6_000,
|
|
current_cache_read_input_tokens: 1_000,
|
|
token_drop: 5_000,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => unreachable!("extra API call"),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct PromptAllowOnce;
|
|
|
|
impl PermissionPrompter for PromptAllowOnce {
|
|
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
|
assert_eq!(request.tool_name, "add");
|
|
PermissionPromptDecision::Allow
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
|
|
let api_client = ScriptedApiClient { call_count: 0 };
|
|
let tool_executor = StaticToolExecutor::new().register("add", |input| {
|
|
let total = input
|
|
.split(',')
|
|
.map(|part| part.parse::<i32>().expect("input must be valid integer"))
|
|
.sum::<i32>();
|
|
Ok(total.to_string())
|
|
});
|
|
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
|
|
let system_prompt = SystemPromptBuilder::new()
|
|
.with_project_context(ProjectContext {
|
|
cwd: PathBuf::from("/tmp/project"),
|
|
current_date: "2026-03-31".to_string(),
|
|
git_status: None,
|
|
git_diff: None,
|
|
instruction_files: Vec::new(),
|
|
})
|
|
.with_os("linux", "6.8")
|
|
.build();
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
api_client,
|
|
tool_executor,
|
|
permission_policy,
|
|
system_prompt,
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
|
|
.expect("conversation loop should succeed");
|
|
|
|
assert_eq!(summary.iterations, 2);
|
|
assert_eq!(summary.assistant_messages.len(), 2);
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert_eq!(summary.prompt_cache_events.len(), 1);
|
|
assert_eq!(runtime.session().messages.len(), 4);
|
|
assert_eq!(summary.usage.output_tokens, 10);
|
|
assert_eq!(summary.auto_compaction, None);
|
|
assert!(matches!(
|
|
runtime.session().messages[1].blocks[1],
|
|
ContentBlock::ToolUse { .. }
|
|
));
|
|
assert!(matches!(
|
|
runtime.session().messages[2].blocks[0],
|
|
ContentBlock::ToolResult {
|
|
is_error: false,
|
|
..
|
|
}
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn records_runtime_session_trace_events() {
|
|
let sink = Arc::new(MemoryTelemetrySink::default());
|
|
let tracer = SessionTracer::new("session-runtime", sink.clone());
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
ScriptedApiClient { call_count: 0 },
|
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
|
vec!["system".to_string()],
|
|
)
|
|
.with_session_tracer(tracer);
|
|
|
|
runtime
|
|
.run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
|
|
.expect("conversation loop should succeed");
|
|
|
|
let events = sink.events();
|
|
let trace_names = events
|
|
.iter()
|
|
.filter_map(|event| match event {
|
|
TelemetryEvent::SessionTrace(trace) => Some(trace.name.as_str()),
|
|
_ => None,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
assert!(trace_names.contains(&"turn_started"));
|
|
assert!(trace_names.contains(&"assistant_iteration_completed"));
|
|
assert!(trace_names.contains(&"tool_execution_started"));
|
|
assert!(trace_names.contains(&"tool_execution_finished"));
|
|
assert!(trace_names.contains(&"turn_completed"));
|
|
}
|
|
|
|
#[test]
|
|
fn records_denied_tool_results_when_prompt_rejects() {
|
|
struct RejectPrompter;
|
|
impl PermissionPrompter for RejectPrompter {
|
|
fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
|
|
PermissionPromptDecision::Deny {
|
|
reason: "not now".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("I could not use the tool.".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: "secret".to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use the tool", Some(&mut RejectPrompter))
|
|
.expect("conversation should continue after denied tool");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
assert!(matches!(
|
|
&summary.tool_results[0].blocks[0],
|
|
ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn denies_tool_use_when_pre_tool_hook_blocks() {
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("blocked".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: r#"{"path":"secret.txt"}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new().register("blocked", |_input| {
|
|
panic!("tool should not execute when hook denies")
|
|
}),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
|
Vec::new(),
|
|
Vec::new(),
|
|
)),
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use the tool", None)
|
|
.expect("conversation should continue after hook denial");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
*is_error,
|
|
"hook denial should produce an error result: {output}"
|
|
);
|
|
assert!(
|
|
output.contains("denied tool") || output.contains("blocked by hook"),
|
|
"unexpected hook denial output: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn denies_tool_use_when_pre_tool_hook_fails() {
|
|
struct SingleCallApiClient;
|
|
impl ApiClient for SingleCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
if request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool)
|
|
{
|
|
return Ok(vec![
|
|
AssistantEvent::TextDelta("failed".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
]);
|
|
}
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "blocked".to_string(),
|
|
input: r#"{"path":"secret.txt"}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
// given
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
SingleCallApiClient,
|
|
StaticToolExecutor::new().register("blocked", |_input| {
|
|
panic!("tool should not execute when hook fails")
|
|
}),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
vec![shell_snippet("printf 'broken hook'; exit 1")],
|
|
Vec::new(),
|
|
Vec::new(),
|
|
)),
|
|
);
|
|
|
|
// when
|
|
let summary = runtime
|
|
.run_turn("use the tool", None)
|
|
.expect("conversation should continue after hook failure");
|
|
|
|
// then
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
*is_error,
|
|
"hook failure should produce an error result: {output}"
|
|
);
|
|
assert!(
|
|
output.contains("exited with status 1") || output.contains("broken hook"),
|
|
"unexpected hook failure output: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn appends_post_tool_hook_feedback_to_tool_result() {
|
|
struct TwoCallApiClient {
|
|
calls: usize,
|
|
}
|
|
|
|
impl ApiClient for TwoCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.calls += 1;
|
|
match self.calls {
|
|
1 => Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "add".to_string(),
|
|
input: r#"{"lhs":2,"rhs":2}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
]),
|
|
2 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => unreachable!("extra API call"),
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
TwoCallApiClient { calls: 0 },
|
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
vec![shell_snippet("printf 'pre hook ran'")],
|
|
vec![shell_snippet("printf 'post hook ran'")],
|
|
Vec::new(),
|
|
)),
|
|
);
|
|
|
|
let summary = runtime
|
|
.run_turn("use add", None)
|
|
.expect("tool loop succeeds");
|
|
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
!*is_error,
|
|
"post hook should preserve non-error result: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains('4'),
|
|
"tool output missing value: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("pre hook ran"),
|
|
"tool output missing pre hook feedback: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("post hook ran"),
|
|
"tool output missing post hook feedback: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn appends_post_tool_use_failure_hook_feedback_to_tool_result() {
|
|
struct TwoCallApiClient {
|
|
calls: usize,
|
|
}
|
|
|
|
impl ApiClient for TwoCallApiClient {
|
|
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
self.calls += 1;
|
|
match self.calls {
|
|
1 => Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "fail".to_string(),
|
|
input: r#"{"path":"README.md"}"#.to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
]),
|
|
2 => {
|
|
assert!(request
|
|
.messages
|
|
.iter()
|
|
.any(|message| message.role == MessageRole::Tool));
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
_ => unreachable!("extra API call"),
|
|
}
|
|
}
|
|
}
|
|
|
|
// given
|
|
let mut runtime = ConversationRuntime::new_with_features(
|
|
Session::new(),
|
|
TwoCallApiClient { calls: 0 },
|
|
StaticToolExecutor::new()
|
|
.register("fail", |_input| Err(ToolError::new("tool exploded"))),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
Vec::new(),
|
|
vec![shell_snippet("printf 'post hook should not run'")],
|
|
vec![shell_snippet("printf 'failure hook ran'")],
|
|
)),
|
|
);
|
|
|
|
// when
|
|
let summary = runtime
|
|
.run_turn("use fail", None)
|
|
.expect("tool loop succeeds");
|
|
|
|
// then
|
|
assert_eq!(summary.tool_results.len(), 1);
|
|
let ContentBlock::ToolResult {
|
|
is_error, output, ..
|
|
} = &summary.tool_results[0].blocks[0]
|
|
else {
|
|
panic!("expected tool result block");
|
|
};
|
|
assert!(
|
|
*is_error,
|
|
"failure hook path should preserve error result: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("tool exploded"),
|
|
"tool output missing failure reason: {output:?}"
|
|
);
|
|
assert!(
|
|
output.contains("failure hook ran"),
|
|
"tool output missing failure hook feedback: {output:?}"
|
|
);
|
|
assert!(
|
|
!output.contains("post hook should not run"),
|
|
"normal post hook should not run on tool failure: {output:?}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn reconstructs_usage_tracker_from_restored_session() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut session = Session::new();
|
|
session
|
|
.messages
|
|
.push(crate::session::ConversationMessage::assistant_with_usage(
|
|
vec![ContentBlock::Text {
|
|
text: "earlier".to_string(),
|
|
}],
|
|
Some(TokenUsage {
|
|
input_tokens: 11,
|
|
output_tokens: 7,
|
|
cache_creation_input_tokens: 2,
|
|
cache_read_input_tokens: 1,
|
|
}),
|
|
));
|
|
|
|
let runtime = ConversationRuntime::new(
|
|
session,
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
assert_eq!(runtime.usage().turns(), 1);
|
|
assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
|
|
}
|
|
|
|
#[test]
|
|
fn compacts_session_after_turns() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
runtime.run_turn("a", None).expect("turn a");
|
|
runtime.run_turn("b", None).expect("turn b");
|
|
runtime.run_turn("c", None).expect("turn c");
|
|
|
|
let result = runtime.compact(CompactionConfig {
|
|
preserve_recent_messages: 2,
|
|
max_estimated_tokens: 1,
|
|
});
|
|
assert!(result.summary.contains("Conversation summary"));
|
|
assert_eq!(
|
|
result.compacted_session.messages[0].role,
|
|
MessageRole::System
|
|
);
|
|
assert_eq!(
|
|
result.compacted_session.session_id,
|
|
runtime.session().session_id
|
|
);
|
|
assert!(result.compacted_session.compaction.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn persists_conversation_turn_messages_to_jsonl_session() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let path = temp_session_path("persisted-turn");
|
|
let session = Session::new().with_persistence_path(path.clone());
|
|
let mut runtime = ConversationRuntime::new(
|
|
session,
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
runtime
|
|
.run_turn("persist this turn", None)
|
|
.expect("turn should succeed");
|
|
|
|
let restored = Session::load_from_path(&path).expect("persisted session should reload");
|
|
fs::remove_file(&path).expect("temp session file should be removable");
|
|
|
|
assert_eq!(restored.messages.len(), 2);
|
|
assert_eq!(restored.messages[0].role, MessageRole::User);
|
|
assert_eq!(restored.messages[1].role, MessageRole::Assistant);
|
|
assert_eq!(restored.session_id, runtime.session().session_id);
|
|
}
|
|
|
|
#[test]
|
|
fn forks_runtime_session_without_mutating_original() {
|
|
let mut session = Session::new();
|
|
session
|
|
.push_user_text("branch me")
|
|
.expect("message should append");
|
|
|
|
let runtime = ConversationRuntime::new(
|
|
session.clone(),
|
|
ScriptedApiClient { call_count: 0 },
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
let forked = runtime.fork_session(Some("alt-path".to_string()));
|
|
|
|
assert_eq!(forked.messages, session.messages);
|
|
assert_ne!(forked.session_id, session.session_id);
|
|
assert_eq!(
|
|
forked
|
|
.fork
|
|
.as_ref()
|
|
.map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
|
|
Some((session.session_id.as_str(), Some("alt-path")))
|
|
);
|
|
assert!(runtime.session().fork.is_none());
|
|
}
|
|
|
|
fn temp_session_path(label: &str) -> PathBuf {
|
|
let nanos = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.expect("system time should be after epoch")
|
|
.as_nanos();
|
|
std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
|
|
}
|
|
|
|
#[cfg(windows)]
|
|
fn shell_snippet(script: &str) -> String {
|
|
script.replace('\'', "\"")
|
|
}
|
|
|
|
#[cfg(not(windows))]
|
|
fn shell_snippet(script: &str) -> String {
|
|
script.to_string()
|
|
}
|
|
|
|
#[test]
|
|
fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 120_000,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut session = Session::new();
|
|
session.messages = vec![
|
|
crate::session::ConversationMessage::user_text("one"),
|
|
crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
|
|
text: "two".to_string(),
|
|
}]),
|
|
crate::session::ConversationMessage::user_text("three"),
|
|
crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
|
|
text: "four".to_string(),
|
|
}]),
|
|
];
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
session,
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
)
|
|
.with_auto_compaction_input_tokens_threshold(100_000);
|
|
|
|
let summary = runtime
|
|
.run_turn("trigger", None)
|
|
.expect("turn should succeed");
|
|
|
|
assert_eq!(
|
|
summary.auto_compaction,
|
|
Some(AutoCompactionEvent {
|
|
removed_message_count: 2,
|
|
})
|
|
);
|
|
assert_eq!(runtime.session().messages[0].role, MessageRole::System);
|
|
}
|
|
|
|
#[test]
|
|
fn skips_auto_compaction_below_threshold() {
|
|
struct SimpleApi;
|
|
impl ApiClient for SimpleApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::TextDelta("done".to_string()),
|
|
AssistantEvent::Usage(TokenUsage {
|
|
input_tokens: 99_999,
|
|
output_tokens: 4,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
}),
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
SimpleApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
)
|
|
.with_auto_compaction_input_tokens_threshold(100_000);
|
|
|
|
let summary = runtime
|
|
.run_turn("trigger", None)
|
|
.expect("turn should succeed");
|
|
assert_eq!(summary.auto_compaction, None);
|
|
assert_eq!(runtime.session().messages.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn auto_compaction_threshold_defaults_and_parses_values() {
|
|
assert_eq!(
|
|
parse_auto_compaction_threshold(None),
|
|
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
|
|
);
|
|
assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
|
|
assert_eq!(
|
|
parse_auto_compaction_threshold(Some("0")),
|
|
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
|
|
);
|
|
assert_eq!(
|
|
parse_auto_compaction_threshold(Some("not-a-number")),
|
|
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn build_assistant_message_requires_message_stop_event() {
|
|
// given
|
|
let events = vec![AssistantEvent::TextDelta("hello".to_string())];
|
|
|
|
// when
|
|
let error = build_assistant_message(events)
|
|
.expect_err("assistant messages should require a stop event");
|
|
|
|
// then
|
|
assert!(error
|
|
.to_string()
|
|
.contains("assistant stream ended without a message stop event"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_assistant_message_requires_content() {
|
|
// given
|
|
let events = vec![AssistantEvent::MessageStop];
|
|
|
|
// when
|
|
let error =
|
|
build_assistant_message(events).expect_err("assistant messages should require content");
|
|
|
|
// then
|
|
assert!(error
|
|
.to_string()
|
|
.contains("assistant stream produced no content"));
|
|
}
|
|
|
|
#[test]
|
|
fn static_tool_executor_rejects_unknown_tools() {
|
|
// given
|
|
let mut executor = StaticToolExecutor::new();
|
|
|
|
// when
|
|
let error = executor
|
|
.execute("missing", "{}")
|
|
.expect_err("unregistered tools should fail");
|
|
|
|
// then
|
|
assert_eq!(error.to_string(), "unknown tool: missing");
|
|
}
|
|
|
|
#[test]
|
|
fn run_turn_errors_when_max_iterations_is_exceeded() {
|
|
struct LoopingApi;
|
|
|
|
impl ApiClient for LoopingApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Ok(vec![
|
|
AssistantEvent::ToolUse {
|
|
id: "tool-1".to_string(),
|
|
name: "echo".to_string(),
|
|
input: "payload".to_string(),
|
|
},
|
|
AssistantEvent::MessageStop,
|
|
])
|
|
}
|
|
}
|
|
|
|
// given
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
LoopingApi,
|
|
StaticToolExecutor::new().register("echo", |input| Ok(input.to_string())),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
)
|
|
.with_max_iterations(1);
|
|
|
|
// when
|
|
let error = runtime
|
|
.run_turn("loop", None)
|
|
.expect_err("conversation loop should stop after the configured limit");
|
|
|
|
// then
|
|
assert!(error
|
|
.to_string()
|
|
.contains("conversation loop exceeded the maximum number of iterations"));
|
|
}
|
|
|
|
#[test]
|
|
fn run_turn_propagates_api_errors() {
|
|
struct FailingApi;
|
|
|
|
impl ApiClient for FailingApi {
|
|
fn stream(
|
|
&mut self,
|
|
_request: ApiRequest,
|
|
) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
Err(RuntimeError::new("upstream failed"))
|
|
}
|
|
}
|
|
|
|
// given
|
|
let mut runtime = ConversationRuntime::new(
|
|
Session::new(),
|
|
FailingApi,
|
|
StaticToolExecutor::new(),
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
vec!["system".to_string()],
|
|
);
|
|
|
|
// when
|
|
let error = runtime
|
|
.run_turn("hello", None)
|
|
.expect_err("API failures should propagate");
|
|
|
|
// then
|
|
assert_eq!(error.to_string(), "upstream failed");
|
|
}
|
|
}
|