Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion rust/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ pub struct Session {
/// via [`Session::cancellation_token`] to bind their own work to
/// the session lifetime.
shutdown: CancellationToken,
/// Cancellation token broadcast to all in-flight tool handlers.
///
/// [`Session::abort`] cancels the current token (signalling all running
/// handlers) and then replaces it with a fresh child of `shutdown` so
/// subsequent tool calls are not pre-cancelled. Shared between the
/// `Session` handle and the event loop via `Arc<ParkingLotMutex<…>>`.
tool_abort: Arc<ParkingLotMutex<CancellationToken>>,
/// Only populated while a `send_and_wait` call is in flight.
///
/// Sync `parking_lot::Mutex` because the lock is never held across an
Expand Down Expand Up @@ -500,12 +507,25 @@ impl Session {

/// Abort the current agent turn.
///
/// Cancels the agentic loop and propagates cancellation to all in-flight
/// tool handlers via the [`CancellationToken`] on each
/// [`ToolInvocation`](crate::types::ToolInvocation). Handlers can check
/// [`is_cancelled()`](CancellationToken::is_cancelled) or `select!` on
/// [`cancelled()`](CancellationToken::cancelled) to stop early.
///
/// # Cancel safety
///
/// **Cancel-safe.** Single `session.abort` RPC; the underlying
/// [`Client::call`](crate::Client::call) is cancel-safe via the
/// writer-actor.
pub async fn abort(&self) -> Result<(), Error> {
// Signal all in-flight tool handlers before sending the RPC so that
// handlers can begin cleanup while the network round-trip is in flight.
{
let mut guard = self.tool_abort.lock();
guard.cancel();
*guard = self.shutdown.child_token();
}
self.client
Comment thread
gimenete marked this conversation as resolved.
.call(
"session.abort",
Expand Down Expand Up @@ -916,6 +936,7 @@ impl Client {
let idle_waiter = Arc::new(ParkingLotMutex::new(None));
let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
let shutdown = CancellationToken::new();
let tool_abort = Arc::new(ParkingLotMutex::new(shutdown.child_token()));
let (event_tx, _) = tokio::sync::broadcast::channel(512);

// For cloud sessions (use_server_generated_id), defer session
Expand Down Expand Up @@ -1017,6 +1038,7 @@ impl Client {
open_canvases.clone(),
event_tx.clone(),
shutdown.clone(),
tool_abort.clone(),
);
tracing::debug!(
elapsed_ms = setup_start.elapsed().as_millis(),
Expand All @@ -1041,6 +1063,7 @@ impl Client {
client: self.clone(),
event_loop: ParkingLotMutex::new(Some(event_loop)),
shutdown,
tool_abort,
idle_waiter,
capabilities,
open_canvases,
Expand Down Expand Up @@ -1173,6 +1196,7 @@ impl Client {
let idle_waiter = Arc::new(ParkingLotMutex::new(None));
let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
let shutdown = CancellationToken::new();
let tool_abort = Arc::new(ParkingLotMutex::new(shutdown.child_token()));
let (event_tx, _) = tokio::sync::broadcast::channel(512);
let event_loop = spawn_event_loop(
session_id.clone(),
Expand All @@ -1189,6 +1213,7 @@ impl Client {
open_canvases.clone(),
event_tx.clone(),
shutdown.clone(),
tool_abort.clone(),
);
let mut registration =
PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone());
Expand Down Expand Up @@ -1284,6 +1309,7 @@ impl Client {
client: self.clone(),
event_loop: ParkingLotMutex::new(Some(event_loop)),
shutdown,
tool_abort,
idle_waiter,
capabilities,
open_canvases,
Expand Down Expand Up @@ -1397,6 +1423,7 @@ fn spawn_event_loop(
open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
shutdown: CancellationToken,
tool_abort: Arc<ParkingLotMutex<CancellationToken>>,
) -> JoinHandle<()> {
let crate::router::SessionChannels {
mut notifications,
Expand All @@ -1421,7 +1448,7 @@ fn spawn_event_loop(
_ = shutdown.cancelled() => break,
Some(notification) = notifications.recv() => {
handle_notification(
&session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx,
&session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx, &tool_abort,
).await;
}
Some(request) = requests.recv() => {
Expand Down Expand Up @@ -1494,6 +1521,7 @@ async fn handle_notification(
capabilities: &Arc<parking_lot::RwLock<SessionCapabilities>>,
open_canvases: &Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
event_tx: &tokio::sync::broadcast::Sender<SessionEvent>,
tool_abort: &Arc<ParkingLotMutex<CancellationToken>>,
) {
let dispatch_start = Instant::now();
let event = notification.event.clone();
Expand Down Expand Up @@ -1741,6 +1769,7 @@ async fn handle_notification(
session_id = %sid,
request_id = %request_id
);
let tool_abort = tool_abort.clone();
tokio::spawn(
async move {
// `tool_name.is_empty()` would have produced a `None`
Expand Down Expand Up @@ -1770,13 +1799,15 @@ async fn handle_notification(
}
let tool_call_id = data.tool_call_id.clone();
let tool_name = data.tool_name.clone();
let cancellation_token = tool_abort.lock().child_token();
let invocation = ToolInvocation {
session_id: sid.clone(),
tool_call_id: data.tool_call_id,
tool_name: data.tool_name,
arguments: data
.arguments
.unwrap_or(Value::Object(serde_json::Map::new())),
cancellation_token,
traceparent: data.traceparent,
tracestate: data.tracestate,
};
Comment thread
gimenete marked this conversation as resolved.
Outdated
Expand Down
51 changes: 41 additions & 10 deletions rust/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ mod tests {
tool_call_id: "tc1".to_string(),
tool_name: "echo".to_string(),
arguments: serde_json::json!({"msg": "hello"}),
traceparent: None,
tracestate: None,
..Default::default()
};

let result = tool.call(inv).await.unwrap();
Expand Down Expand Up @@ -606,8 +605,7 @@ mod tests {
tool_call_id: "tc1".to_string(),
tool_name: "weather".to_string(),
arguments: serde_json::json!({"city": "Seattle"}),
traceparent: None,
tracestate: None,
..Default::default()
};
match handler.call(inv).await.unwrap() {
ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
Expand Down Expand Up @@ -688,8 +686,7 @@ mod tests {
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
traceparent: None,
tracestate: None,
..Default::default()
};

let result = tool.call(inv).await.unwrap();
Expand All @@ -707,8 +704,7 @@ mod tests {
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"wrong_field": 42}),
traceparent: None,
tracestate: None,
..Default::default()
};

let err = tool.call(inv).await.unwrap_err();
Expand All @@ -728,8 +724,7 @@ mod tests {
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "Portland"}),
traceparent: None,
tracestate: None,
..Default::default()
})
.await
.expect("ToolHandler::call should succeed for matching args");
Expand All @@ -739,4 +734,40 @@ mod tests {
}
}
}

#[tokio::test]
async fn tool_invocation_cancellation_token_fires_on_cancel() {
let token = tokio_util::sync::CancellationToken::new();
let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "echo".to_string(),
arguments: serde_json::json!({}),
cancellation_token: token.clone(),
..Default::default()
};

assert!(!inv.cancellation_token.is_cancelled());
token.cancel();
assert!(inv.cancellation_token.is_cancelled());
}

#[tokio::test]
async fn tool_invocation_child_token_cancelled_when_parent_fires() {
let parent = tokio_util::sync::CancellationToken::new();
let child = parent.child_token();

let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "echo".to_string(),
arguments: serde_json::json!({}),
cancellation_token: child,
..Default::default()
};

assert!(!inv.cancellation_token.is_cancelled());
parent.cancel();
assert!(inv.cancellation_token.is_cancelled());
}
}
14 changes: 14 additions & 0 deletions rust/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::time::Duration;

use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;

use crate::canvas::{CanvasDeclaration, CanvasHandler};
use crate::generated::api_types::OpenCanvasInstance;
Expand Down Expand Up @@ -3934,6 +3935,19 @@ pub struct ToolInvocation {
pub tool_name: String,
/// Tool arguments as JSON.
pub arguments: Value,
/// Cancellation signal for this tool invocation.
///
/// Fires when [`Session::abort`](crate::Session::abort) is called while
/// this handler is in flight. Handlers can check
/// [`is_cancelled()`](CancellationToken::is_cancelled) or `select!` on
/// [`cancelled()`](CancellationToken::cancelled) to cooperatively stop
/// work early. Handlers that don't need cancellation can ignore this field.
///
/// The token is already cancelled for handlers that are dispatched after
/// an `abort()` call, so they can check the flag at entry and return
/// immediately if desired.
#[serde(skip)]
pub cancellation_token: CancellationToken,
Comment thread
gimenete marked this conversation as resolved.
Outdated
/// W3C Trace Context `traceparent` header propagated from the CLI's
/// `execute_tool` span. Pass through to OpenTelemetry-aware code so
/// child spans created inside the handler are parented to the CLI
Expand Down