Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,31 @@ session, _ := client.CreateSession(context.Background(), &copilot.SessionConfig{

When the model selects a tool, the SDK automatically runs your handler (in parallel with other calls) and responds to the CLI's `tool.call` with the handler's result.

#### Cooperative Cancellation via session.Abort

`ToolInvocation.Context` is a `context.Context` that is cancelled when `session.Abort` is called. Pass it to any cancellable operation (HTTP requests, DB queries, sleeps) so the handler stops promptly when the session is aborted:

```go
lookupIssue := copilot.DefineTool("lookup_issue", "Fetch issue details from our tracker",
func(params LookupIssueParams, inv copilot.ToolInvocation) (any, error) {
// Pass inv.Context so the HTTP request is cancelled on session.Abort.
req, err := http.NewRequestWithContext(inv.Context, "GET",
"https://api.example.com/issues/"+params.ID, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err // returns context.Canceled when aborted
}
defer resp.Body.Close()
// ...
return summary, nil
})
```

Handlers that don't use `inv.Context` are unaffected; they run to completion as before.

#### Overriding Built-in Tools

If you register a tool with the same name as a built-in CLI tool (e.g. `edit_file`, `read_file`), the SDK will throw an error unless you explicitly opt in by setting `OverridesBuiltInTool = true`. This flag signals that you intend to replace the built-in tool with your custom implementation.
Expand Down
44 changes: 40 additions & 4 deletions go/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ type Session struct {
capabilities SessionCapabilities
capabilitiesMu sync.RWMutex

// toolCallCancels tracks cancel functions for in-flight tool calls so that
// Abort can propagate cancellation into handler contexts.
toolCallCancels map[string]context.CancelFunc
toolCallCancelsMu sync.Mutex

// eventCh serializes user event handler dispatch. dispatchEvent enqueues;
// a single goroutine (processEvents) dequeues and invokes handlers in FIFO order.
eventCh chan SessionEvent
Expand Down Expand Up @@ -1337,11 +1342,35 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) {

// executeToolAndRespond executes a tool handler and sends the result back via RPC.
func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, arguments any, handler ToolHandler, traceparent, tracestate string) {
ctx := contextWithTraceParent(context.Background(), traceparent, tracestate)
// traceCtx carries OTel trace propagation but is not subject to abort cancellation.
// It is used for administrative RPC calls that must complete regardless of abort.
traceCtx := contextWithTraceParent(context.Background(), traceparent, tracestate)
// ctx is passed to the tool handler and is cancelled when session.Abort is called,
// giving handlers a cooperative cancellation signal.
ctx, cancel := context.WithCancel(traceCtx)

s.toolCallCancelsMu.Lock()
if s.toolCallCancels == nil {
s.toolCallCancels = make(map[string]context.CancelFunc)
}
s.toolCallCancels[toolCallID] = cancel
s.toolCallCancelsMu.Unlock()

// Cleanup runs last (registered first). Removes the cancel from the in-flight map
// and releases context resources.
defer func() {
s.toolCallCancelsMu.Lock()
delete(s.toolCallCancels, toolCallID)
s.toolCallCancelsMu.Unlock()
cancel()
}()

// Panic recovery runs first (registered second, LIFO). Uses traceCtx to ensure
// the error response is sent even if ctx was already cancelled by Abort.
defer func() {
if r := recover(); r != nil {
errMsg := fmt.Sprintf("tool panic: %v", r)
s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{
s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{
RequestID: requestID,
Error: &errMsg,
})
Expand All @@ -1353,13 +1382,14 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string,
ToolCallID: toolCallID,
ToolName: toolName,
Arguments: arguments,
Context: ctx,
TraceContext: ctx,
}
Comment thread
gimenete marked this conversation as resolved.

result, err := handler(invocation)
if err != nil {
errMsg := err.Error()
s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{
s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{
RequestID: requestID,
Error: &errMsg,
})
Expand Down Expand Up @@ -1389,7 +1419,7 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string,
if result.Error != "" {
rpcResult.Error = &result.Error
}
s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{
s.RPC.Tools.HandlePendingToolCall(traceCtx, &rpc.HandlePendingToolCallRequest{
RequestID: requestID,
Result: rpcResult,
})
Expand Down Expand Up @@ -1555,6 +1585,12 @@ func (s *Session) Abort(ctx context.Context) error {
return fmt.Errorf("failed to abort session: %w", err)
}

s.toolCallCancelsMu.Lock()
for _, cancel := range s.toolCallCancels {
cancel()
}
s.toolCallCancelsMu.Unlock()

return nil
}

Expand Down
201 changes: 201 additions & 0 deletions go/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1031,3 +1031,204 @@ func TestSession_ElicitationRequestSchema(t *testing.T) {
}
})
}

// TestToolInvocation_ContextCancelledOnAbort verifies that the context passed to a
// tool handler is cancelled when the in-flight cancel func (as used by Abort) fires.
func TestToolInvocation_ContextCancelledOnAbort(t *testing.T) {
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
defer stdinR.Close()
defer stdinW.Close()
defer stdoutR.Close()
defer stdoutW.Close()

client := jsonrpc2.NewClient(stdinW, stdoutR)
client.Start()
defer client.Stop()

session := &Session{
SessionID: "session-abort-test",
client: client,
RPC: rpc.NewSessionRPC(client, "session-abort-test"),
}

// Drain the RPC responses from the mock server side.
go func() {
scanner := bufio.NewScanner(stdinR)
for scanner.Scan() {
// read Content-Length header
line := scanner.Text()
if !strings.HasPrefix(line, "Content-Length:") {
continue
}
var contentLen int
fmt.Sscanf(line, "Content-Length: %d", &contentLen)
// skip blank separator
scanner.Scan()
body := make([]byte, contentLen)
io.ReadFull(stdinR, body)
Comment thread
gimenete marked this conversation as resolved.
Outdated

var req struct {
ID json.RawMessage `json:"id"`
Method string `json:"method"`
}
if err := json.Unmarshal(body, &req); err != nil || req.ID == nil {
continue
}
resp, _ := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"id": json.RawMessage(req.ID),
"result": map[string]any{},
})
fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp)
}
}()

