diff --git a/llama-index-core/llama_index/core/agent/workflow/base_agent.py b/llama-index-core/llama_index/core/agent/workflow/base_agent.py index 015d964e495..987e3a5c728 100644 --- a/llama-index-core/llama_index/core/agent/workflow/base_agent.py +++ b/llama-index-core/llama_index/core/agent/workflow/base_agent.py @@ -76,6 +76,35 @@ ) +class _ToolCallContext: + """Delegate Context calls while scoping default HITL waiter ids per tool.""" + + def __init__(self, ctx: Context, tool_id: str) -> None: + self._ctx = ctx + self._tool_id = tool_id + + def __getattr__(self, name: str) -> Any: + return getattr(self._ctx, name) + + async def wait_for_event( + self, + event_type: Type[Any], + waiter_event: Any = None, + waiter_id: Optional[str] = None, + requirements: Optional[dict[str, Any]] = None, + timeout: Optional[float] = 2000, + ) -> Any: + if waiter_id is None: + waiter_id = f"agent_tool_call:{self._tool_id}" + return await self._ctx.wait_for_event( + event_type, + waiter_event=waiter_event, + waiter_id=waiter_id, + requirements=requirements, + timeout=timeout, + ) + + def get_default_llm() -> LLM: return Settings.llm @@ -349,8 +378,10 @@ async def _call_tool( ctx: Context, tool: AsyncBaseTool, tool_input: dict, + tool_id: Optional[str] = None, ) -> ToolOutput: """Call the given tool with the given input.""" + tool_id = tool_id or tool.metadata.get_name() try: if ( isinstance(tool, FunctionTool) @@ -358,7 +389,7 @@ async def _call_tool( and tool.ctx_param_name is not None ): new_tool_input = {**tool_input} - new_tool_input[tool.ctx_param_name] = ctx + new_tool_input[tool.ctx_param_name] = _ToolCallContext(ctx, tool_id) tool_output = await tool.acall(**new_tool_input) else: tool_output = await tool.acall(**tool_input) @@ -645,7 +676,7 @@ async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: ) else: tool = tools_by_name[ev.tool_name] - result = await self._call_tool(ctx, tool, ev.tool_kwargs) + result = await self._call_tool(ctx, tool, ev.tool_kwargs, ev.tool_id) result_ev = ToolCallResult( tool_name=ev.tool_name, diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py index 3f3d28984ab..db5f5d87d1d 100644 --- a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -24,6 +24,7 @@ DEFAULT_AGENT_NAME, DEFAULT_AGENT_DESCRIPTION, DEFAULT_MAX_ITERATIONS, + _ToolCallContext, _get_waiting_for_event_exception, ) from llama_index.core.agent.workflow.prompts import DEFAULT_EARLY_STOPPING_PROMPT @@ -350,8 +351,10 @@ async def _call_tool( ctx: Context, tool: AsyncBaseTool, tool_input: dict, + tool_id: Optional[str] = None, ) -> ToolOutput: """Call the given tool with the given input.""" + tool_id = tool_id or tool.metadata.get_name() try: if ( isinstance(tool, FunctionTool) @@ -359,7 +362,7 @@ async def _call_tool( and tool.ctx_param_name is not None ): new_tool_input = {**tool_input} - new_tool_input[tool.ctx_param_name] = ctx + new_tool_input[tool.ctx_param_name] = _ToolCallContext(ctx, tool_id) tool_output = await tool.acall(**new_tool_input) else: tool_output = await tool.acall(**tool_input) @@ -654,7 +657,7 @@ async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: ) else: tool = tools_by_name[ev.tool_name] - result = await self._call_tool(ctx, tool, ev.tool_kwargs) + result = await self._call_tool(ctx, tool, ev.tool_kwargs, ev.tool_id) result_ev = ToolCallResult( tool_name=ev.tool_name, diff --git a/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py index 295c507302e..5b0e4939bf7 100644 --- a/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py +++ b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py @@ -1,3 +1,4 @@ +import asyncio from typing import List import pytest @@ -424,6 +425,69 @@ async def hitl(ctx: Context): assert "HITL successful" in str(response) +@pytest.mark.asyncio +async def test_parallel_hitl_tool_calls_have_scoped_waiters(): + """Parallel tools that wait for HITL should not overwrite each other.""" + + async def hitl(ctx: Context, label: str): + resp = await ctx.wait_for_event( + HumanResponseEvent, + waiter_event=InputRequiredEvent(prefix=f"approve {label}?"), + timeout=5, + ) + return f"{label}: {resp.response}" + + agent = FunctionAgent( + name="agent", + description="test", + tools=[hitl], + llm=MockFunctionCallingLLM( + response_generator=_response_generator_from_list( + [ + ChatMessage( + role=MessageRole.ASSISTANT, + content="need approval", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="call_one", + tool_name="hitl", + tool_kwargs={"label": "one"}, + ), + ToolSelection( + tool_id="call_two", + tool_name="hitl", + tool_kwargs={"label": "two"}, + ), + ] + }, + ), + ChatMessage(role=MessageRole.ASSISTANT, content="done"), + ] + ) + ), + ) + + workflow = AgentWorkflow(agents=[agent], root_agent="agent", timeout=10) + handler = workflow.run(user_msg="test") + + input_required_events = [] + async for ev in handler.stream_events(): + if isinstance(ev, InputRequiredEvent): + input_required_events.append(ev) + handler.ctx.send_event( + HumanResponseEvent(response=f"ok {len(input_required_events)}") + ) + if len(input_required_events) == 2: + break + + response = await asyncio.wait_for(handler, timeout=10) + + assert len(input_required_events) == 2 + assert response is not None + assert "done" in str(response) + + @pytest.mark.asyncio async def test_max_iterations(): """Test max iterations."""