Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llama-index-integrations/tools/llama-index-tools-mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ workflow = LoudWorkflow()
mcp = workflow_as_mcp(workflow, start_event_model=RunEvent)
```

If your workflow stores request-specific mutable state on the workflow instance,
pass a factory so each MCP tool call gets a fresh workflow:

```python
mcp = workflow_as_mcp(
workflow_factory=LoudWorkflow,
workflow_name="LoudWorkflow",
start_event_model=RunEvent,
)
```

Then, you can launch the MCP server (assuming you have the `mcp[cli]` extra installed):

```bash
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from mcp.client.session import ClientSession
from mcp.server.fastmcp import FastMCP, Context
Expand Down Expand Up @@ -75,10 +75,11 @@ async def aget_tools_from_mcp_url(


def workflow_as_mcp(
workflow: Workflow,
workflow: Optional[Workflow] = None,
workflow_name: Optional[str] = None,
workflow_description: Optional[str] = None,
start_event_model: Optional[BaseModel] = None,
workflow_factory: Optional[Callable[[], Workflow]] = None,
**fastmcp_init_kwargs: Any,
) -> FastMCP:
"""
Expand All @@ -88,47 +89,84 @@ def workflow_as_mcp(
within MCP, which will

Args:
workflow:
The workflow to convert.
workflow (optional):
The workflow instance to convert. This instance is reused for every MCP tool call.
workflow_name (optional):
The name of the workflow. Defaults to the workflow class name.
workflow_description (optional):
The description of the workflow. Defaults to the workflow docstring.
start_event_model (optional):
The start event model of the workflow. Can be a `BaseModel` or a `StartEvent` class.
Defaults to the workflow's custom `StartEvent` class.
workflow_factory (optional):
Factory that creates a fresh workflow instance for each MCP tool call.
**fastmcp_init_kwargs:
Additional keyword arguments to pass to the FastMCP constructor.

Returns:
The MCP app object.

"""
if workflow is None and workflow_factory is None:
raise ValueError("Must provide either workflow or workflow_factory.")
if workflow is not None and workflow_factory is not None:
raise ValueError("Provide either workflow or workflow_factory, not both.")

app = FastMCP(**fastmcp_init_kwargs)

# Dynamically get the start event class -- this is a bit of a hack
StartEventCLS = start_event_model or workflow._start_event_class
if start_event_model is None:
if workflow is None:
raise ValueError(
"Must provide start_event_model when using workflow_factory without a workflow instance."
)
StartEventCLS = workflow._start_event_class
else:
StartEventCLS = start_event_model

if StartEventCLS == StartEvent:
raise ValueError(
"Must declare a custom StartEvent class in your workflow or provide a start_event_model."
)

# Get the workflow name and description
workflow_name = workflow_name or workflow.__class__.__name__
workflow_description = workflow_description or workflow.__doc__
if workflow_name is None:
workflow_name = (
workflow.__class__.__name__
if workflow is not None
else getattr(workflow_factory, "__name__", None)
)
if workflow_name is None:
raise ValueError("Must provide workflow_name when it cannot be inferred.")

workflow_description = (
workflow_description
if workflow_description is not None
else workflow.__doc__
if workflow is not None
else None
)

@app.tool(name=workflow_name, description=workflow_description)
async def _workflow_tool(run_args: StartEventCLS, context: Context) -> Any:
# Handle edge cases where the start event is an Event or a BaseModel
# If the workflow does not have a custom StartEvent class, then we need to handle the event differently

if isinstance(run_args, Event) and workflow._start_event_class != StartEvent:
handler = workflow.run(start_event=run_args)
active_workflow = (
workflow_factory() if workflow_factory is not None else workflow
)
if active_workflow is None:
raise ValueError("Must provide either workflow or workflow_factory.")

if (
isinstance(run_args, Event)
and active_workflow._start_event_class != StartEvent
):
handler = active_workflow.run(start_event=run_args)
elif isinstance(run_args, BaseModel):
handler = workflow.run(**run_args.model_dump())
handler = active_workflow.run(**run_args.model_dump())
elif isinstance(run_args, dict):
start_event = StartEventCLS.model_validate(run_args)
handler = workflow.run(start_event=start_event)
handler = active_workflow.run(start_event=start_event)
else:
raise ValueError(f"Invalid start event type: {type(run_args)}")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from typing import Any

import pytest

from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
from llama_index.tools.mcp.utils import workflow_as_mcp


class TenantStart(StartEvent):
tenant_id: str


class CountingWorkflow(Workflow):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.call_count = 0
self.history: list[str] = []

@step
async def echo(self, ctx: Context, ev: TenantStart) -> StopEvent:
self.call_count += 1
self.history.append(ev.tenant_id)
return StopEvent(
result={
"call_index": self.call_count,
"history_visible": list(self.history),
}
)


@pytest.mark.asyncio
async def test_workflow_as_mcp_factory_creates_workflow_per_call() -> None:
created_workflows: list[CountingWorkflow] = []

def workflow_factory() -> CountingWorkflow:
workflow = CountingWorkflow(timeout=5)
created_workflows.append(workflow)
return workflow

app = workflow_as_mcp(
workflow_factory=workflow_factory,
workflow_name="CountingWorkflow",
start_event_model=TenantStart,
)

alice_result = await app.call_tool(
"CountingWorkflow", {"run_args": {"tenant_id": "alice"}}
)
bob_result = await app.call_tool(
"CountingWorkflow", {"run_args": {"tenant_id": "bob"}}
)

alice_payload = json.loads(alice_result[0].text)
bob_payload = json.loads(bob_result[0].text)

assert alice_payload == {"call_index": 1, "history_visible": ["alice"]}
assert bob_payload == {"call_index": 1, "history_visible": ["bob"]}
assert [workflow.history for workflow in created_workflows] == [["alice"], ["bob"]]