diff --git a/pyproject.toml b/pyproject.toml index 667ab355..3678c714 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", + "mcp; python_version >= '3.10'", ] classifiers = [ "Typing :: Typed", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2849538d..02e3bf58 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -50,7 +50,7 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool from .tracing import ( SpanError, Trace, @@ -301,6 +301,7 @@ def process_model_response( cls, *, agent: Agent[Any], + all_tools: list[Tool], response: ModelResponse, output_schema: AgentOutputSchema | None, handoffs: list[Handoff], @@ -312,8 +313,8 @@ def process_model_response( computer_actions = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} - function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)} - computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: if isinstance(output, ResponseOutputMessage): diff --git a/src/agents/agent.py b/src/agents/agent.py index 2723e678..3258e15a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -12,6 +12,7 @@ from .handoffs import Handoff from .items import ItemHelpers from .logger import logger +from .mcp import MCPUtil from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext @@ -21,6 +22,7 @@ if TYPE_CHECKING: from .lifecycle import AgentHooks + from .mcp import MCPServer from .result import RunResult @@ -107,6 +109,16 @@ class Agent(Generic[TContext]): tools: list[Tool] = field(default_factory=list) """A list of tools that the agent can use.""" + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + async def get_mcp_tools(self) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + + async def get_all_tools(self) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py new file mode 100644 index 00000000..1a72a89f --- /dev/null +++ b/src/agents/mcp/__init__.py @@ -0,0 +1,21 @@ +try: + from .server import ( + MCPServer, + MCPServerSse, + MCPServerSseParams, + MCPServerStdio, + MCPServerStdioParams, + ) +except ImportError: + pass + +from .util import MCPUtil + +__all__ = [ + "MCPServer", + "MCPServerSse", + "MCPServerSseParams", + "MCPServerStdio", + "MCPServerStdioParams", + "MCPUtil", +] diff --git a/src/agents/mcp/mcp_util.py b/src/agents/mcp/mcp_util.py new file mode 100644 index 00000000..41b4c521 --- /dev/null +++ b/src/agents/mcp/mcp_util.py @@ -0,0 +1,94 @@ +import functools +import json +from typing import Any + +from mcp.types import Tool as MCPTool + +from .. import _debug +from ..exceptions import AgentsException, ModelBehaviorError, UserError +from ..logger import logger +from ..run_context import RunContextWrapper +from ..tool import FunctionTool, Tool +from .server import MCPServer + + +class MCPUtil: + """Set of utilities for interop between MCP and Agents SDK tools.""" + + @classmethod + async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]: + """Get all function tools from a list of MCP servers.""" + tools = [] + tool_names: set[str] = set() + for server in servers: + server_tools = await cls.get_function_tools(server) + server_tool_names = {tool.name for tool in server_tools} + if len(server_tool_names & tool_names) > 0: + raise UserError( + f"Duplicate tool names found across MCP servers: " + f"{server_tool_names & tool_names}" + ) + tool_names.update(server_tool_names) + tools.extend(server_tools) + + return tools + + @classmethod + async def get_function_tools(cls, server: MCPServer) -> list[Tool]: + """Get all function tools from a single MCP server.""" + tools = await server.list_tools() + return [cls.to_function_tool(tool, server) for tool in tools] + + @classmethod + def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool: + """Convert an MCP tool to an Agents SDK function tool.""" + invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + return FunctionTool( + name=tool.name, + description=tool.description or "", + params_json_schema=tool.inputSchema, + on_invoke_tool=invoke_func, + strict_json_schema=False, + ) + + @classmethod + async def invoke_mcp_tool( + cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str + ) -> str: + """Invoke an MCP tool and return the result as a string.""" + try: + json_data: dict[str, Any] = json.loads(input_json) if input_json else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {tool.name}") + else: + logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {tool.name}: {input_json}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking MCP tool {tool.name}") + else: + logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + + try: + result = await server.call_tool(tool.name, json_data) + except Exception as e: + logger.error(f"Error invoking MCP tool {tool.name}: {e}") + raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"MCP tool {tool.name} completed.") + else: + logger.debug(f"MCP tool {tool.name} returned {result}") + + # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single + # string. We'll try to convert. + if len(result.content) == 1: + return result.content[0].model_dump_json() + elif len(result.content) > 1: + return json.dumps([item.model_dump() for item in result.content]) + else: + logger.error(f"Errored MCP tool result: {result}") + return "Error running tool." diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py new file mode 100644 index 00000000..e19e686a --- /dev/null +++ b/src/agents/mcp/server.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import abc +import asyncio +from contextlib import AbstractAsyncContextManager, AsyncExitStack +from pathlib import Path +from typing import Any, Literal + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.sse import sse_client +from mcp.types import CallToolResult, JSONRPCMessage +from typing_extensions import NotRequired, TypedDict + +from ..exceptions import UserError +from ..logger import logger + + +class MCPServer(abc.ABC): + """Base class for Model Context Protocol servers.""" + + @abc.abstractmethod + async def connect(self): + """Connect to the server. For example, this might mean spawning a subprocess or + opening a network connection. The server is expected to remain connected until + `cleanup()` is called. + """ + pass + + @abc.abstractmethod + async def cleanup(self): + """Cleanup the server. For example, this might mean closing a subprocess or + closing a network connection. + """ + pass + + @abc.abstractmethod + async def list_tools(self) -> list[MCPTool]: + """List the tools available on the server.""" + pass + + @abc.abstractmethod + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + """Invoke a tool on the server.""" + pass + + +class _MCPServerWithClientSession(MCPServer, abc.ABC): + """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" + + def __init__(self, cache_tools_list: bool): + """ + Args: + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be invalidated + by calling `invalidate_tools_cache()`. You should set this to `True` if you know the + server will not change its tools list, because it can drastically improve latency + (by avoiding a round-trip to the server every time). + """ + self.session: ClientSession | None = None + self.exit_stack: AsyncExitStack = AsyncExitStack() + self._cleanup_lock: asyncio.Lock = asyncio.Lock() + self.cache_tools_list = cache_tools_list + + # The cache is always dirty at startup, so that we fetch tools at least once + self._cache_dirty = True + self._tools_list: list[MCPTool] | None = None + + @abc.abstractmethod + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + pass + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.cleanup() + + def invalidate_tools_cache(self): + """Invalidate the tools cache.""" + self._cache_dirty = True + + async def connect(self): + """Connect to the server.""" + try: + transport = await self.exit_stack.enter_async_context(self.create_streams()) + read, write = transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + except Exception as e: + logger.error(f"Error initializing MCP server: {e}") + await self.cleanup() + raise + + async def list_tools(self) -> list[MCPTool]: + """List the tools available on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + # Return from cache if caching is enabled, we have tools, and the cache is not dirty + if self.cache_tools_list and not self._cache_dirty and self._tools_list: + return self._tools_list + + # Reset the cache dirty to False + self._cache_dirty = False + + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + return self._tools_list + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + """Invoke a tool on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.call_tool(tool_name, arguments) + + async def cleanup(self): + """Cleanup the server.""" + async with self._cleanup_lock: + try: + await self.exit_stack.aclose() + self.session = None + except Exception as e: + logger.error(f"Error cleaning up server: {e}") + + +class MCPServerStdioParams(TypedDict): + """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another + import. + """ + + command: str + """The executable to run to start the server. For example, `python` or `node`.""" + + args: NotRequired[list[str]] + """Command line args to pass to the `command` executable. For example, `['foo.py']` or + `['server.js', '--port', '8080']`.""" + + env: NotRequired[dict[str, str]] + """The environment variables to set for the server. .""" + + cwd: NotRequired[str | Path] + """The working directory to use when spawning the process.""" + + encoding: NotRequired[str] + """The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`.""" + + encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] + """The text encoding error handler. Defaults to `strict`. + + See https://docs.python.org/3/library/codecs.html#codec-base-classes for + explanations of possible values. + """ + + +class MCPServerStdio(_MCPServerWithClientSession): + """MCP server implementation that uses the stdio transport. See the [spec] + (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for + details. + """ + + def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False): + """Create a new MCP server based on the stdio transport. + + Args: + params: The params that configure the server. This includes: + - The command (e.g. `python` or `node`) that starts the server. + - The args to pass to the server command (e.g. `foo.py` or `server.js`). + - The environment variables to set for the server. + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + """ + super().__init__(cache_tools_list) + + self.params = StdioServerParameters( + command=params["command"], + args=params.get("args", []), + env=params.get("env"), + cwd=params.get("cwd"), + encoding=params.get("encoding", "utf-8"), + encoding_error_handler=params.get("encoding_error_handler", "strict"), + ) + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + return stdio_client(self.params) + + +class MCPServerSseParams(TypedDict): + """Mirrors the params in`mcp.client.sse.sse_client`.""" + + url: str + """The URL of the server.""" + + headers: NotRequired[dict[str, str]] + """The headers to send to the server.""" + + timeout: NotRequired[float] + """The timeout for the HTTP request. Defaults to 5 seconds.""" + + sse_read_timeout: NotRequired[float] + """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" + + +class MCPServerSse(_MCPServerWithClientSession): + """MCP server implementation that uses the HTTP with SSE transport. See the [spec] + (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) + for details. + """ + + def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False): + """Create a new MCP server based on the HTTP with SSE transport. + + Args: + params: The params that configure the server. This includes: + - The URL of the server. + - The headers to send to the server. + - The timeout for the HTTP request. + - The timeout for the SSE connection. + + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + """ + super().__init__(cache_tools_list) + + self.params = params + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Create the streams for the server.""" + return sse_client( + url=self.params["url"], + headers=self.params.get("headers", None), + timeout=self.params.get("timeout", 5), + sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), + ) diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py new file mode 100644 index 00000000..038c4fec --- /dev/null +++ b/src/agents/mcp/util.py @@ -0,0 +1,96 @@ +import functools +import json +from typing import TYPE_CHECKING, Any + +from .. import _debug +from ..exceptions import AgentsException, ModelBehaviorError, UserError +from ..logger import logger +from ..run_context import RunContextWrapper +from ..tool import FunctionTool, Tool + +if TYPE_CHECKING: + from mcp.types import Tool as MCPTool + + from .server import MCPServer + + +class MCPUtil: + """Set of utilities for interop between MCP and Agents SDK tools.""" + + @classmethod + async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]: + """Get all function tools from a list of MCP servers.""" + tools = [] + tool_names: set[str] = set() + for server in servers: + server_tools = await cls.get_function_tools(server) + server_tool_names = {tool.name for tool in server_tools} + if len(server_tool_names & tool_names) > 0: + raise UserError( + f"Duplicate tool names found across MCP servers: " + f"{server_tool_names & tool_names}" + ) + tool_names.update(server_tool_names) + tools.extend(server_tools) + + return tools + + @classmethod + async def get_function_tools(cls, server: "MCPServer") -> list[Tool]: + """Get all function tools from a single MCP server.""" + tools = await server.list_tools() + return [cls.to_function_tool(tool, server) for tool in tools] + + @classmethod + def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool: + """Convert an MCP tool to an Agents SDK function tool.""" + invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + return FunctionTool( + name=tool.name, + description=tool.description or "", + params_json_schema=tool.inputSchema, + on_invoke_tool=invoke_func, + strict_json_schema=False, + ) + + @classmethod + async def invoke_mcp_tool( + cls, server: "MCPServer", tool: "MCPTool", context: RunContextWrapper[Any], input_json: str + ) -> str: + """Invoke an MCP tool and return the result as a string.""" + try: + json_data: dict[str, Any] = json.loads(input_json) if input_json else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {tool.name}") + else: + logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {tool.name}: {input_json}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking MCP tool {tool.name}") + else: + logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + + try: + result = await server.call_tool(tool.name, json_data) + except Exception as e: + logger.error(f"Error invoking MCP tool {tool.name}: {e}") + raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"MCP tool {tool.name} completed.") + else: + logger.debug(f"MCP tool {tool.name} returned {result}") + + # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single + # string. We'll try to convert. + if len(result.content) == 1: + return result.content[0].model_dump_json() + elif len(result.content) > 1: + return json.dumps([item.model_dump() for item in result.content]) + else: + logger.error(f"Errored MCP tool result: {result}") + return "Error running tool." diff --git a/src/agents/run.py b/src/agents/run.py index 934400fe..b7ac85f9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,6 +7,8 @@ from openai.types.responses import ResponseCompletedEvent +from agents.tool import Tool + from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -177,7 +179,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - tool_names = [t.name for t in current_agent.tools] + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -217,6 +220,7 @@ async def run( ), cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -228,6 +232,7 @@ async def run( else: turn_result = await cls._run_single_turn( agent=current_agent, + all_tools=all_tools, original_input=original_input, generated_items=generated_items, hooks=hooks, @@ -627,7 +632,7 @@ async def _run_single_turn_streamed( system_prompt = await agent.get_system_prompt(context_wrapper) handoffs = cls._get_handoffs(agent) - + all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) final_response: ModelResponse | None = None @@ -640,7 +645,7 @@ async def _run_single_turn_streamed( system_prompt, input, model_settings, - agent.tools, + all_tools, output_schema, handoffs, get_model_tracing_impl( @@ -677,6 +682,7 @@ async def _run_single_turn_streamed( pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -691,6 +697,7 @@ async def _run_single_turn( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], generated_items: list[RunItem], hooks: RunHooks[TContext], @@ -721,6 +728,7 @@ async def _run_single_turn( system_prompt, input, output_schema, + all_tools, handoffs, context_wrapper, run_config, @@ -732,6 +740,7 @@ async def _run_single_turn( pre_step_items=generated_items, new_response=new_response, output_schema=output_schema, + all_tools=all_tools, handoffs=handoffs, hooks=hooks, context_wrapper=context_wrapper, @@ -743,6 +752,7 @@ async def _get_single_step_result_from_response( cls, *, agent: Agent[TContext], + all_tools: list[Tool], original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, @@ -754,6 +764,7 @@ async def _get_single_step_result_from_response( ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, + all_tools=all_tools, response=new_response, output_schema=output_schema, handoffs=handoffs, @@ -853,6 +864,7 @@ async def _get_new_response( system_prompt: str | None, input: list[TResponseInputItem], output_schema: AgentOutputSchema | None, + all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -863,7 +875,7 @@ async def _get_new_response( system_instructions=system_prompt, input=input, model_settings=model_settings, - tools=agent.tools, + tools=all_tools, output_schema=output_schema, handoffs=handoffs, tracing=get_model_tracing_impl( @@ -892,6 +904,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: handoffs.append(handoff(handoff_item)) return handoffs + @classmethod + async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: + return await agent.get_all_tools() + @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: if isinstance(run_config.model, Model): diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mcp/conftest.py b/tests/mcp/conftest.py new file mode 100644 index 00000000..80fd15ec --- /dev/null +++ b/tests/mcp/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + + +# Skip MCP tests on Python 3.9 +def pytest_ignore_collect(collection_path, config): + if sys.version_info[:2] == (3, 9): + this_dir = os.path.dirname(__file__) + + if str(collection_path).startswith(this_dir): + return True diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py new file mode 100644 index 00000000..952b3ea7 --- /dev/null +++ b/tests/mcp/helpers.py @@ -0,0 +1,54 @@ +import json +import shutil +from typing import Any + +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, TextContent + +from agents.mcp import MCPServer + +tee = shutil.which("tee") or "" +assert tee, "tee not found" + + +# Added dummy stream classes for patching stdio_client to avoid real I/O during tests +class DummyStream: + async def send(self, msg): + pass + + async def receive(self): + raise Exception("Dummy receive not implemented") + + +class DummyStreamsContextManager: + async def __aenter__(self): + return (DummyStream(), DummyStream()) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeMCPServer(MCPServer): + def __init__(self, tools: list[MCPTool] | None = None): + self.tools: list[MCPTool] = tools or [] + self.tool_calls: list[str] = [] + self.tool_results: list[str] = [] + + def add_tool(self, name: str, input_schema: dict[str, Any]): + self.tools.append(MCPTool(name=name, inputSchema=input_schema)) + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_tools(self): + return self.tools + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + self.tool_calls.append(tool_name) + self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}") + return CallToolResult( + content=[TextContent(text=self.tool_results[-1], type="text")], + ) diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py new file mode 100644 index 00000000..cac409e6 --- /dev/null +++ b/tests/mcp/test_caching.py @@ -0,0 +1,57 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_server_caching_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we turn caching on, the list of tools is cached and not fetched from the server + on each call to `list_tools()`. + """ + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + async with server: + # Call list_tools() multiple times + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should have been called once" + + # Call list_tools() again, should return the cached value + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" + + # Invalidate the cache and call list_tools() again + server.invalidate_tools_cache() + tools = await server.list_tools() + assert tools == tools + + assert mock_list_tools.call_count == 2, "list_tools() should be called again" + + # Without invalidating the cache, calling list_tools() again should return the cached value + tools = await server.list_tools() + assert tools == tools diff --git a/tests/mcp/test_connect_disconnect.py b/tests/mcp/test_connect_disconnect.py new file mode 100644 index 00000000..b0013039 --- /dev/null +++ b/tests/mcp/test_connect_disconnect.py @@ -0,0 +1,69 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_async_ctx_manager_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + async with server: + assert server.session is not None, "Server should be connected" + + assert server.session is None, "Server should be disconnected" + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_manual_connect_disconnect_works( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that the async context manager works.""" + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=tools) + + assert server.session is None, "Server should not be connected" + + await server.connect() + assert server.session is not None, "Server should be connected" + + await server.cleanup() + assert server.session is None, "Server should be disconnected" diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py new file mode 100644 index 00000000..345df996 --- /dev/null +++ b/tests/mcp/test_mcp_util.py @@ -0,0 +1,109 @@ +import logging +from typing import Any + +import pytest +from mcp.types import Tool as MCPTool +from pydantic import BaseModel + +from agents import FunctionTool, RunContextWrapper +from agents.exceptions import AgentsException, ModelBehaviorError +from agents.mcp import MCPServer, MCPUtil + +from .helpers import FakeMCPServer + + +class Foo(BaseModel): + bar: str + baz: int + + +class Bar(BaseModel): + qux: str + + +@pytest.mark.asyncio +async def test_get_all_function_tools(): + """Test that the get_all_function_tools function returns all function tools from a list of MCP + servers. + """ + names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"] + schemas = [ + {}, + {}, + {}, + Foo.model_json_schema(), + Bar.model_json_schema(), + ] + + server1 = FakeMCPServer() + server1.add_tool(names[0], schemas[0]) + server1.add_tool(names[1], schemas[1]) + + server2 = FakeMCPServer() + server2.add_tool(names[2], schemas[2]) + server2.add_tool(names[3], schemas[3]) + + server3 = FakeMCPServer() + server3.add_tool(names[4], schemas[4]) + + servers: list[MCPServer] = [server1, server2, server3] + tools = await MCPUtil.get_all_function_tools(servers) + assert len(tools) == 5 + assert all(tool.name in names for tool in tools) + + for idx, tool in enumerate(tools): + assert isinstance(tool, FunctionTool) + assert tool.params_json_schema == schemas[idx] + assert tool.name == names[idx] + + +@pytest.mark.asyncio +async def test_invoke_mcp_tool(): + """Test that the invoke_mcp_tool function invokes an MCP tool and returns the result.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + # Just making sure it doesn't crash + + +@pytest.mark.asyncio +async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(ModelBehaviorError): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json") + + assert "Invalid JSON input for tool test_tool_1" in caplog.text + + +class CrashingFakeMCPServer(FakeMCPServer): + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None): + raise Exception("Crash!") + + +@pytest.mark.asyncio +async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG) + + """Test that bad JSON input errors are logged and re-raised.""" + server = CrashingFakeMCPServer() + server.add_tool("test_tool_1", {}) + + ctx = RunContextWrapper(context=None) + tool = MCPTool(name="test_tool_1", inputSchema={}) + + with pytest.raises(AgentsException): + await MCPUtil.invoke_mcp_tool(server, tool, ctx, "") + + assert "Error invoking MCP tool test_tool_1" in caplog.text diff --git a/tests/mcp/test_runner_calls_mcp.py b/tests/mcp/test_runner_calls_mcp.py new file mode 100644 index 00000000..3319c097 --- /dev/null +++ b/tests/mcp/test_runner_calls_mcp.py @@ -0,0 +1,197 @@ +import json + +import pytest +from pydantic import BaseModel + +from agents import Agent, ModelBehaviorError, Runner, UserError + +from ..fake_model import FakeModel +from ..test_responses import get_function_tool_call, get_text_message +from .helpers import FakeMCPServer + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool): + """Test that the runner asserts when an MCP tool is not found.""" + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", {}) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(ModelBehaviorError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_works_with_multiple_mcp_servers(streaming: bool): + """Test that the runner works with multiple MCP servers.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server1.tool_calls == [] + assert server2.tool_calls == ["test_tool_2"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_errors_when_mcp_tools_clash(streaming: bool): + """Test that the runner errors when multiple servers have the same tool name.""" + server1 = FakeMCPServer() + server1.add_tool("test_tool_1", {}) + server1.add_tool("test_tool_2", {}) + + server2 = FakeMCPServer() + server2.add_tool("test_tool_2", {}) + server2.add_tool("test_tool_3", {}) + + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server1, server2], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_3", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + with pytest.raises(UserError): + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + +class Foo(BaseModel): + bar: str + baz: int + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_runner_calls_mcp_tool_with_args(streaming: bool): + """Test that the runner calls an MCP tool when the model produces a tool call.""" + server = FakeMCPServer() + await server.connect() + server.add_tool("test_tool_1", {}) + server.add_tool("test_tool_2", Foo.model_json_schema()) + server.add_tool("test_tool_3", {}) + model = FakeModel() + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + ) + + json_args = json.dumps(Foo(bar="baz", baz=1).model_dump()) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)], + # Second turn: text message + [get_text_message("done")], + ] + ) + + if streaming: + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + else: + await Runner.run(agent, input="user_message") + + assert server.tool_calls == ["test_tool_2"] + assert server.tool_results == [f"result_test_tool_2_{json_args}"] + + await server.cleanup() diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py new file mode 100644 index 00000000..5c6432bc --- /dev/null +++ b/tests/mcp/test_server_errors.py @@ -0,0 +1,38 @@ +import pytest + +from agents.exceptions import UserError +from agents.mcp.server import _MCPServerWithClientSession + + +class CrashingClientSessionServer(_MCPServerWithClientSession): + def __init__(self): + super().__init__(cache_tools_list=False) + self.cleanup_called = False + + def create_streams(self): + raise ValueError("Crash!") + + async def cleanup(self): + self.cleanup_called = True + await super().cleanup() + + +@pytest.mark.asyncio +async def test_server_errors_cause_error_and_cleanup_called(): + server = CrashingClientSessionServer() + + with pytest.raises(ValueError): + await server.connect() + + assert server.cleanup_called + + +@pytest.mark.asyncio +async def test_not_calling_connect_causes_error(): + server = CrashingClientSessionServer() + + with pytest.raises(UserError): + await server.list_tools() + + with pytest.raises(UserError): + await server.call_tool("foo", {}) diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2d581bf6..16c62c84 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,6 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, + all_tools=await agent.get_all_tools(), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 24f9e8e3..2a6634ac 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -43,7 +43,11 @@ def test_empty_response(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=[], ) assert not result.handoffs assert not result.functions @@ -57,13 +61,14 @@ def test_no_tool_calls(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[] ) assert not result.handoffs assert not result.functions -def test_single_tool_call(): +@pytest.mark.asyncio +async def test_single_tool_call(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -74,7 +79,11 @@ def test_single_tool_call(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -84,7 +93,8 @@ def test_single_tool_call(): assert func.tool_call.arguments == "" -def test_missing_tool_call_raises_error(): +@pytest.mark.asyncio +async def test_missing_tool_call_raises_error(): agent = Agent(name="test", tools=[get_function_tool(name="test")]) response = ModelResponse( output=[ @@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error(): with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_multiple_tool_calls(): +@pytest.mark.asyncio +async def test_multiple_tool_calls(): agent = Agent( name="test", tools=[ @@ -121,7 +136,11 @@ def test_multiple_tool_calls(): ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent_3, response=response, output_schema=None, handoffs=[] + agent=agent_3, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent_3.get_all_tools(), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -189,10 +213,12 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) -def test_multiple_handoffs_doesnt_error(): +@pytest.mark.asyncio +async def test_multiple_handoffs_doesnt_error(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -218,7 +245,8 @@ class Foo(BaseModel): bar: str -def test_final_output_parsed_correctly(): +@pytest.mark.asyncio +async def test_final_output_parsed_correctly(): agent = Agent(name="test", output_type=Foo) response = ModelResponse( output=[ @@ -234,10 +262,12 @@ def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], + all_tools=await agent.get_all_tools(), ) -def test_file_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_file_search_tool_call_parsed_correctly(): # Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool # runs are scheduled. @@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_function_web_search_tool_call_parsed_correctly(): +@pytest.mark.asyncio +async def test_function_web_search_tool_call_parsed_correctly(): agent = Agent(name="test") web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") response = ModelResponse( @@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly(): assert not result.handoffs -def test_reasoning_item_parsed_correctly(): +@pytest.mark.asyncio +async def test_reasoning_item_parsed_correctly(): # Verify that a Reasoning output item is converted into a ReasoningItem. reasoning = ResponseReasoningItem( @@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -342,7 +386,8 @@ def drag(self, path: list[tuple[int, int]]) -> None: return None # pragma: no cover -def test_computer_tool_call_without_computer_tool_raises_error(): +@pytest.mark.asyncio +async def test_computer_tool_call_without_computer_tool_raises_error(): # If the agent has no ComputerTool in its tools, process_model_response should raise a # ModelBehaviorError when encountering a ResponseComputerToolCall. computer_call = ResponseComputerToolCall( @@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error(): ) with pytest.raises(ModelBehaviorError): RunImpl.process_model_response( - agent=Agent(name="test"), response=response, output_schema=None, handoffs=[] + agent=Agent(name="test"), + response=response, + output_schema=None, + handoffs=[], + all_tools=await Agent(name="test").get_all_tools(), ) -def test_computer_tool_call_with_computer_tool_parsed_correctly(): +@pytest.mark.asyncio +async def test_computer_tool_call_with_computer_tool_parsed_correctly(): # If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a # ToolCallItem and scheduled to run in computer_actions. dummy_computer = DummyComputer() @@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): referenceable_id=None, ) result = RunImpl.process_model_response( - agent=agent, response=response, output_schema=None, handoffs=[] + agent=agent, + response=response, + output_schema=None, + handoffs=[], + all_tools=await agent.get_all_tools(), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly(): assert result.computer_actions and result.computer_actions[0].tool_call == computer_call -def test_tool_and_handoff_parsed_correctly(): +@pytest.mark.asyncio +async def test_tool_and_handoff_parsed_correctly(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), + all_tools=await agent_3.get_all_tools(), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here" diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 75559232..89b5cca7 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -269,7 +269,7 @@ def fake_time_func(): async for _ in turns: pass - assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) + assert "Timeout waiting for transcription_session.created event" in str(exc_info.value) await session.close() @@ -302,13 +302,11 @@ async def test_session_error_event(): trace_include_sensitive_audio_data=False, ) - with pytest.raises(STTWebsocketConnectionError) as exc_info: + with pytest.raises(STTWebsocketConnectionError): turns = session.transcribe_turns() async for _ in turns: pass - assert "Simulated server error!" in str(exc_info.value) - await session.close() @@ -362,8 +360,8 @@ async def test_inactivity_timeout(): async for turn in session.transcribe_turns(): collected_turns.append(turn) - assert "Timeout waiting for transcription_session" in str(exc_info.value) + assert "Timeout waiting for transcription_session" in str(exc_info.value) - assert len(collected_turns) == 0, "No transcripts expected, but we got something?" + assert len(collected_turns) == 0, "No transcripts expected, but we got something?" await session.close() diff --git a/uv.lock b/uv.lock index a9c79e21..d6eba43f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -459,6 +458,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, ] +[[package]] +name = "httpx-sse" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, +] + [[package]] name = "idna" version = "3.10" @@ -699,6 +707,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, ] +[[package]] +name = "mcp" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "httpx", marker = "python_full_version >= '3.10'" }, + { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, + { name = "pydantic", marker = "python_full_version >= '3.10'" }, + { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, + { name = "starlette", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/c9/c55764824e893fdebe777ac7223200986a275c3191dba9169f8eb6d7c978/mcp-1.5.0.tar.gz", hash = "sha256:5b2766c05e68e01a2034875e250139839498c61792163a7b221fc170c12f5aa9", size = 159128 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d1/3ff566ecf322077d861f1a68a1ff025cad337417bd66ad22a7c6f7dfcfaf/mcp-1.5.0-py3-none-any.whl", hash = "sha256:51c3f35ce93cb702f7513c12406bbea9665ef75a08db909200b07da9db641527", size = 73734 }, +] + [[package]] name = "mdit-py-plugins" version = "0.4.2" @@ -1054,6 +1081,7 @@ version = "0.0.6" source = { editable = "." } dependencies = [ { name = "griffe" }, + { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "openai" }, { name = "pydantic" }, { name = "requests" }, @@ -1091,6 +1119,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "griffe", specifier = ">=1.5.6,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=1.66.5" }, { name = "pydantic", specifier = ">=2.10,<3" }, @@ -1099,7 +1128,6 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.12.2,<5" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice"] [package.metadata.requires-dev] dev = [ @@ -1305,6 +1333,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/0c/c5c5cd3689c32ed1fe8c5d234b079c12c281c051759770c05b8bed6412b5/pydantic_core-2.27.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7d0c8399fcc1848491f00e0314bd59fb34a9c008761bcb422a057670c3f65e35", size = 2004961 }, ] +[[package]] +name = "pydantic-settings" +version = "2.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "python_full_version >= '3.10'" }, + { name = "python-dotenv", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/82/c79424d7d8c29b994fb01d277da57b0a9b09cc03c3ff875f9bd8a86b2145/pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585", size = 83550 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/53/a64f03044927dc47aafe029c42a5b7aabc38dfb813475e0e1bf71c4a59d0/pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c", size = 30839 }, +] + [[package]] name = "pyee" version = "12.1.1" @@ -1496,6 +1537,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + [[package]] name = "python-xlib" version = "0.33" @@ -1660,6 +1710,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/9b/15217b04f3b36d30de55fef542389d722de63f1ad81f9c72d8afc98cb6ab/sounddevice-0.5.1-py3-none-win_amd64.whl", hash = "sha256:4313b63f2076552b23ac3e0abd3bcfc0c1c6a696fc356759a13bd113c9df90f1", size = 363634 }, ] +[[package]] +name = "sse-starlette" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "starlette", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, +] + +[[package]] +name = "starlette" +version = "0.46.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/1b/52b27f2e13ceedc79a908e29eac426a63465a1a01248e5f24aa36a62aeb3/starlette-0.46.1.tar.gz", hash = "sha256:3c88d58ee4bd1bb807c0d1acb381838afc7752f9ddaec81bbe4383611d833230", size = 2580102 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995 }, +] + [[package]] name = "textual" version = "2.1.2" @@ -1774,6 +1849,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] +[[package]] +name = "uvicorn" +version = "0.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "python_full_version >= '3.10'" }, + { name = "h11", marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version == '3.10.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, +] + [[package]] name = "watchdog" version = "6.0.0"