Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/copilot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
create_session_fs_adapter,
)
from .tools import (
AbortSignal,
Tool,
ToolBinaryResult,
ToolInvocation,
Expand All @@ -168,6 +169,7 @@
"AutoModeSwitchHandler",
"AutoModeSwitchRequest",
"AutoModeSwitchResponse",
"AbortSignal",
"BUILTIN_TOOLS_ISOLATED",
"CanvasAction",
"CanvasDeclaration",
Expand Down
56 changes: 55 additions & 1 deletion python/copilot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .generated.session_events import (
ReasoningSummary as _RpcReasoningSummary,
)
from .tools import Tool, ToolHandler, ToolInvocation, ToolResult
from .tools import AbortController, Tool, ToolHandler, ToolInvocation, ToolResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1190,6 +1190,7 @@ def __init__(
self._event_handlers_lock = threading.Lock()
self._tool_handlers: dict[str, ToolHandler] = {}
self._tool_handlers_lock = threading.Lock()
self._in_flight_tool_calls: dict[str, AbortController] = {}
self._permission_handler: _PermissionHandlerFn | None = None
self._permission_handler_lock = threading.Lock()
self._user_input_handler: UserInputHandler | None = None
Expand Down Expand Up @@ -1655,12 +1656,15 @@ async def _execute_tool_and_respond(
tracestate: str | None = None,
) -> None:
"""Execute a tool handler and send the result back via HandlePendingToolCall RPC."""
abort_controller = AbortController()
self._in_flight_tool_calls[tool_call_id] = abort_controller
try:
invocation = ToolInvocation(
session_id=self.session_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
arguments=arguments,
signal=abort_controller.signal,
)
Comment thread
gimenete marked this conversation as resolved.

with trace_context(traceparent, tracestate):
Expand Down Expand Up @@ -1745,6 +1749,11 @@ async def _execute_tool_and_respond(
)
except (JsonRpcError, ProcessExitedError, OSError):
pass # Connection lost or RPC error — nothing we can do
finally:
# Only clear if this is still the controller for this toolCallId;
# guards against a recycled toolCallId from a later invocation.
if self._in_flight_tool_calls.get(tool_call_id) is abort_controller:
del self._in_flight_tool_calls[tool_call_id]

async def _execute_permission_and_respond(
self,
Expand Down Expand Up @@ -2421,6 +2430,9 @@ async def disconnect(self) -> None:
self._destroyed = True

try:
# Abort any in-flight tool handlers so they can release resources.
self._abort_in_flight_tool_calls()
self._in_flight_tool_calls.clear()
await self._client.request("session.destroy", {"sessionId": self.session_id})
finally:
# Clear handlers even if the request fails.
Expand Down Expand Up @@ -2460,6 +2472,12 @@ async def abort(self) -> None:
"""
Abort the currently processing message in this session.

This cancels the agentic loop and propagates cancellation to all
in-flight tool handlers via their :class:`~copilot.AbortSignal`
(passed in :attr:`~copilot.ToolInvocation.signal`). Tool handlers can
check the signal with ``invocation.signal.is_aborted`` or await
``invocation.signal.wait()``.

Use this to cancel a long-running request. The session remains valid
and can continue to be used for new messages.

Expand All @@ -2476,8 +2494,44 @@ async def abort(self) -> None:
>>> await asyncio.sleep(5)
>>> await session.abort()
"""
# Abort all in-flight tool handlers
self._abort_in_flight_tool_calls()
await self._client.request("session.abort", {"sessionId": self.session_id})

def cancel_tool_call(self, tool_call_id: str) -> bool:
"""
Cancel a single in-flight tool handler without aborting the agentic loop.

Signals only the handler identified by *tool_call_id* via its
:class:`~copilot.AbortSignal`; all other concurrent handlers are
unaffected. The session remains valid and the agentic loop continues.

Args:
tool_call_id: The ``tool_call_id`` of the in-flight tool call to cancel.

Returns:
``True`` if a matching in-flight call was found and its signal was
triggered; ``False`` if no call with that ID is currently running.

Example::

session.on(lambda event: (
session.cancel_tool_call(event.data.tool_call_id)
if event.type.value == "tool.execution_start"
else None
))
"""
controller = self._in_flight_tool_calls.get(tool_call_id)
if not controller:
return False
controller.abort()
return True

def _abort_in_flight_tool_calls(self) -> None:
"""Abort the AbortSignal for every in-flight tool handler."""
for controller in self._in_flight_tool_calls.values():
controller.abort()

async def set_model(
self,
model: str,
Expand Down
71 changes: 71 additions & 0 deletions python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import asyncio
import inspect
import json
from collections.abc import Awaitable, Callable
Expand All @@ -18,6 +19,63 @@
ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"]


class AbortSignal:
"""
A signal object that allows monitoring whether an abort has been requested.

Passed to tool handlers via :attr:`ToolInvocation.signal` so they can
cooperatively cancel in-flight work when :meth:`~copilot.CopilotSession.abort`
is called.

Example::

@define_tool(description="Fetch remote data")
async def fetch_data(params: Params, inv: ToolInvocation) -> str:
if inv.signal.is_aborted:
return "cancelled"
data = await fetch_with_signal(params.url, inv.signal)
return data

"""

def __init__(self) -> None:
self._event: asyncio.Event = asyncio.Event()

@property
def is_aborted(self) -> bool:
"""``True`` if :meth:`~AbortController.abort` has been called."""
return self._event.is_set()

async def wait(self) -> None:
"""Coroutine that completes when the signal is aborted."""
await self._event.wait()

def _abort(self) -> None:
"""Internal: trigger the signal. Called by :class:`AbortController`."""
self._event.set()


class AbortController:
"""
A controller that creates and manages an :class:`AbortSignal`.

Call :meth:`abort` to cancel all in-flight tool handlers that hold the
associated :attr:`signal`.
"""

def __init__(self) -> None:
self._signal: AbortSignal = AbortSignal()

@property
def signal(self) -> AbortSignal:
"""The :class:`AbortSignal` managed by this controller."""
return self._signal

def abort(self) -> None:
"""Trigger the signal, notifying all handlers that abort has been requested."""
self._signal._abort()


@dataclass
class ToolBinaryResult:
"""Binary content returned by a tool."""
Expand Down Expand Up @@ -49,6 +107,19 @@ class ToolInvocation:
tool_call_id: str = ""
tool_name: str = ""
arguments: Any = None
signal: AbortSignal | None = None
"""Optional AbortSignal for cooperative cancellation.

When a ``ToolInvocation`` is constructed by :class:`~copilot.CopilotSession`
during tool dispatch, this field is set to the signal managed by the
session — it is triggered when :meth:`~copilot.CopilotSession.abort` or
:meth:`~copilot.CopilotSession.cancel_tool_call` is called.

When a ``ToolInvocation`` is constructed manually without a signal (e.g. in
tests), this defaults to ``None``. Handlers that consume the signal should
guard for ``None`` (``if inv.signal and inv.signal.is_aborted``) or inject
one explicitly: ``ToolInvocation(..., signal=my_controller.signal)``.
"""


ToolHandler = Callable[[ToolInvocation], ToolResult | Awaitable[ToolResult]]
Expand Down
117 changes: 117 additions & 0 deletions python/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,120 @@ def test_call_tool_result_dict_is_json_serialized_by_normalize(self):
result = _normalize_result({"content": [{"type": "text", "text": "hello"}]})
parsed = json.loads(result.text_result_for_llm)
assert parsed == {"content": [{"type": "text", "text": "hello"}]}


class TestAbortSignal:
def test_not_aborted_initially(self):
from copilot.tools import AbortController

controller = AbortController()
assert controller.signal.is_aborted is False

def test_aborted_after_abort_called(self):
from copilot.tools import AbortController

controller = AbortController()
controller.abort()
assert controller.signal.is_aborted is True

async def test_wait_returns_after_abort(self):
from copilot.tools import AbortController

controller = AbortController()
signal = controller.signal

async def aborter():
import asyncio

await asyncio.sleep(0)
controller.abort()

import asyncio

await asyncio.gather(signal.wait(), aborter())
assert signal.is_aborted is True

def test_tool_invocation_signal_defaults_to_none(self):
inv = ToolInvocation(session_id="s1", tool_call_id="c1", tool_name="t1")
assert inv.signal is None

def test_tool_invocation_accepts_injected_signal(self):
from copilot.tools import AbortController, AbortSignal

controller = AbortController()
inv = ToolInvocation(
session_id="s1", tool_call_id="c1", tool_name="t1", signal=controller.signal
)
assert isinstance(inv.signal, AbortSignal)
assert inv.signal.is_aborted is False
controller.abort()
assert inv.signal.is_aborted is True

async def test_handler_receives_signal_via_invocation(self):
from copilot.tools import AbortController

received_signal = None
controller = AbortController()

class Params(BaseModel):
pass

@define_tool("test_signal", description="Test signal propagation")
def test_tool(params: Params, inv: ToolInvocation) -> str:
nonlocal received_signal
received_signal = inv.signal
return "ok"

invocation = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="test_signal",
arguments={},
signal=controller.signal,
)

await test_tool.handler(invocation)

assert received_signal is controller.signal
assert received_signal.is_aborted is False

controller.abort()
assert received_signal.is_aborted is True


class TestCancelToolCall:
def test_returns_false_for_unknown_id(self):
from copilot.session import CopilotSession

session = CopilotSession("sess-1", client=None)
assert session.cancel_tool_call("nonexistent") is False

def test_returns_true_and_aborts_signal_for_known_id(self):
from copilot.session import CopilotSession
from copilot.tools import AbortController

session = CopilotSession("sess-1", client=None)
controller = AbortController()
session._in_flight_tool_calls["call-1"] = controller

result = session.cancel_tool_call("call-1")

assert result is True
assert controller.signal.is_aborted is True

def test_cancels_only_targeted_handler(self):
"""cancel_tool_call aborts only the targeted handler; others are unaffected."""
from copilot.session import CopilotSession
from copilot.tools import AbortController

session = CopilotSession("sess-1", client=None)
controller_a = AbortController()
controller_b = AbortController()
session._in_flight_tool_calls["call-a"] = controller_a
session._in_flight_tool_calls["call-b"] = controller_b

result = session.cancel_tool_call("call-a")

assert result is True
assert controller_a.signal.is_aborted is True
assert controller_b.signal.is_aborted is False