Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/copilot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
create_session_fs_adapter,
)
from .tools import (
AbortSignal,
Tool,
ToolBinaryResult,
ToolInvocation,
Expand All @@ -168,6 +169,7 @@
"AutoModeSwitchHandler",
"AutoModeSwitchRequest",
"AutoModeSwitchResponse",
"AbortSignal",
"BUILTIN_TOOLS_ISOLATED",
"CanvasAction",
"CanvasDeclaration",
Expand Down
14 changes: 13 additions & 1 deletion python/copilot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .generated.session_events import (
ReasoningSummary as _RpcReasoningSummary,
)
from .tools import Tool, ToolHandler, ToolInvocation, ToolResult
from .tools import AbortController, Tool, ToolHandler, ToolInvocation, ToolResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1190,6 +1190,7 @@ def __init__(
self._event_handlers_lock = threading.Lock()
self._tool_handlers: dict[str, ToolHandler] = {}
self._tool_handlers_lock = threading.Lock()
self._tool_abort_controller: AbortController = AbortController()
Comment thread
gimenete marked this conversation as resolved.
Outdated
self._permission_handler: _PermissionHandlerFn | None = None
self._permission_handler_lock = threading.Lock()
self._user_input_handler: UserInputHandler | None = None
Expand Down Expand Up @@ -1661,6 +1662,7 @@ async def _execute_tool_and_respond(
tool_call_id=tool_call_id,
tool_name=tool_name,
arguments=arguments,
signal=self._tool_abort_controller.signal,
)
Comment thread
gimenete marked this conversation as resolved.

with trace_context(traceparent, tracestate):
Expand Down Expand Up @@ -2460,6 +2462,12 @@ async def abort(self) -> None:
"""
Abort the currently processing message in this session.

This cancels the agentic loop and propagates cancellation to all
in-flight tool handlers via their :class:`~copilot.AbortSignal`
(passed in :attr:`~copilot.ToolInvocation.signal`). Tool handlers can
check the signal with ``invocation.signal.is_aborted`` or await
``invocation.signal.wait()``.

Use this to cancel a long-running request. The session remains valid
and can continue to be used for new messages.

Expand All @@ -2476,6 +2484,10 @@ async def abort(self) -> None:
>>> await asyncio.sleep(5)
>>> await session.abort()
"""
# Abort all in-flight tool handlers
self._tool_abort_controller.abort()
# Create a new controller for future tool calls
self._tool_abort_controller = AbortController()
Comment thread
gimenete marked this conversation as resolved.
Outdated
await self._client.request("session.abort", {"sessionId": self.session_id})

async def set_model(
Expand Down
60 changes: 60 additions & 0 deletions python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import asyncio
import inspect
import json
from collections.abc import Awaitable, Callable
Expand All @@ -18,6 +19,63 @@
ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"]


class AbortSignal:
"""
A signal object that allows monitoring whether an abort has been requested.

Passed to tool handlers via :attr:`ToolInvocation.signal` so they can
cooperatively cancel in-flight work when :meth:`~copilot.CopilotSession.abort`
is called.

Example::

@define_tool(description="Fetch remote data")
async def fetch_data(params: Params, inv: ToolInvocation) -> str:
if inv.signal.is_aborted:
return "cancelled"
data = await fetch_with_signal(params.url, inv.signal)
return data

"""

def __init__(self) -> None:
self._event: asyncio.Event = asyncio.Event()

@property
def is_aborted(self) -> bool:
"""``True`` if :meth:`~AbortController.abort` has been called."""
return self._event.is_set()

async def wait(self) -> None:
"""Coroutine that completes when the signal is aborted."""
await self._event.wait()

def _abort(self) -> None:
"""Internal: trigger the signal. Called by :class:`AbortController`."""
self._event.set()


class AbortController:
"""
A controller that creates and manages an :class:`AbortSignal`.

Call :meth:`abort` to cancel all in-flight tool handlers that hold the
associated :attr:`signal`.
"""

def __init__(self) -> None:
self._signal: AbortSignal = AbortSignal()

@property
def signal(self) -> AbortSignal:
"""The :class:`AbortSignal` managed by this controller."""
return self._signal

def abort(self) -> None:
"""Trigger the signal, notifying all handlers that abort has been requested."""
self._signal._abort()


@dataclass
class ToolBinaryResult:
"""Binary content returned by a tool."""
Expand Down Expand Up @@ -49,6 +107,8 @@ class ToolInvocation:
tool_call_id: str = ""
tool_name: str = ""
arguments: Any = None
signal: AbortSignal = field(default_factory=AbortSignal)
"""AbortSignal that is triggered when :meth:`~copilot.CopilotSession.abort` is called."""
Comment thread
gimenete marked this conversation as resolved.
Outdated


ToolHandler = Callable[[ToolInvocation], ToolResult | Awaitable[ToolResult]]
Expand Down
70 changes: 70 additions & 0 deletions python/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,73 @@ def test_call_tool_result_dict_is_json_serialized_by_normalize(self):
result = _normalize_result({"content": [{"type": "text", "text": "hello"}]})
parsed = json.loads(result.text_result_for_llm)
assert parsed == {"content": [{"type": "text", "text": "hello"}]}


class TestAbortSignal:
def test_not_aborted_initially(self):
from copilot.tools import AbortController

controller = AbortController()
assert controller.signal.is_aborted is False

def test_aborted_after_abort_called(self):
from copilot.tools import AbortController

controller = AbortController()
controller.abort()
assert controller.signal.is_aborted is True

async def test_wait_returns_after_abort(self):
from copilot.tools import AbortController

controller = AbortController()
signal = controller.signal

async def aborter():
import asyncio

await asyncio.sleep(0)
controller.abort()

import asyncio

await asyncio.gather(signal.wait(), aborter())
assert signal.is_aborted is True

def test_tool_invocation_has_abort_signal(self):
from copilot.tools import AbortSignal

inv = ToolInvocation(session_id="s1", tool_call_id="c1", tool_name="t1")
assert isinstance(inv.signal, AbortSignal)
assert inv.signal.is_aborted is False

async def test_handler_receives_signal_via_invocation(self):
from copilot.tools import AbortController

received_signal = None
controller = AbortController()

class Params(BaseModel):
pass

@define_tool("test_signal", description="Test signal propagation")
def test_tool(params: Params, inv: ToolInvocation) -> str:
nonlocal received_signal
received_signal = inv.signal
return "ok"

invocation = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="test_signal",
arguments={},
signal=controller.signal,
)

await test_tool.handler(invocation)

assert received_signal is controller.signal
assert received_signal.is_aborted is False

controller.abort()
assert received_signal.is_aborted is True