diff --git a/src/agents/agent.py b/src/agents/agent.py index b2b9f7b2..4c6de245 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import NotRequired, TypeAlias, TypedDict from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff @@ -53,6 +53,15 @@ class StopAtTools(TypedDict): """A list of tool names, any of which will stop the agent from running further.""" +class MCPConfig(TypedDict): + """Configuration for MCP servers.""" + + convert_schemas_to_strict: NotRequired[bool] + """If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a + best-effort conversion, so some schemas may not be convertible. Defaults to False. + """ + + @dataclass class Agent(Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. @@ -119,6 +128,9 @@ class Agent(Generic[TContext]): longer needed. """ + mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) + """Configuration for MCP servers.""" + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -224,7 +236,8 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s async def get_mcp_tools(self) -> list[Tool]: """Fetches the available tools from the MCP servers.""" - return await MCPUtil.get_all_function_tools(self.mcp_servers) + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) async def get_all_tools(self) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 36c18bea..770ae8dd 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -2,6 +2,8 @@ import json from typing import TYPE_CHECKING, Any +from agents.strict_schema import ensure_strict_json_schema + from .. import _debug from ..exceptions import AgentsException, ModelBehaviorError, UserError from ..logger import logger @@ -19,12 +21,14 @@ class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" @classmethod - async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]: + async def get_all_function_tools( + cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: - server_tools = await cls.get_function_tools(server) + server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -37,25 +41,37 @@ async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]: return tools @classmethod - async def get_function_tools(cls, server: "MCPServer") -> list[Tool]: + async def get_function_tools( + cls, server: "MCPServer", convert_schemas_to_strict: bool + ) -> list[Tool]: """Get all function tools from a single MCP server.""" with mcp_tools_span(server=server.name) as span: tools = await server.list_tools() span.span_data.result = [tool.name for tool in tools] - return [cls.to_function_tool(tool, server) for tool in tools] + return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] @classmethod - def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool: + def to_function_tool( + cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool + ) -> FunctionTool: """Convert an MCP tool to an Agents SDK function tool.""" invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + schema, is_strict = tool.inputSchema, False + if convert_schemas_to_strict: + try: + schema = ensure_strict_json_schema(schema) + is_strict = True + except Exception as e: + logger.info(f"Error converting MCP schema to strict mode: {e}") + return FunctionTool( name=tool.name, description=tool.description or "", - params_json_schema=tool.inputSchema, + params_json_schema=schema, on_invoke_tool=invoke_func, - strict_json_schema=False, + strict_json_schema=is_strict, ) @classmethod diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 345df996..64378b59 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -2,10 +2,11 @@ from typing import Any import pytest +from inline_snapshot import snapshot from mcp.types import Tool as MCPTool -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter -from agents import FunctionTool, RunContextWrapper +from agents import Agent, FunctionTool, RunContextWrapper from agents.exceptions import AgentsException, ModelBehaviorError from agents.mcp import MCPServer, MCPUtil @@ -18,7 +19,16 @@ class Foo(BaseModel): class Bar(BaseModel): - qux: str + qux: dict[str, str] + + +Baz = TypeAdapter(dict[str, str]) + + +def _convertible_schema() -> dict[str, Any]: + schema = Foo.model_json_schema() + schema["additionalProperties"] = False + return schema @pytest.mark.asyncio @@ -47,7 +57,7 @@ async def test_get_all_function_tools(): server3.add_tool(names[4], schemas[4]) servers: list[MCPServer] = [server1, server2, server3] - tools = await MCPUtil.get_all_function_tools(servers) + tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -56,6 +66,11 @@ async def test_get_all_function_tools(): assert tool.params_json_schema == schemas[idx] assert tool.name == names[idx] + # Also make sure it works with strict schemas + tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True) + assert len(tools) == 5 + assert all(tool.name in names for tool in tools) + @pytest.mark.asyncio async def test_invoke_mcp_tool(): @@ -107,3 +122,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") assert "Error invoking MCP tool test_tool_1" in caplog.text + + +@pytest.mark.asyncio +async def test_agent_convert_schemas_true(): + """Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict. + - 'foo' tool is already strict and remains strict. + - 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc). + """ + strict_schema = Foo.model_json_schema() + non_strict_schema = Baz.json_schema() + possible_to_convert_schema = _convertible_schema() + + server = FakeMCPServer() + server.add_tool("foo", strict_schema) + server.add_tool("bar", non_strict_schema) + server.add_tool("baz", possible_to_convert_schema) + agent = Agent( + name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True} + ) + tools = await agent.get_mcp_tools() + + foo_tool = next(tool for tool in tools if tool.name == "foo") + assert isinstance(foo_tool, FunctionTool) + bar_tool = next(tool for tool in tools if tool.name == "bar") + assert isinstance(bar_tool, FunctionTool) + baz_tool = next(tool for tool in tools if tool.name == "baz") + assert isinstance(baz_tool, FunctionTool) + + # Checks that additionalProperties is set to False + assert foo_tool.params_json_schema == snapshot( + { + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "baz": {"title": "Baz", "type": "integer"}, + }, + "required": ["bar", "baz"], + "title": "Foo", + "type": "object", + "additionalProperties": False, + } + ) + assert foo_tool.strict_json_schema is True, "foo_tool should be strict" + + # Checks that additionalProperties is set to False + assert bar_tool.params_json_schema == snapshot( + { + "type": "object", + "additionalProperties": {"type": "string"}, + } + ) + assert bar_tool.strict_json_schema is False, "bar_tool should not be strict" + + # Checks that additionalProperties is set to False + assert baz_tool.params_json_schema == snapshot( + { + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "baz": {"title": "Baz", "type": "integer"}, + }, + "required": ["bar", "baz"], + "title": "Foo", + "type": "object", + "additionalProperties": False, + } + ) + assert baz_tool.strict_json_schema is True, "baz_tool should be strict" + + +@pytest.mark.asyncio +async def test_agent_convert_schemas_false(): + """Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict. + - 'foo' tool remains strict. + - 'bar' tool remains non-strict (additionalProperties remains True). + """ + strict_schema = Foo.model_json_schema() + non_strict_schema = Baz.json_schema() + possible_to_convert_schema = _convertible_schema() + + server = FakeMCPServer() + server.add_tool("foo", strict_schema) + server.add_tool("bar", non_strict_schema) + server.add_tool("baz", possible_to_convert_schema) + + agent = Agent( + name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False} + ) + tools = await agent.get_mcp_tools() + + foo_tool = next(tool for tool in tools if tool.name == "foo") + assert isinstance(foo_tool, FunctionTool) + bar_tool = next(tool for tool in tools if tool.name == "bar") + assert isinstance(bar_tool, FunctionTool) + baz_tool = next(tool for tool in tools if tool.name == "baz") + assert isinstance(baz_tool, FunctionTool) + + assert foo_tool.params_json_schema == strict_schema + assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified" + + assert bar_tool.params_json_schema == non_strict_schema + assert bar_tool.strict_json_schema is False + + assert baz_tool.params_json_schema == possible_to_convert_schema + assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified" + + +@pytest.mark.asyncio +async def test_agent_convert_schemas_unset(): + """Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas + as non-strict. + - 'foo' tool remains strict. + - 'bar' tool remains non-strict. + """ + strict_schema = Foo.model_json_schema() + non_strict_schema = Baz.json_schema() + possible_to_convert_schema = _convertible_schema() + + server = FakeMCPServer() + server.add_tool("foo", strict_schema) + server.add_tool("bar", non_strict_schema) + server.add_tool("baz", possible_to_convert_schema) + agent = Agent(name="test_agent", mcp_servers=[server]) + tools = await agent.get_mcp_tools() + + foo_tool = next(tool for tool in tools if tool.name == "foo") + assert isinstance(foo_tool, FunctionTool) + bar_tool = next(tool for tool in tools if tool.name == "bar") + assert isinstance(bar_tool, FunctionTool) + baz_tool = next(tool for tool in tools if tool.name == "baz") + assert isinstance(baz_tool, FunctionTool) + + assert foo_tool.params_json_schema == strict_schema + assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified" + + assert bar_tool.params_json_schema == non_strict_schema + assert bar_tool.strict_json_schema is False + + assert baz_tool.params_json_schema == possible_to_convert_schema + assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified" diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index ce0c5804..4f277656 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -642,9 +642,7 @@ async def test_tool_use_behavior_custom_function(): async def test_model_settings_override(): model = FakeModel() agent = Agent( - name="test", - model=model, - model_settings=ModelSettings(temperature=1.0, max_tokens=1000) + name="test", model=model, model_settings=ModelSettings(temperature=1.0, max_tokens=1000) ) model.add_multiple_turn_outputs( diff --git a/tests/test_tracing_errors.py b/tests/test_tracing_errors.py index 6d698bcc..72bd39ed 100644 --- a/tests/test_tracing_errors.py +++ b/tests/test_tracing_errors.py @@ -244,9 +244,10 @@ async def test_multiple_handoff_doesnt_error(): }, }, {"type": "generation"}, - {"type": "handoff", - "data": {"from_agent": "test", "to_agent": "test"}, - "error": { + { + "type": "handoff", + "data": {"from_agent": "test", "to_agent": "test"}, + "error": { "data": { "requested_agents": [ "test", @@ -255,7 +256,7 @@ async def test_multiple_handoff_doesnt_error(): }, "message": "Multiple handoffs requested", }, - }, + }, ], }, { @@ -383,10 +384,7 @@ async def test_handoffs_lead_to_correct_agent_spans(): {"type": "generation"}, { "type": "handoff", - "data": { - "from_agent": "test_agent_3", - "to_agent": "test_agent_1" - }, + "data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"}, "error": { "data": { "requested_agents": [