diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index f5fc15a3d7..d86c140f63 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -80,8 +80,9 @@ agent_framework/ - **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s. - **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses. -- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument. -- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins. +- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is never forwarded as an argument; trusted caller/runtime `_meta` is validated as MCP request metadata, model-supplied `_meta` is discarded in generated MCP functions, and metadata precedence is caller/runtime < OpenTelemetry < tools/list metadata. +- **`allowed_tools`** (constructor arg on all `MCPTool` subclasses) - Restricts exposed MCP tools by raw remote MCP tool identity. Prefixed local names remain accepted only when the raw remote name already matches its normalized form; normalized/local aliases do not authorize a different raw remote name. If multiple raw remote tool names map to the same local function name, tool loading raises `ToolExecutionException` instead of first-one-wins shadowing. +- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a normal forwarded argument name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins; `_meta` is the exception and only trusted runtime/caller metadata is used. - **Sampling guardrails** (`sampling_callback`) - Passing `client=` advertises `SamplingCapability` so the server can send `sampling/createMessage`. Because remote servers are untrusted (confused-deputy risk), the default `sampling_callback` is **deny-by-default** and applies, in order: a per-session rate limit (`sampling_max_requests`, default `_DEFAULT_SAMPLING_MAX_REQUESTS`), an approval gate (`sampling_approval_callback`), and a `maxTokens` cap (`sampling_max_tokens`, default `_DEFAULT_SAMPLING_MAX_TOKENS`). The approval callback (constructor arg on all subclasses; exported type alias `SamplingApprovalCallback`) receives the raw `CreateMessageRequestParams`, may be sync or async, and must return truthy to approve. When it is `None` (the default) every sampling request is denied; pass `lambda params: True` to restore legacy auto-approve as an explicit opt-in. Requests and denials are logged at WARNING (content is not logged). The per-session counter resets in `_reset_session_state`. - **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields: - `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies. diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 6d8d2ff8b1..7c559cc421 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -74,6 +74,11 @@ class MCPSpecificApproval(TypedDict, total=False): # Reserved key in an ``additional_tool_argument_names`` mapping that applies its # values to every tool on the server rather than a single named tool. _MCP_GLOBAL_EXTRA_ARGS_KEY = "*" +_MCP_META_LABEL_PATTERN = r"[A-Za-z](?:[A-Za-z0-9-]*[A-Za-z0-9])?" +_MCP_META_KEY_PATTERN = re.compile( + rf"^(?:(?:{_MCP_META_LABEL_PATTERN})(?:\.{_MCP_META_LABEL_PATTERN})*/)?" + r"[A-Za-z0-9](?:[A-Za-z0-9_.-]*[A-Za-z0-9])?$" +) # Framework kwargs that flow through the function-invocation pipeline (via # ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP # server: they are internal objects that the MCP SDK cannot serialize. They are @@ -205,7 +210,42 @@ def _normalize_additional_tool_argument_names( return set(additional_tool_argument_names), {} -def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None: +def _mcp_config_candidate_names(*, local_name: str, normalized_name: str, remote_name: str) -> tuple[str, ...]: + """Return safe configuration names for MCP allow/approval matching.""" + names = [remote_name] + if normalized_name == remote_name and local_name != remote_name: + names.append(local_name) + return tuple(names) + + +def _validate_mcp_meta_key(key: str) -> None: + """Validate an MCP ``_meta`` key against the 2025-06-18 key-name format.""" + if not _MCP_META_KEY_PATTERN.fullmatch(key): + raise ToolExecutionException(f"Invalid MCP _meta key name: {key!r}.") + + +def _validate_mcp_meta(raw_meta: object | None) -> dict[str, Any] | None: + """Validate and copy MCP request metadata.""" + if raw_meta is None: + return None + if not isinstance(raw_meta, dict): + raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.") + + raw_meta_dict = cast(Mapping[object, Any], raw_meta) + meta: dict[str, Any] = {} + for key, value in raw_meta_dict.items(): + if not isinstance(key, str): + raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.") + _validate_mcp_meta_key(key) + meta[key] = value + return meta + + +def _inject_otel_into_mcp_meta( + meta: dict[str, Any] | None = None, + *, + overwrite: bool = False, +) -> dict[str, Any] | None: """Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s).""" carrier: dict[str, str] = {} propagate.inject(carrier) @@ -215,7 +255,8 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, if meta is None: meta = {} for key, value in carrier.items(): - if key not in meta: + _validate_mcp_meta_key(key) + if overwrite or key not in meta: meta[key] = value return meta @@ -381,7 +422,9 @@ def __init__( approval_mode: Whether approval is required to run tools. allowed_tools: Optional allow-list of MCP tool names to expose as functions. ``None`` (the default) exposes every tool advertised by the MCP server. - A non-empty collection exposes only the tools whose names appear in it. + A non-empty collection exposes only the raw remote tools whose names appear in it. For + compatibility, the prefixed local function name is also accepted when the raw remote name already + matches its normalized form; normalized aliases do not authorize a different raw remote tool. An empty collection (``[]``) exposes no tools — if you simply want to disable tool execution, prefer ``load_tools=False`` instead. ``[]`` is useful as a runtime guard or when you want to load tool metadata for @@ -753,11 +796,14 @@ def functions(self) -> list[FunctionTool]: additional_properties = func.additional_properties or {} normalized_name = additional_properties.get(_MCP_NORMALIZED_NAME_KEY) remote_name = additional_properties.get(_MCP_REMOTE_NAME_KEY) - if ( - func.name in allowed_names - or (isinstance(normalized_name, str) and normalized_name in allowed_names) - or (isinstance(remote_name, str) and remote_name in allowed_names) - ): + if not isinstance(normalized_name, str) or not isinstance(remote_name, str): + continue + candidate_names = _mcp_config_candidate_names( + local_name=func.name, + normalized_name=normalized_name, + remote_name=remote_name, + ) + if any(name in allowed_names for name in candidate_names): filtered_functions.append(func) return filtered_functions @@ -1381,7 +1427,13 @@ async def _load_prompts_locked(self) -> None: continue input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name) + approval_mode = self._determine_approval_mode( + *_mcp_config_candidate_names( + local_name=local_name, + normalized_name=normalized_name, + remote_name=prompt.name, + ) + ) func: FunctionTool = FunctionTool( func=partial(self.get_prompt, prompt.name), name=local_name, @@ -1422,7 +1474,11 @@ async def _load_tools_locked(self) -> None: return # Track existing function names to prevent duplicates - existing_names = {func.name for func in self._functions} + existing_remote_by_local: dict[str, str] = {} + for func in self._functions: + remote_name = (func.additional_properties or {}).get(_MCP_REMOTE_NAME_KEY) + if isinstance(remote_name, str): + existing_remote_by_local[func.name] = remote_name tool_call_meta_by_name: dict[str, dict[str, Any]] = {} tool_task_support_by_name: dict[str, str] = {} tool_param_names_by_name: dict[str, set[str]] = {} @@ -1462,7 +1518,7 @@ async def _load_tools_locked(self) -> None: for tool in tool_list.tools: if tool.meta is not None: - tool_call_meta_by_name[tool.name] = dict(tool.meta) + tool_call_meta_by_name[tool.name] = _validate_mcp_meta(tool.meta) or {} task_support = getattr(getattr(tool, "execution", None), "taskSupport", None) if task_support is not None: @@ -1490,10 +1546,24 @@ async def _load_tools_locked(self) -> None: local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) # Skip if already loaded - if local_name in existing_names: + if local_name in existing_remote_by_local: + if existing_remote_by_local.get(local_name) != tool.name: + raise ToolExecutionException( + "MCP server advertised multiple tools that map to the same local function name: " + f"{existing_remote_by_local[local_name]!r} and {tool.name!r} both map to " + f"{local_name!r}." + ) continue - approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) + existing_remote_by_local[local_name] = tool.name + + approval_mode = self._determine_approval_mode( + *_mcp_config_candidate_names( + local_name=local_name, + normalized_name=normalized_name, + remote_name=tool.name, + ) + ) async def _call_tool_with_runtime_kwargs( ctx: FunctionInvocationContext, @@ -1501,8 +1571,13 @@ async def _call_tool_with_runtime_kwargs( _remote_tool_name: str = tool.name, **kwargs: Any, ) -> str | list[Content]: + trusted_meta = ctx.kwargs.get("_meta") call_kwargs = dict(ctx.kwargs) call_kwargs.update(kwargs) + if trusted_meta is not None: + call_kwargs["_meta"] = trusted_meta + else: + call_kwargs.pop("_meta", None) return await self.call_tool(_remote_tool_name, **call_kwargs) # Create FunctionTools out of each tool @@ -1518,7 +1593,6 @@ async def _call_tool_with_runtime_kwargs( }, ) self._functions.append(func) - existing_names.add(local_name) # Check if there are more pages if not tool_list.nextCursor: @@ -1636,8 +1710,8 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: Keyword Args: _meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the ``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument. - User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in - non-conflicting keys. + OpenTelemetry propagation overrides caller-supplied keys, and metadata from ``tools/list`` + overrides both. kwargs: Remaining arguments to pass to the tool. Returns: @@ -1746,17 +1820,7 @@ def _prepare_call_kwargs( self, tool_name: str, kwargs: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, Any] | None]: """Filter kwargs down to the tool's arguments and build the merged MCP request metadata.""" - raw_user_meta: object | None = kwargs.get("_meta") - user_meta: dict[str, Any] | None = None - if raw_user_meta is not None and not isinstance(raw_user_meta, dict): - raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.") - if isinstance(raw_user_meta, dict): - raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta) - user_meta = {} - for key, value in raw_user_meta_dict.items(): - if not isinstance(key, str): - raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.") - user_meta[key] = value + user_meta = _validate_mcp_meta(kwargs.get("_meta")) # Allowlist: forward only the tool's declared parameters (from inputSchema.properties) # plus any user-configured extra argument names. Everything else - notably the @@ -1783,12 +1847,12 @@ def _prepare_call_kwargs( } # Some MCP proxies require their tools/list metadata to be echoed on tools/call. - tool_meta = self._tool_call_meta_by_name.get(tool_name) - request_meta = dict(tool_meta) if tool_meta is not None else None - if user_meta is not None: - request_meta = {**(request_meta or {}), **user_meta} - meta = _inject_otel_into_mcp_meta(request_meta) - return filtered_kwargs, meta + request_meta = dict(user_meta) if user_meta is not None else None + request_meta = _inject_otel_into_mcp_meta(request_meta, overwrite=True) + tool_meta = _validate_mcp_meta(self._tool_call_meta_by_name.get(tool_name)) + if tool_meta is not None: + request_meta = {**(request_meta or {}), **tool_meta} + return filtered_kwargs, request_meta async def call_tool_as_task(self, tool_name: str, **kwargs: Any) -> str | list[Content]: """Call an MCP tool via the long-running task lifecycle (SEP-2663). diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index f00fdd3aab..c9322101f7 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -122,6 +122,113 @@ async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration assert tool.functions[0].approval_mode == "always_require" +async def test_allowed_tools_does_not_authorize_normalized_remote_name_collision() -> None: + """A normalized/local allowlist match must not authorize a different raw remote tool.""" + tool = MCPTool(name="test_server", allowed_tools=["delete-file"]) # type: ignore[abstract] + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="delete/file", + description="Delete a file", + inputSchema={"type": "object", "properties": {}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + await tool.load_tools() + + assert [function.name for function in tool._functions] == ["delete-file"] + assert tool.functions == [] + + +async def test_load_tools_rejects_colliding_normalized_tool_names() -> None: + """A remote MCP server must not choose which raw tool backs a colliding local name.""" + tool = MCPTool(name="test_server", allowed_tools=["delete-file"]) # type: ignore[abstract] + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="delete/file", + description="Unauthorized tool", + inputSchema={"type": "object", "properties": {}}, + ), + types.Tool( + name="delete-file", + description="Authorized tool", + inputSchema={"type": "object", "properties": {}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + with pytest.raises(ToolExecutionException, match="map to the same local function name"): + await tool.load_tools() + + +async def test_allowed_tools_exact_raw_name_allows_normalized_function_name() -> None: + """An exact raw remote allowlist entry still exposes that raw tool, regardless of local normalization.""" + tool = MCPTool(name="test_server", allowed_tools=["delete/file"]) # type: ignore[abstract] + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="delete/file", + description="Delete a file", + inputSchema={"type": "object", "properties": {}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + await tool.load_tools() + + assert [function.name for function in tool.functions] == ["delete-file"] + assert tool.functions[0].additional_properties is not None + assert tool.functions[0].additional_properties["_mcp_remote_name"] == "delete/file" + + +async def test_approval_mode_does_not_match_normalized_colliding_name() -> None: + """Approval rules should not apply to a different raw remote tool through normalization.""" + tool = MCPTool( # type: ignore[abstract] + name="test_server", + approval_mode={"always_require_approval": ["delete-file"]}, + ) + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="delete/file", + description="Delete a file", + inputSchema={"type": "object", "properties": {}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + await tool.load_tools() + + assert tool._functions[0].name == "delete-file" + assert tool._functions[0].approval_mode == "never_require" + + async def test_load_prompts_with_tool_name_prefix() -> None: """Prefixed MCP prompt names should be exposed with the configured prefix.""" tool = MCPTool(name="docs", tool_name_prefix="docs") # type: ignore[abstract] @@ -3339,6 +3446,7 @@ async def test_load_tools_adds_properties_to_zero_arg_tool_schema(): none_schema_tool.name = "none_schema_tool" none_schema_tool.description = "A tool with None inputSchema" none_schema_tool.inputSchema = None + none_schema_tool.meta = None page.tools.append(none_schema_tool) page.nextCursor = None @@ -4777,7 +4885,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta(): - """User-provided _meta should be sent as MCP request metadata, not tool arguments.""" + """Tools/list _meta should win over caller-provided _meta on conflicts.""" from opentelemetry import trace tool_meta = {"from_tool": "tool-value", "shared": "tool-value"} @@ -4817,11 +4925,153 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: assert call_kwargs["meta"] == { "from_tool": "tool-value", "from_user": "user-value", - "shared": "user-value", + "shared": "tool-value", } assert user_meta == {"from_user": "user-value", "shared": "user-value"} +async def test_mcp_tool_function_invocation_strips_model_supplied_meta() -> None: + """Model-supplied _meta should not become MCP request metadata.""" + from opentelemetry import trace + + class TestServer(MCPTool): + async def connect(self) -> None: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + with ( + trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)), + patch("agent_framework._mcp.propagate.inject", side_effect=lambda carrier: None), + ): + await server.functions[0].invoke( + arguments={"param": "test_value", "_meta": {"attacker.example/route": "evil"}} + ) + + call_kwargs = server.session.call_tool.call_args.kwargs # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert call_kwargs["arguments"] == {"param": "test_value"} + assert call_kwargs["meta"] is None + + +async def test_mcp_tool_function_invocation_preserves_trusted_meta_over_model_meta() -> None: + """Trusted function-invocation _meta should be restored after model arguments are merged.""" + from opentelemetry import trace + + trusted_meta = {"trusted.example/route": "trusted"} + + class TestServer(MCPTool): + async def connect(self) -> None: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + context = FunctionInvocationContext( + function=server.functions[0], + arguments={}, + kwargs={"_meta": trusted_meta}, + ) + with ( + trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)), + patch("agent_framework._mcp.propagate.inject", side_effect=lambda carrier: None), + ): + await server.functions[0].invoke( + arguments={"param": "test_value", "_meta": {"attacker.example/route": "evil"}}, + context=context, + ) + + call_kwargs = server.session.call_tool.call_args.kwargs # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert call_kwargs["arguments"] == {"param": "test_value"} + assert call_kwargs["meta"] == trusted_meta + + +async def test_mcp_tool_call_tool_otel_meta_overrides_user_meta_but_not_tool_list_meta() -> None: + """OpenTelemetry should override caller metadata while tools/list metadata remains most trusted.""" + from opentelemetry import trace + + tool_meta = {"traceparent": "tool-traceparent", "from_tool": "tool-value"} + user_meta = {"traceparent": "user-traceparent", "from_user": "user-value"} + + class TestServer(MCPTool): + async def connect(self) -> None: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta=tool_meta, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + with ( + trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)), + patch( + "agent_framework._mcp.propagate.inject", + side_effect=lambda carrier: carrier.update({"traceparent": "otel-traceparent"}), + ), + ): + await server.call_tool("test_tool", param="test_value", _meta=user_meta) + + call_kwargs = server.session.call_tool.call_args.kwargs # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert call_kwargs["meta"] == { + "traceparent": "tool-traceparent", + "from_tool": "tool-value", + "from_user": "user-value", + } + + async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client(): """Test that calling get_mcp_client multiple times does not accumulate duplicate hooks.""" tool = MCPStreamableHTTPTool( @@ -6475,6 +6725,30 @@ def test_prepare_call_kwargs_extracts_meta() -> None: assert meta.get("trace") == "abc" +@pytest.mark.parametrize( + "key", + [ + "", + "_leading-underscore", + "trailing-underscore_", + "abc/", + "1bad.example/name", + "bad..example/name", + "bad.example/_name", + "bad.example/name_", + ], +) +def test_prepare_call_kwargs_rejects_invalid_meta_key_names(key: str) -> None: + server = MCPTool(name="test_server") # type: ignore[abstract] + server._tool_param_names_by_name = {"test_tool": {"param"}} + + with pytest.raises(ToolExecutionException, match="Invalid MCP _meta key name"): + server._prepare_call_kwargs( + "test_tool", + {"param": "v", "_meta": {key: "value"}}, + ) + + async def test_call_tool_forwards_only_declared_arguments() -> None: """End-to-end: framework runtime kwargs are stripped before reaching the server.""" diff --git a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py index 304cdc095f..e2c5feb709 100644 --- a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py +++ b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py @@ -86,6 +86,10 @@ def cache_key(self) -> tuple[Any, ...]: ) +class _OutputDirectory(Protocol): + name: str + + class SandboxRuntime(Protocol): def execute(self, *, config: _RunConfig, code: str) -> list[Content]: ... @@ -725,7 +729,7 @@ def _collect_output_relative_paths(*, sandbox: Any, root: Path) -> set[str]: def _parse_output_files( *, sandbox: Any, - output_dir: TemporaryDirectory[str] | None, + output_dir: _OutputDirectory | None, expect_output_files: bool, ) -> list[Content]: if output_dir is None: