From 010022777b3d37ab085cb8101a01fc1d37cca387 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 25 Mar 2025 18:01:00 -0400 Subject: [PATCH] [5/n] MCP tracing ## Summary: Adds tracing and tests for tracing. - Tools are added to the agents - Theres a span for the mcp tools lookup - Functions have MCP data ## Test Plan: Unit tests . --- examples/mcp/filesystem_example/main.py | 9 +- src/agents/__init__.py | 4 + src/agents/agent.py | 3 +- src/agents/mcp/server.py | 38 ++++- src/agents/mcp/util.py | 27 +++- src/agents/run.py | 17 +- src/agents/tracing/__init__.py | 4 + src/agents/tracing/create.py | 29 ++++ src/agents/tracing/span_data.py | 34 +++- src/agents/voice/models/openai_stt.py | 3 +- tests/mcp/helpers.py | 4 + tests/mcp/test_mcp_tracing.py | 198 ++++++++++++++++++++++++ tests/mcp/test_server_errors.py | 4 + 13 files changed, 352 insertions(+), 22 deletions(-) create mode 100644 tests/mcp/test_mcp_tracing.py diff --git a/examples/mcp/filesystem_example/main.py b/examples/mcp/filesystem_example/main.py index 0ba2b675..ae6fadd2 100644 --- a/examples/mcp/filesystem_example/main.py +++ b/examples/mcp/filesystem_example/main.py @@ -2,7 +2,7 @@ import os import shutil -from agents import Agent, Runner, trace +from agents import Agent, Runner, gen_trace_id, trace from agents.mcp import MCPServer, MCPServerStdio @@ -37,12 +37,15 @@ async def main(): samples_dir = os.path.join(current_dir, "sample_files") async with MCPServerStdio( + name="Filesystem Server, via npx", params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], - } + }, ) as server: - with trace(workflow_name="MCP Filesystem Example"): + trace_id = gen_trace_id() + with trace(workflow_name="MCP Filesystem Example", trace_id=trace_id): + print(f"View trace: https://platform.openai.com/traces/{trace_id}\n") await run(server) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 47bb2649..242f5649 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -70,6 +70,7 @@ GenerationSpanData, GuardrailSpanData, HandoffSpanData, + MCPListToolsSpanData, Span, SpanData, SpanError, @@ -89,6 +90,7 @@ get_current_trace, guardrail_span, handoff_span, + mcp_tools_span, set_trace_processors, set_tracing_disabled, set_tracing_export_api_key, @@ -220,6 +222,7 @@ def enable_verbose_stdout_logging(): "speech_group_span", "transcription_span", "speech_span", + "mcp_tools_span", "trace", "Trace", "TracingProcessor", @@ -234,6 +237,7 @@ def enable_verbose_stdout_logging(): "HandoffSpanData", "SpeechGroupSpanData", "SpeechSpanData", + "MCPListToolsSpanData", "TranscriptionSpanData", "set_default_openai_key", "set_default_openai_client", diff --git a/src/agents/agent.py b/src/agents/agent.py index b31f00b1..13bb464e 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -228,4 +228,5 @@ async def get_mcp_tools(self) -> list[Tool]: async def get_all_tools(self) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" - return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools + mcp_tools = await self.get_mcp_tools() + return mcp_tools + self.tools diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 91af31db..e70d7ce6 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -27,6 +27,12 @@ async def connect(self): """ pass + @property + @abc.abstractmethod + def name(self) -> str: + """A readable name for the server.""" + pass + @abc.abstractmethod async def cleanup(self): """Cleanup the server. For example, this might mean closing a subprocess or @@ -171,7 +177,12 @@ class MCPServerStdio(_MCPServerWithClientSession): details. """ - def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False): + def __init__( + self, + params: MCPServerStdioParams, + cache_tools_list: bool = False, + name: str | None = None, + ): """Create a new MCP server based on the stdio transport. Args: @@ -185,6 +196,8 @@ def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False) invalidated by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + name: A readable name for the server. If not provided, we'll create one from the + command. """ super().__init__(cache_tools_list) @@ -197,6 +210,8 @@ def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False) encoding_error_handler=params.get("encoding_error_handler", "strict"), ) + self._name = name or f"stdio: {self.params.command}" + def create_streams( self, ) -> AbstractAsyncContextManager[ @@ -208,6 +223,11 @@ def create_streams( """Create the streams for the server.""" return stdio_client(self.params) + @property + def name(self) -> str: + """A readable name for the server.""" + return self._name + class MCPServerSseParams(TypedDict): """Mirrors the params in`mcp.client.sse.sse_client`.""" @@ -231,7 +251,12 @@ class MCPServerSse(_MCPServerWithClientSession): for details. """ - def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False): + def __init__( + self, + params: MCPServerSseParams, + cache_tools_list: bool = False, + name: str | None = None, + ): """Create a new MCP server based on the HTTP with SSE transport. Args: @@ -245,10 +270,14 @@ def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False): invalidated by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + name: A readable name for the server. If not provided, we'll create one from the + URL. """ super().__init__(cache_tools_list) self.params = params + self._name = name or f"sse: {self.params['url']}" def create_streams( self, @@ -265,3 +294,8 @@ def create_streams( timeout=self.params.get("timeout", 5), sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), ) + + @property + def name(self) -> str: + """A readable name for the server.""" + return self._name diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 038c4fec..36c18bea 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -7,6 +7,7 @@ from ..logger import logger from ..run_context import RunContextWrapper from ..tool import FunctionTool, Tool +from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span if TYPE_CHECKING: from mcp.types import Tool as MCPTool @@ -38,7 +39,11 @@ async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]: @classmethod async def get_function_tools(cls, server: "MCPServer") -> list[Tool]: """Get all function tools from a single MCP server.""" - tools = await server.list_tools() + + with mcp_tools_span(server=server.name) as span: + tools = await server.list_tools() + span.span_data.result = [tool.name for tool in tools] + return [cls.to_function_tool(tool, server) for tool in tools] @classmethod @@ -88,9 +93,23 @@ async def invoke_mcp_tool( # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single # string. We'll try to convert. if len(result.content) == 1: - return result.content[0].model_dump_json() + tool_output = result.content[0].model_dump_json() elif len(result.content) > 1: - return json.dumps([item.model_dump() for item in result.content]) + tool_output = json.dumps([item.model_dump() for item in result.content]) else: logger.error(f"Errored MCP tool result: {result}") - return "Error running tool." + tool_output = "Error running tool." + + current_span = get_current_span() + if current_span: + if isinstance(current_span.span_data, FunctionSpanData): + current_span.span_data.output = tool_output + current_span.span_data.mcp_data = { + "server": server.name, + } + else: + logger.warning( + f"Current span is not a FunctionSpanData, skipping tool output: {current_span}" + ) + + return tool_output diff --git a/src/agents/run.py b/src/agents/run.py index 5c21b709..0159822a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,8 +7,6 @@ from openai.types.responses import ResponseCompletedEvent -from agents.tool import Tool - from ._run_impl import ( AgentToolUseTracker, NextStepFinalOutput, @@ -40,6 +38,7 @@ from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent +from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData from .usage import Usage @@ -182,8 +181,6 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - all_tools = await cls._get_all_tools(current_agent) - tool_names = [t.name for t in all_tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -192,11 +189,13 @@ async def run( current_span = agent_span( name=current_agent.name, handoffs=handoff_names, - tools=tool_names, output_type=output_type_name, ) current_span.start(mark_as_current=True) + all_tools = await cls._get_all_tools(current_agent) + current_span.span_data.tools = [t.name for t in all_tools] + current_turn += 1 if current_turn > max_turns: _error_tracing.attach_error_to_span( @@ -504,7 +503,6 @@ async def _run_streamed_impl( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] - tool_names = [t.name for t in current_agent.tools] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.output_type_name() else: @@ -513,11 +511,13 @@ async def _run_streamed_impl( current_span = agent_span( name=current_agent.name, handoffs=handoff_names, - tools=tool_names, output_type=output_type_name, ) current_span.start(mark_as_current=True) + all_tools = await cls._get_all_tools(current_agent) + tool_names = [t.name for t in all_tools] + current_span.span_data.tools = tool_names current_turn += 1 streamed_result.current_turn = current_turn @@ -553,6 +553,7 @@ async def _run_streamed_impl( run_config, should_run_agent_start_hooks, tool_use_tracker, + all_tools, ) should_run_agent_start_hooks = False @@ -621,6 +622,7 @@ async def _run_single_turn_streamed( run_config: RunConfig, should_run_agent_start_hooks: bool, tool_use_tracker: AgentToolUseTracker, + all_tools: list[Tool], ) -> SingleStepResult: if should_run_agent_start_hooks: await asyncio.gather( @@ -640,7 +642,6 @@ async def _run_single_turn_streamed( system_prompt = await agent.get_system_prompt(context_wrapper) handoffs = cls._get_handoffs(agent) - all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index dc7c7cfd..9df94426 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -9,6 +9,7 @@ get_current_trace, guardrail_span, handoff_span, + mcp_tools_span, response_span, speech_group_span, speech_span, @@ -25,6 +26,7 @@ GenerationSpanData, GuardrailSpanData, HandoffSpanData, + MCPListToolsSpanData, ResponseSpanData, SpanData, SpeechGroupSpanData, @@ -59,6 +61,7 @@ "GenerationSpanData", "GuardrailSpanData", "HandoffSpanData", + "MCPListToolsSpanData", "ResponseSpanData", "SpeechGroupSpanData", "SpeechSpanData", @@ -69,6 +72,7 @@ "speech_group_span", "speech_span", "transcription_span", + "mcp_tools_span", ] diff --git a/src/agents/tracing/create.py b/src/agents/tracing/create.py index af2f156f..b6fe4610 100644 --- a/src/agents/tracing/create.py +++ b/src/agents/tracing/create.py @@ -12,6 +12,7 @@ GenerationSpanData, GuardrailSpanData, HandoffSpanData, + MCPListToolsSpanData, ResponseSpanData, SpeechGroupSpanData, SpeechSpanData, @@ -424,3 +425,31 @@ def speech_group_span( parent=parent, disabled=disabled, ) + + +def mcp_tools_span( + server: str | None = None, + result: list[str] | None = None, + span_id: str | None = None, + parent: Trace | Span[Any] | None = None, + disabled: bool = False, +) -> Span[MCPListToolsSpanData]: + """Create a new MCP list tools span. The span will not be started automatically, you should + either do `with mcp_tools_span() ...` or call `span.start()` + `span.finish()` manually. + + Args: + server: The name of the MCP server. + result: The result of the MCP list tools call. + span_id: The ID of the span. Optional. If not provided, we will generate an ID. We + recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are + correctly formatted. + parent: The parent span or trace. If not provided, we will automatically use the current + trace/span as the parent. + disabled: If True, we will return a Span but the Span will not be recorded. + """ + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=MCPListToolsSpanData(server=server, result=result), + span_id=span_id, + parent=parent, + disabled=disabled, + ) diff --git a/src/agents/tracing/span_data.py b/src/agents/tracing/span_data.py index 95e7fe0f..ed2a3f2d 100644 --- a/src/agents/tracing/span_data.py +++ b/src/agents/tracing/span_data.py @@ -49,12 +49,19 @@ def export(self) -> dict[str, Any]: class FunctionSpanData(SpanData): - __slots__ = ("name", "input", "output") + __slots__ = ("name", "input", "output", "mcp_data") - def __init__(self, name: str, input: str | None, output: Any | None): + def __init__( + self, + name: str, + input: str | None, + output: Any | None, + mcp_data: dict[str, Any] | None = None, + ): self.name = name self.input = input self.output = output + self.mcp_data = mcp_data @property def type(self) -> str: @@ -66,6 +73,7 @@ def export(self) -> dict[str, Any]: "name": self.name, "input": self.input, "output": str(self.output) if self.output else None, + "mcp_data": self.mcp_data, } @@ -282,3 +290,25 @@ def export(self) -> dict[str, Any]: "type": self.type, "input": self.input, } + + +class MCPListToolsSpanData(SpanData): + __slots__ = ( + "server", + "result", + ) + + def __init__(self, server: str | None = None, result: list[str] | None = None): + self.server = server + self.result = result + + @property + def type(self) -> str: + return "mcp_tools" + + def export(self) -> dict[str, Any]: + return { + "type": self.type, + "server": self.server, + "result": self.result, + } diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index a5cf8acd..1ae4ea14 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -10,9 +10,8 @@ from openai import AsyncOpenAI -from agents.exceptions import AgentsException - from ... import _debug +from ...exceptions import AgentsException from ...logger import logger from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span from ..exceptions import STTWebsocketConnectionError diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 952b3ea7..8ff153c1 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -52,3 +52,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C return CallToolResult( content=[TextContent(text=self.tool_results[-1], type="text")], ) + + @property + def name(self) -> str: + return "fake_mcp_server" diff --git a/tests/mcp/test_mcp_tracing.py b/tests/mcp/test_mcp_tracing.py new file mode 100644 index 00000000..b71954b5 --- /dev/null +++ b/tests/mcp/test_mcp_tracing.py @@ -0,0 +1,198 @@ +import pytest +from inline_snapshot import snapshot + +from agents import Agent, Runner + +from ..fake_model import FakeModel +from ..test_responses import get_function_tool, get_function_tool_call, get_text_message +from ..testing_processor import SPAN_PROCESSOR_TESTING, fetch_normalized_spans +from .helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_mcp_tracing(): + model = FakeModel() + server = FakeMCPServer() + server.add_tool("test_tool_1", {}) + agent = Agent( + name="test", + model=model, + mcp_servers=[server], + tools=[get_function_tool("non_mcp_tool", "tool_result")], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_1", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + # First run: should list MCP tools before first and second steps + x = Runner.run_streamed(agent, input="first_test") + async for _ in x.stream_events(): + pass + + assert x.final_output == "done" + spans = fetch_normalized_spans() + + # Should have a single tool listing, and the function span should have MCP data + assert spans == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test", + "handoffs": [], + "tools": ["test_tool_1", "non_mcp_tool"], + "output_type": "str", + }, + "children": [ + { + "type": "mcp_tools", + "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, + }, + { + "type": "function", + "data": { + "name": "test_tool_1", + "input": "", + "output": '{"type":"text","text":"result_test_tool_1_{}","annotations":null}', # noqa: E501 + "mcp_data": {"server": "fake_mcp_server"}, + }, + }, + ], + } + ], + } + ] + ) + + server.add_tool("test_tool_2", {}) + + SPAN_PROCESSOR_TESTING.clear() + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("non_mcp_tool", ""), + get_function_tool_call("test_tool_2", ""), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + await Runner.run(agent, input="second_test") + spans = fetch_normalized_spans() + + # Should have a single tool listing, and the function span should have MCP data, and the non-mcp + # tool function span should not have MCP data + assert spans == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test", + "handoffs": [], + "tools": ["test_tool_1", "test_tool_2", "non_mcp_tool"], + "output_type": "str", + }, + "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2"], + }, + }, + { + "type": "function", + "data": { + "name": "non_mcp_tool", + "input": "", + "output": "tool_result", + }, + }, + { + "type": "function", + "data": { + "name": "test_tool_2", + "input": "", + "output": '{"type":"text","text":"result_test_tool_2_{}","annotations":null}', # noqa: E501 + "mcp_data": {"server": "fake_mcp_server"}, + }, + }, + ], + } + ], + } + ] + ) + + SPAN_PROCESSOR_TESTING.clear() + + # Add more tools to the server + server.add_tool("test_tool_3", {}) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [get_text_message("a_message"), get_function_tool_call("test_tool_3", "")], + # Second turn: text message + [get_text_message("done")], + ] + ) + + await Runner.run(agent, input="third_test") + + spans = fetch_normalized_spans() + + # Should have a single tool listing, and the function span should have MCP data, and the non-mcp + # tool function span should not have MCP data + assert spans == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test", + "handoffs": [], + "tools": ["test_tool_1", "test_tool_2", "test_tool_3", "non_mcp_tool"], + "output_type": "str", + }, + "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2", "test_tool_3"], + }, + }, + { + "type": "function", + "data": { + "name": "test_tool_3", + "input": "", + "output": '{"type":"text","text":"result_test_tool_3_{}","annotations":null}', # noqa: E501 + "mcp_data": {"server": "fake_mcp_server"}, + }, + }, + ], + } + ], + } + ] + ) diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index 5c6432bc..bdca7ce6 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -16,6 +16,10 @@ async def cleanup(self): self.cleanup_called = True await super().cleanup() + @property + def name(self) -> str: + return "crashing_client_session_server" + @pytest.mark.asyncio async def test_server_errors_cause_error_and_cleanup_called():