// Channel to receive the invocation context from the handler.
ctxCh := make(chan context.Context, 1)

// The handler blocks until its context is cancelled, then reports.
handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) {
ctxCh <- inv.Context
<-inv.Context.Done()
return ToolResult{TextResultForLLM: "cancelled"}, nil
})

done := make(chan struct{})
go func() {
defer close(done)
session.executeToolAndRespond("req-1", "my_tool", "tc-1", nil, handler, "", "")
}()

// Wait for the handler to start and capture its context.
var handlerCtx context.Context
select {
case handlerCtx = <-ctxCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handler to start")
}

// Verify the context is not yet cancelled.
if handlerCtx.Err() != nil {
t.Fatalf("expected context to be active, got %v", handlerCtx.Err())
}

// Simulate what Abort() does: cancel all in-flight tool call contexts.
session.toolCallCancelsMu.Lock()
for _, cancel := range session.toolCallCancels {
cancel()
}
session.toolCallCancelsMu.Unlock()

// Wait for the handler to finish.
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handler to finish after cancellation")
}

// The handler context must be cancelled.
if handlerCtx.Err() == nil {
t.Fatal("expected handler context to be cancelled after abort")
}

// The cancel func must have been removed from the map.
session.toolCallCancelsMu.Lock()
remaining := len(session.toolCallCancels)
session.toolCallCancelsMu.Unlock()
if remaining != 0 {
t.Fatalf("expected toolCallCancels to be empty after execution, got %d entries", remaining)
}
}

// TestToolInvocation_ContextPopulated verifies that executeToolAndRespond sets
// both Context and TraceContext on the ToolInvocation passed to the handler.
func TestToolInvocation_ContextPopulated(t *testing.T) {
stdinR, stdinW := io.Pipe()
stdoutR, stdoutW := io.Pipe()
defer stdinR.Close()
defer stdinW.Close()
defer stdoutR.Close()
defer stdoutW.Close()

client := jsonrpc2.NewClient(stdinW, stdoutR)
client.Start()
defer client.Stop()

session := &Session{
SessionID: "session-ctx-test",
client: client,
RPC: rpc.NewSessionRPC(client, "session-ctx-test"),
}

// Drain RPC responses.
go func() {
scanner := bufio.NewScanner(stdinR)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "Content-Length:") {
continue
}
var contentLen int
fmt.Sscanf(line, "Content-Length: %d", &contentLen)
scanner.Scan()
body := make([]byte, contentLen)
io.ReadFull(stdinR, body)

var req struct {
ID json.RawMessage `json:"id"`
Method string `json:"method"`
}
if err := json.Unmarshal(body, &req); err != nil || req.ID == nil {
continue
}
resp, _ := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"id": json.RawMessage(req.ID),
"result": map[string]any{},
})
fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(resp), resp)
}
}()

invCh := make(chan ToolInvocation, 1)
handler := ToolHandler(func(inv ToolInvocation) (ToolResult, error) {
invCh <- inv
return ToolResult{TextResultForLLM: "ok"}, nil
})

done := make(chan struct{})
go func() {
defer close(done)
session.executeToolAndRespond("req-2", "check_tool", "tc-2", map[string]any{"x": 1}, handler, "", "")
}()

var inv ToolInvocation
select {
case inv = <-invCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handler invocation")
}

if inv.Context == nil {
t.Fatal("expected ToolInvocation.Context to be set")
}
if inv.TraceContext == nil {
t.Fatal("expected ToolInvocation.TraceContext to be set")
}
if inv.Context != inv.TraceContext {
t.Error("expected Context and TraceContext to be the same value")
}
if inv.SessionID != "session-ctx-test" {
t.Errorf("expected SessionID session-ctx-test, got %q", inv.SessionID)
}
if inv.ToolCallID != "tc-2" {
t.Errorf("expected ToolCallID tc-2, got %q", inv.ToolCallID)
}

select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for executeToolAndRespond to complete")
}
}
14 changes: 13 additions & 1 deletion go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -1154,10 +1154,22 @@ type ToolInvocation struct {
ToolName string
Arguments any

// Context is the primary context for this tool invocation. It carries
// W3C Trace Context propagation (for OpenTelemetry) and is cancelled
// when session.Abort is called, allowing handlers to cooperatively stop
// in-flight work (e.g. pass to http.NewRequestWithContext, sql.QueryContext).
//
// Handlers that do not inspect the context continue to work unchanged.
Context context.Context

// TraceContext is deprecated: use Context instead.
// TraceContext carries the W3C Trace Context propagated from the CLI's
// execute_tool span. Pass this to OpenTelemetry-aware code so that
// execute_tool span. Pass this to OpenTelemetry-aware code so that
// child spans created inside the handler are parented to the CLI span.
// When no trace context is available this will be context.Background().
//
// Deprecated: Use Context, which carries the same trace information and
// is additionally cancelled when session.Abort is called.
TraceContext context.Context
}

Expand Down