Skip to content

Convert MCP schemas to strict where possible #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast

from typing_extensions import TypeAlias, TypedDict
from typing_extensions import NotRequired, TypeAlias, TypedDict

from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
Expand Down Expand Up @@ -53,6 +53,15 @@ class StopAtTools(TypedDict):
"""A list of tool names, any of which will stop the agent from running further."""


class MCPConfig(TypedDict):
"""Configuration for MCP servers."""

convert_schemas_to_strict: NotRequired[bool]
Copy link
Contributor

@pakrym-oai pakrym-oai Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It slightly feels like the wrong place to do this. IMO the MCP server should be responsible to declaring itself as strict.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you mean the actual MCP server, then yes I agree. That is a more involved change though, since MCP servers adhere to a published spec.

I didn't add it to the MCP server class in the repo, because I think those should be very thin wrappers around the actual MCP server, in order to be reused in other places outside the agents SDK internals.

Gonna merge this - but if you think we should implement it differently ping me and we can revert/re-do it!

"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
best-effort conversion, so some schemas may not be convertible. Defaults to False.
"""


@dataclass
class Agent(Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
Expand Down Expand Up @@ -119,6 +128,9 @@ class Agent(Generic[TContext]):
longer needed.
"""

mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
"""Configuration for MCP servers."""

input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
"""A list of checks that run in parallel to the agent's execution, before generating a
response. Runs only if the agent is the first agent in the chain.
Expand Down Expand Up @@ -224,7 +236,8 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s

async def get_mcp_tools(self) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
return await MCPUtil.get_all_function_tools(self.mcp_servers)
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)

async def get_all_tools(self) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
Expand Down
30 changes: 23 additions & 7 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
from typing import TYPE_CHECKING, Any

from agents.strict_schema import ensure_strict_json_schema

from .. import _debug
from ..exceptions import AgentsException, ModelBehaviorError, UserError
from ..logger import logger
Expand All @@ -19,12 +21,14 @@ class MCPUtil:
"""Set of utilities for interop between MCP and Agents SDK tools."""

@classmethod
async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
async def get_all_function_tools(
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
for server in servers:
server_tools = await cls.get_function_tools(server)
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
raise UserError(
Expand All @@ -37,25 +41,37 @@ async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
return tools

@classmethod
async def get_function_tools(cls, server: "MCPServer") -> list[Tool]:
async def get_function_tools(
cls, server: "MCPServer", convert_schemas_to_strict: bool
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools()
span.span_data.result = [tool.name for tool in tools]

return [cls.to_function_tool(tool, server) for tool in tools]
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]

@classmethod
def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool:
def to_function_tool(
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool."""
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
schema, is_strict = tool.inputSchema, False
if convert_schemas_to_strict:
try:
schema = ensure_strict_json_schema(schema)
is_strict = True
except Exception as e:
logger.info(f"Error converting MCP schema to strict mode: {e}")

return FunctionTool(
name=tool.name,
description=tool.description or "",
params_json_schema=tool.inputSchema,
params_json_schema=schema,
on_invoke_tool=invoke_func,
strict_json_schema=False,
strict_json_schema=is_strict,
)

@classmethod
Expand Down
161 changes: 157 additions & 4 deletions tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Any

import pytest
from inline_snapshot import snapshot
from mcp.types import Tool as MCPTool
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

from agents import FunctionTool, RunContextWrapper
from agents import Agent, FunctionTool, RunContextWrapper
from agents.exceptions import AgentsException, ModelBehaviorError
from agents.mcp import MCPServer, MCPUtil

Expand All @@ -18,7 +19,16 @@ class Foo(BaseModel):


class Bar(BaseModel):
qux: str
qux: dict[str, str]


Baz = TypeAdapter(dict[str, str])


def _convertible_schema() -> dict[str, Any]:
schema = Foo.model_json_schema()
schema["additionalProperties"] = False
return schema


@pytest.mark.asyncio
Expand Down Expand Up @@ -47,7 +57,7 @@ async def test_get_all_function_tools():
server3.add_tool(names[4], schemas[4])

servers: list[MCPServer] = [server1, server2, server3]
tools = await MCPUtil.get_all_function_tools(servers)
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False)
assert len(tools) == 5
assert all(tool.name in names for tool in tools)

Expand All @@ -56,6 +66,11 @@ async def test_get_all_function_tools():
assert tool.params_json_schema == schemas[idx]
assert tool.name == names[idx]

# Also make sure it works with strict schemas
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True)
assert len(tools) == 5
assert all(tool.name in names for tool in tools)


@pytest.mark.asyncio
async def test_invoke_mcp_tool():
Expand Down Expand Up @@ -107,3 +122,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")

assert "Error invoking MCP tool test_tool_1" in caplog.text


