diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index ff2562d68..464c0a06b 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -147,6 +147,7 @@ create_session_fs_adapter, ) from .tools import ( + AbortSignal, Tool, ToolBinaryResult, ToolInvocation, @@ -168,6 +169,7 @@ "AutoModeSwitchHandler", "AutoModeSwitchRequest", "AutoModeSwitchResponse", + "AbortSignal", "BUILTIN_TOOLS_ISOLATED", "CanvasAction", "CanvasDeclaration", diff --git a/python/copilot/session.py b/python/copilot/session.py index 3720af05d..42aef4922 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -77,7 +77,7 @@ from .generated.session_events import ( ReasoningSummary as _RpcReasoningSummary, ) -from .tools import Tool, ToolHandler, ToolInvocation, ToolResult +from .tools import AbortController, Tool, ToolHandler, ToolInvocation, ToolResult logger = logging.getLogger(__name__) @@ -1190,6 +1190,7 @@ def __init__( self._event_handlers_lock = threading.Lock() self._tool_handlers: dict[str, ToolHandler] = {} self._tool_handlers_lock = threading.Lock() + self._in_flight_tool_calls: dict[str, AbortController] = {} self._permission_handler: _PermissionHandlerFn | None = None self._permission_handler_lock = threading.Lock() self._user_input_handler: UserInputHandler | None = None @@ -1655,12 +1656,15 @@ async def _execute_tool_and_respond( tracestate: str | None = None, ) -> None: """Execute a tool handler and send the result back via HandlePendingToolCall RPC.""" + abort_controller = AbortController() + self._in_flight_tool_calls[tool_call_id] = abort_controller try: invocation = ToolInvocation( session_id=self.session_id, tool_call_id=tool_call_id, tool_name=tool_name, arguments=arguments, + signal=abort_controller.signal, ) with trace_context(traceparent, tracestate): @@ -1745,6 +1749,11 @@ async def _execute_tool_and_respond( ) except (JsonRpcError, ProcessExitedError, OSError): pass # Connection lost or RPC error — nothing we can do + finally: + # Only clear if this is still the controller for this toolCallId; + # guards against a recycled toolCallId from a later invocation. + if self._in_flight_tool_calls.get(tool_call_id) is abort_controller: + del self._in_flight_tool_calls[tool_call_id] async def _execute_permission_and_respond( self, @@ -2421,6 +2430,9 @@ async def disconnect(self) -> None: self._destroyed = True try: + # Abort any in-flight tool handlers so they can release resources. + self._abort_in_flight_tool_calls() + self._in_flight_tool_calls.clear() await self._client.request("session.destroy", {"sessionId": self.session_id}) finally: # Clear handlers even if the request fails. @@ -2460,6 +2472,12 @@ async def abort(self) -> None: """ Abort the currently processing message in this session. + This cancels the agentic loop and propagates cancellation to all + in-flight tool handlers via their :class:`~copilot.AbortSignal` + (passed in :attr:`~copilot.ToolInvocation.signal`). Tool handlers can + check the signal with ``invocation.signal.is_aborted`` or await + ``invocation.signal.wait()``. + Use this to cancel a long-running request. The session remains valid and can continue to be used for new messages. @@ -2476,8 +2494,44 @@ async def abort(self) -> None: >>> await asyncio.sleep(5) >>> await session.abort() """ + # Abort all in-flight tool handlers + self._abort_in_flight_tool_calls() await self._client.request("session.abort", {"sessionId": self.session_id}) + def cancel_tool_call(self, tool_call_id: str) -> bool: + """ + Cancel a single in-flight tool handler without aborting the agentic loop. + + Signals only the handler identified by *tool_call_id* via its + :class:`~copilot.AbortSignal`; all other concurrent handlers are + unaffected. The session remains valid and the agentic loop continues. + + Args: + tool_call_id: The ``tool_call_id`` of the in-flight tool call to cancel. + + Returns: + ``True`` if a matching in-flight call was found and its signal was + triggered; ``False`` if no call with that ID is currently running. + + Example:: + + session.on(lambda event: ( + session.cancel_tool_call(event.data.tool_call_id) + if event.type.value == "tool.execution_start" + else None + )) + """ + controller = self._in_flight_tool_calls.get(tool_call_id) + if not controller: + return False + controller.abort() + return True + + def _abort_in_flight_tool_calls(self) -> None: + """Abort the AbortSignal for every in-flight tool handler.""" + for controller in self._in_flight_tool_calls.values(): + controller.abort() + async def set_model( self, model: str, diff --git a/python/copilot/tools.py b/python/copilot/tools.py index a82a48b1e..2a6843400 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -7,6 +7,7 @@ from __future__ import annotations +import asyncio import inspect import json from collections.abc import Awaitable, Callable @@ -18,6 +19,63 @@ ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"] +class AbortSignal: + """ + A signal object that allows monitoring whether an abort has been requested. + + Passed to tool handlers via :attr:`ToolInvocation.signal` so they can + cooperatively cancel in-flight work when :meth:`~copilot.CopilotSession.abort` + is called. + + Example:: + + @define_tool(description="Fetch remote data") + async def fetch_data(params: Params, inv: ToolInvocation) -> str: + if inv.signal.is_aborted: + return "cancelled" + data = await fetch_with_signal(params.url, inv.signal) + return data + + """ + + def __init__(self) -> None: + self._event: asyncio.Event = asyncio.Event() + + @property + def is_aborted(self) -> bool: + """``True`` if :meth:`~AbortController.abort` has been called.""" + return self._event.is_set() + + async def wait(self) -> None: + """Coroutine that completes when the signal is aborted.""" + await self._event.wait() + + def _abort(self) -> None: + """Internal: trigger the signal. Called by :class:`AbortController`.""" + self._event.set() + + +class AbortController: + """ + A controller that creates and manages an :class:`AbortSignal`. + + Call :meth:`abort` to cancel all in-flight tool handlers that hold the + associated :attr:`signal`. + """ + + def __init__(self) -> None: + self._signal: AbortSignal = AbortSignal() + + @property + def signal(self) -> AbortSignal: + """The :class:`AbortSignal` managed by this controller.""" + return self._signal + + def abort(self) -> None: + """Trigger the signal, notifying all handlers that abort has been requested.""" + self._signal._abort() + + @dataclass class ToolBinaryResult: """Binary content returned by a tool.""" @@ -49,6 +107,19 @@ class ToolInvocation: tool_call_id: str = "" tool_name: str = "" arguments: Any = None + signal: AbortSignal | None = None + """Optional AbortSignal for cooperative cancellation. + + When a ``ToolInvocation`` is constructed by :class:`~copilot.CopilotSession` + during tool dispatch, this field is set to the signal managed by the + session — it is triggered when :meth:`~copilot.CopilotSession.abort` or + :meth:`~copilot.CopilotSession.cancel_tool_call` is called. + + When a ``ToolInvocation`` is constructed manually without a signal (e.g. in + tests), this defaults to ``None``. Handlers that consume the signal should + guard for ``None`` (``if inv.signal and inv.signal.is_aborted``) or inject + one explicitly: ``ToolInvocation(..., signal=my_controller.signal)``. + """ ToolHandler = Callable[[ToolInvocation], ToolResult | Awaitable[ToolResult]] diff --git a/python/test_tools.py b/python/test_tools.py index d583b59c0..1189bfa4b 100644 --- a/python/test_tools.py +++ b/python/test_tools.py @@ -427,3 +427,120 @@ def test_call_tool_result_dict_is_json_serialized_by_normalize(self): result = _normalize_result({"content": [{"type": "text", "text": "hello"}]}) parsed = json.loads(result.text_result_for_llm) assert parsed == {"content": [{"type": "text", "text": "hello"}]} + + +class TestAbortSignal: + def test_not_aborted_initially(self): + from copilot.tools import AbortController + + controller = AbortController() + assert controller.signal.is_aborted is False + + def test_aborted_after_abort_called(self): + from copilot.tools import AbortController + + controller = AbortController() + controller.abort() + assert controller.signal.is_aborted is True + + async def test_wait_returns_after_abort(self): + from copilot.tools import AbortController + + controller = AbortController() + signal = controller.signal + + async def aborter(): + import asyncio + + await asyncio.sleep(0) + controller.abort() + + import asyncio + + await asyncio.gather(signal.wait(), aborter()) + assert signal.is_aborted is True + + def test_tool_invocation_signal_defaults_to_none(self): + inv = ToolInvocation(session_id="s1", tool_call_id="c1", tool_name="t1") + assert inv.signal is None + + def test_tool_invocation_accepts_injected_signal(self): + from copilot.tools import AbortController, AbortSignal + + controller = AbortController() + inv = ToolInvocation( + session_id="s1", tool_call_id="c1", tool_name="t1", signal=controller.signal + ) + assert isinstance(inv.signal, AbortSignal) + assert inv.signal.is_aborted is False + controller.abort() + assert inv.signal.is_aborted is True + + async def test_handler_receives_signal_via_invocation(self): + from copilot.tools import AbortController + + received_signal = None + controller = AbortController() + + class Params(BaseModel): + pass + + @define_tool("test_signal", description="Test signal propagation") + def test_tool(params: Params, inv: ToolInvocation) -> str: + nonlocal received_signal + received_signal = inv.signal + return "ok" + + invocation = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="test_signal", + arguments={}, + signal=controller.signal, + ) + + await test_tool.handler(invocation) + + assert received_signal is controller.signal + assert received_signal.is_aborted is False + + controller.abort() + assert received_signal.is_aborted is True + + +class TestCancelToolCall: + def test_returns_false_for_unknown_id(self): + from copilot.session import CopilotSession + + session = CopilotSession("sess-1", client=None) + assert session.cancel_tool_call("nonexistent") is False + + def test_returns_true_and_aborts_signal_for_known_id(self): + from copilot.session import CopilotSession + from copilot.tools import AbortController + + session = CopilotSession("sess-1", client=None) + controller = AbortController() + session._in_flight_tool_calls["call-1"] = controller + + result = session.cancel_tool_call("call-1") + + assert result is True + assert controller.signal.is_aborted is True + + def test_cancels_only_targeted_handler(self): + """cancel_tool_call aborts only the targeted handler; others are unaffected.""" + from copilot.session import CopilotSession + from copilot.tools import AbortController + + session = CopilotSession("sess-1", client=None) + controller_a = AbortController() + controller_b = AbortController() + session._in_flight_tool_calls["call-a"] = controller_a + session._in_flight_tool_calls["call-b"] = controller_b + + result = session.cancel_tool_call("call-a") + + assert result is True + assert controller_a.signal.is_aborted is True + assert controller_b.signal.is_aborted is False