From ddd8d8069f92f61a5d2cdb65201a8ecb2f9e6ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Mon, 22 Jun 2026 22:22:28 +0800 Subject: [PATCH] fix: support isolated workflow MCP calls --- .../tools/llama-index-tools-mcp/README.md | 11 ++++ .../llama_index/tools/mcp/utils.py | 62 +++++++++++++++---- .../tests/test_workflow_as_mcp.py | 59 ++++++++++++++++++ 3 files changed, 120 insertions(+), 12 deletions(-) create mode 100644 llama-index-integrations/tools/llama-index-tools-mcp/tests/test_workflow_as_mcp.py diff --git a/llama-index-integrations/tools/llama-index-tools-mcp/README.md b/llama-index-integrations/tools/llama-index-tools-mcp/README.md index fb749af3b44..63055bc5a55 100644 --- a/llama-index-integrations/tools/llama-index-tools-mcp/README.md +++ b/llama-index-integrations/tools/llama-index-tools-mcp/README.md @@ -94,6 +94,17 @@ workflow = LoudWorkflow() mcp = workflow_as_mcp(workflow, start_event_model=RunEvent) ``` +If your workflow stores request-specific mutable state on the workflow instance, +pass a factory so each MCP tool call gets a fresh workflow: + +```python +mcp = workflow_as_mcp( + workflow_factory=LoudWorkflow, + workflow_name="LoudWorkflow", + start_event_model=RunEvent, +) +``` + Then, you can launch the MCP server (assuming you have the `mcp[cli]` extra installed): ```bash diff --git a/llama-index-integrations/tools/llama-index-tools-mcp/llama_index/tools/mcp/utils.py b/llama-index-integrations/tools/llama-index-tools-mcp/llama_index/tools/mcp/utils.py index fbc06127528..ee1d3fde954 100644 --- a/llama-index-integrations/tools/llama-index-tools-mcp/llama_index/tools/mcp/utils.py +++ b/llama-index-integrations/tools/llama-index-tools-mcp/llama_index/tools/mcp/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP, Context @@ -75,10 +75,11 @@ async def aget_tools_from_mcp_url( def workflow_as_mcp( - workflow: Workflow, + workflow: Optional[Workflow] = None, workflow_name: Optional[str] = None, workflow_description: Optional[str] = None, start_event_model: Optional[BaseModel] = None, + workflow_factory: Optional[Callable[[], Workflow]] = None, **fastmcp_init_kwargs: Any, ) -> FastMCP: """ @@ -88,8 +89,8 @@ def workflow_as_mcp( within MCP, which will Args: - workflow: - The workflow to convert. + workflow (optional): + The workflow instance to convert. This instance is reused for every MCP tool call. workflow_name (optional): The name of the workflow. Defaults to the workflow class name. workflow_description (optional): @@ -97,6 +98,8 @@ def workflow_as_mcp( start_event_model (optional): The start event model of the workflow. Can be a `BaseModel` or a `StartEvent` class. Defaults to the workflow's custom `StartEvent` class. + workflow_factory (optional): + Factory that creates a fresh workflow instance for each MCP tool call. **fastmcp_init_kwargs: Additional keyword arguments to pass to the FastMCP constructor. @@ -104,31 +107,66 @@ def workflow_as_mcp( The MCP app object. """ + if workflow is None and workflow_factory is None: + raise ValueError("Must provide either workflow or workflow_factory.") + if workflow is not None and workflow_factory is not None: + raise ValueError("Provide either workflow or workflow_factory, not both.") + app = FastMCP(**fastmcp_init_kwargs) # Dynamically get the start event class -- this is a bit of a hack - StartEventCLS = start_event_model or workflow._start_event_class + if start_event_model is None: + if workflow is None: + raise ValueError( + "Must provide start_event_model when using workflow_factory without a workflow instance." + ) + StartEventCLS = workflow._start_event_class + else: + StartEventCLS = start_event_model + if StartEventCLS == StartEvent: raise ValueError( "Must declare a custom StartEvent class in your workflow or provide a start_event_model." ) # Get the workflow name and description - workflow_name = workflow_name or workflow.__class__.__name__ - workflow_description = workflow_description or workflow.__doc__ + if workflow_name is None: + workflow_name = ( + workflow.__class__.__name__ + if workflow is not None + else getattr(workflow_factory, "__name__", None) + ) + if workflow_name is None: + raise ValueError("Must provide workflow_name when it cannot be inferred.") + + workflow_description = ( + workflow_description + if workflow_description is not None + else workflow.__doc__ + if workflow is not None + else None + ) @app.tool(name=workflow_name, description=workflow_description) async def _workflow_tool(run_args: StartEventCLS, context: Context) -> Any: # Handle edge cases where the start event is an Event or a BaseModel # If the workflow does not have a custom StartEvent class, then we need to handle the event differently - - if isinstance(run_args, Event) and workflow._start_event_class != StartEvent: - handler = workflow.run(start_event=run_args) + active_workflow = ( + workflow_factory() if workflow_factory is not None else workflow + ) + if active_workflow is None: + raise ValueError("Must provide either workflow or workflow_factory.") + + if ( + isinstance(run_args, Event) + and active_workflow._start_event_class != StartEvent + ): + handler = active_workflow.run(start_event=run_args) elif isinstance(run_args, BaseModel): - handler = workflow.run(**run_args.model_dump()) + handler = active_workflow.run(**run_args.model_dump()) elif isinstance(run_args, dict): start_event = StartEventCLS.model_validate(run_args) - handler = workflow.run(start_event=start_event) + handler = active_workflow.run(start_event=start_event) else: raise ValueError(f"Invalid start event type: {type(run_args)}") diff --git a/llama-index-integrations/tools/llama-index-tools-mcp/tests/test_workflow_as_mcp.py b/llama-index-integrations/tools/llama-index-tools-mcp/tests/test_workflow_as_mcp.py new file mode 100644 index 00000000000..3255f116b29 --- /dev/null +++ b/llama-index-integrations/tools/llama-index-tools-mcp/tests/test_workflow_as_mcp.py @@ -0,0 +1,59 @@ +import json +from typing import Any + +import pytest + +from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step +from llama_index.tools.mcp.utils import workflow_as_mcp + + +class TenantStart(StartEvent): + tenant_id: str + + +class CountingWorkflow(Workflow): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.call_count = 0 + self.history: list[str] = [] + + @step + async def echo(self, ctx: Context, ev: TenantStart) -> StopEvent: + self.call_count += 1 + self.history.append(ev.tenant_id) + return StopEvent( + result={ + "call_index": self.call_count, + "history_visible": list(self.history), + } + ) + + +@pytest.mark.asyncio +async def test_workflow_as_mcp_factory_creates_workflow_per_call() -> None: + created_workflows: list[CountingWorkflow] = [] + + def workflow_factory() -> CountingWorkflow: + workflow = CountingWorkflow(timeout=5) + created_workflows.append(workflow) + return workflow + + app = workflow_as_mcp( + workflow_factory=workflow_factory, + workflow_name="CountingWorkflow", + start_event_model=TenantStart, + ) + + alice_result = await app.call_tool( + "CountingWorkflow", {"run_args": {"tenant_id": "alice"}} + ) + bob_result = await app.call_tool( + "CountingWorkflow", {"run_args": {"tenant_id": "bob"}} + ) + + alice_payload = json.loads(alice_result[0].text) + bob_payload = json.loads(bob_result[0].text) + + assert alice_payload == {"call_index": 1, "history_visible": ["alice"]} + assert bob_payload == {"call_index": 1, "history_visible": ["bob"]} + assert [workflow.history for workflow in created_workflows] == [["alice"], ["bob"]]