@pytest.mark.asyncio
async def test_agent_convert_schemas_true():
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
- 'foo' tool is already strict and remains strict.
- 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc).
"""
strict_schema = Foo.model_json_schema()
non_strict_schema = Baz.json_schema()
possible_to_convert_schema = _convertible_schema()

server = FakeMCPServer()
server.add_tool("foo", strict_schema)
server.add_tool("bar", non_strict_schema)
server.add_tool("baz", possible_to_convert_schema)
agent = Agent(
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True}
)
tools = await agent.get_mcp_tools()

foo_tool = next(tool for tool in tools if tool.name == "foo")
assert isinstance(foo_tool, FunctionTool)
bar_tool = next(tool for tool in tools if tool.name == "bar")
assert isinstance(bar_tool, FunctionTool)
baz_tool = next(tool for tool in tools if tool.name == "baz")
assert isinstance(baz_tool, FunctionTool)

# Checks that additionalProperties is set to False
assert foo_tool.params_json_schema == snapshot(
{
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"title": "Baz", "type": "integer"},
},
"required": ["bar", "baz"],
"title": "Foo",
"type": "object",
"additionalProperties": False,
}
)
assert foo_tool.strict_json_schema is True, "foo_tool should be strict"

# Checks that additionalProperties is set to False
assert bar_tool.params_json_schema == snapshot(
{
"type": "object",
"additionalProperties": {"type": "string"},
}
)
assert bar_tool.strict_json_schema is False, "bar_tool should not be strict"

# Checks that additionalProperties is set to False
assert baz_tool.params_json_schema == snapshot(
{
"properties": {
"bar": {"title": "Bar", "type": "string"},
"baz": {"title": "Baz", "type": "integer"},
},
"required": ["bar", "baz"],
"title": "Foo",
"type": "object",
"additionalProperties": False,
}
)
assert baz_tool.strict_json_schema is True, "baz_tool should be strict"


@pytest.mark.asyncio
async def test_agent_convert_schemas_false():
"""Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict.
- 'foo' tool remains strict.
- 'bar' tool remains non-strict (additionalProperties remains True).
"""
strict_schema = Foo.model_json_schema()
non_strict_schema = Baz.json_schema()
possible_to_convert_schema = _convertible_schema()

server = FakeMCPServer()
server.add_tool("foo", strict_schema)
server.add_tool("bar", non_strict_schema)
server.add_tool("baz", possible_to_convert_schema)

agent = Agent(
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False}
)
tools = await agent.get_mcp_tools()

foo_tool = next(tool for tool in tools if tool.name == "foo")
assert isinstance(foo_tool, FunctionTool)
bar_tool = next(tool for tool in tools if tool.name == "bar")
assert isinstance(bar_tool, FunctionTool)
baz_tool = next(tool for tool in tools if tool.name == "baz")
assert isinstance(baz_tool, FunctionTool)

assert foo_tool.params_json_schema == strict_schema
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"

assert bar_tool.params_json_schema == non_strict_schema
assert bar_tool.strict_json_schema is False

assert baz_tool.params_json_schema == possible_to_convert_schema
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"


@pytest.mark.asyncio
async def test_agent_convert_schemas_unset():
"""Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas
as non-strict.
- 'foo' tool remains strict.
- 'bar' tool remains non-strict.
"""
strict_schema = Foo.model_json_schema()
non_strict_schema = Baz.json_schema()
possible_to_convert_schema = _convertible_schema()

server = FakeMCPServer()
server.add_tool("foo", strict_schema)
server.add_tool("bar", non_strict_schema)
server.add_tool("baz", possible_to_convert_schema)
agent = Agent(name="test_agent", mcp_servers=[server])
tools = await agent.get_mcp_tools()

foo_tool = next(tool for tool in tools if tool.name == "foo")
assert isinstance(foo_tool, FunctionTool)
bar_tool = next(tool for tool in tools if tool.name == "bar")
assert isinstance(bar_tool, FunctionTool)
baz_tool = next(tool for tool in tools if tool.name == "baz")
assert isinstance(baz_tool, FunctionTool)

assert foo_tool.params_json_schema == strict_schema
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"

assert bar_tool.params_json_schema == non_strict_schema
assert bar_tool.strict_json_schema is False

assert baz_tool.params_json_schema == possible_to_convert_schema
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
4 changes: 1 addition & 3 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,7 @@ async def test_tool_use_behavior_custom_function():
async def test_model_settings_override():
model = FakeModel()
agent = Agent(
name="test",
model=model,
model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
name="test", model=model, model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
)

model.add_multiple_turn_outputs(
Expand Down
14 changes: 6 additions & 8 deletions tests/test_tracing_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ async def test_multiple_handoff_doesnt_error():
},
},
{"type": "generation"},
{"type": "handoff",
"data": {"from_agent": "test", "to_agent": "test"},
"error": {
{
"type": "handoff",
"data": {"from_agent": "test", "to_agent": "test"},
"error": {
"data": {
"requested_agents": [
"test",
Expand All @@ -255,7 +256,7 @@ async def test_multiple_handoff_doesnt_error():
},
"message": "Multiple handoffs requested",
},
},
},
],
},
{
Expand Down Expand Up @@ -383,10 +384,7 @@ async def test_handoffs_lead_to_correct_agent_spans():
{"type": "generation"},
{
"type": "handoff",
"data": {
"from_agent": "test_agent_3",
"to_agent": "test_agent_1"
},
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
"error": {
"data": {
"requested_agents": [
Expand Down