diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml index 6447f83ef..dca717752 100644 --- a/.github/workflows/issues.yml +++ b/.github/workflows/issues.yml @@ -21,6 +21,7 @@ jobs: days-before-pr-stale: 10 days-before-pr-close: 7 stale-pr-label: "stale" + exempt-issue-labels: "skip-stale" stale-pr-message: "This PR is stale because it has been open for 10 days with no activity." close-pr-message: "This PR was closed because it has been inactive for 7 days since being marked as stale." repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/AGENTS.md b/AGENTS.md index ff37db326..291c31837 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,8 @@ Welcome to the OpenAI Agents SDK repository. This file contains the main points Coverage can be generated with `make coverage`. +All python commands should be run via `uv run python ...` + ## Snapshot tests Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: @@ -64,6 +66,6 @@ Commit messages should be concise and written in the imperative mood. Small, foc ## What reviewers look for - Tests covering new behaviour. -- Consistent style: code formatted with `ruff format`, imports sorted, and type hints passing `mypy`. +- Consistent style: code formatted with `uv run ruff format`, imports sorted, and type hints passing `uv run mypy .`. - Clear documentation for any public API changes. - Clean history and a helpful PR description. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..5e01a1c3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +Read the AGENTS.md file for instructions. \ No newline at end of file diff --git a/Makefile b/Makefile index 5c6aba425..9a88f93a1 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,10 @@ format: uv run ruff format uv run ruff check --fix +.PHONY: format-check +format-check: + uv run ruff format --check + .PHONY: lint lint: uv run ruff check @@ -55,5 +59,5 @@ serve-docs: deploy-docs: uv run mkdocs gh-deploy --force --verbose - - +.PHONY: check +check: format-check lint mypy tests diff --git a/README.md b/README.md index 785177916..755c342ae 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,21 @@ Explore the [examples](examples) directory to see the SDK in action, and read ou 1. Set up your Python environment -``` +- Option A: Using venv (traditional method) +```bash python -m venv env -source env/bin/activate +source env/bin/activate # On Windows: env\Scripts\activate +``` + +- Option B: Using uv (recommended) +```bash +uv venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate ``` 2. Install Agents SDK -``` +```bash pip install openai-agents ``` @@ -50,7 +57,7 @@ print(result.final_output) (_If running this, ensure you set the `OPENAI_API_KEY` environment variable_) -(_For Jupyter notebook users, see [hello_world_jupyter.py](examples/basic/hello_world_jupyter.py)_) +(_For Jupyter notebook users, see [hello_world_jupyter.ipynb](examples/basic/hello_world_jupyter.ipynb)_) ## Handoffs example @@ -163,10 +170,16 @@ make sync 2. (After making changes) lint/test +``` +make check # run tests linter and typechecker +``` + +Or to run them individually: ``` make tests # run tests make mypy # run typechecker make lint # run linter +make format-check # run style checker ``` ## Acknowledgements diff --git a/docs/agents.md b/docs/agents.md index 39d4afd57..b11a4dd68 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -6,6 +6,7 @@ Agents are the core building block in your apps. An agent is a large language mo The most common properties of an agent you'll configure are: +- `name`: A required string that identifies your agent. - `instructions`: also known as a developer message or system prompt. - `model`: which LLM to use, and optional `model_settings` to configure model tuning parameters like temperature, top_p, etc. - `tools`: Tools that the agent can use to achieve its tasks. diff --git a/docs/context.md b/docs/context.md index 4176ec51f..6e54565e0 100644 --- a/docs/context.md +++ b/docs/context.md @@ -38,7 +38,8 @@ class UserInfo: # (1)! @function_tool async def fetch_user_age(wrapper: RunContextWrapper[UserInfo]) -> str: # (2)! - return f"User {wrapper.context.name} is 47 years old" + """Fetch the age of the user. Call this function to get user's age information.""" + return f"The user {wrapper.context.name} is 47 years old" async def main(): user_info = UserInfo(name="John", uid=123) diff --git a/docs/guardrails.md b/docs/guardrails.md index 2f0be0f2a..8df904a4c 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -23,7 +23,7 @@ Input guardrails run in 3 steps: Output guardrails run in 3 steps: -1. First, the guardrail receives the same input passed to the agent. +1. First, the guardrail receives the output produced by the agent. 2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] 3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. diff --git a/docs/ja/guardrails.md b/docs/ja/guardrails.md index e7b02a6ed..b67bb8bad 100644 --- a/docs/ja/guardrails.md +++ b/docs/ja/guardrails.md @@ -4,44 +4,44 @@ search: --- # ガードレール -ガードレールは エージェント と _並列_ に実行され、 ユーザー入力 のチェックとバリデーションを行います。たとえば、顧客からのリクエストを支援するために非常に賢い (そのため遅く / 高価な) モデルを使うエージェントがあるとします。悪意のある ユーザー がモデルに数学の宿題を手伝わせようとするのは避けたいですよね。その場合、 高速 / 低コスト のモデルでガードレールを実行できます。ガードレールが悪意のある利用を検知した場合、即座にエラーを送出して高価なモデルの実行を停止し、時間と費用を節約できます。 +ガードレールは エージェント と _並行して_ 実行され、ユーザー入力のチェックとバリデーションを行えます。例えば、とても賢い(つまり遅く/高価な)モデルを使用してカスタマーリクエストを処理するエージェントがあるとします。悪意のある ユーザー がモデルに数学の宿題を手伝わせようとするのは避けたいでしょう。そこで、速く/安価なモデルで動くガードレールを実行できます。ガードレールが悪意のある利用を検知すると、直ちにエラーを送出して高価なモデルの実行を停止し、時間とコストを節約できます。 -ガードレールには 2 種類あります。 +ガードレールには 2 種類あります: -1. Input ガードレールは最初の ユーザー入力 に対して実行されます -2. Output ガードレールは最終的なエージェント出力に対して実行されます +1. 入力ガードレール は初期 ユーザー 入力に対して実行されます +2. 出力ガードレール は最終的なエージェント出力に対して実行されます -## Input ガードレール +## 入力ガードレール -Input ガードレールは 3 つのステップで実行されます。 +入力ガードレールは 3 ステップで実行されます: 1. まず、ガードレールはエージェントに渡されたものと同じ入力を受け取ります。 -2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`InputGuardrailResult`][agents.guardrail.InputGuardrailResult] でラップされます。 -3. 最後に [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`InputGuardrailTripwireTriggered`][agents.exceptions.InputGuardrailTripwireTriggered] 例外が送出されるので、 ユーザー への適切な応答や例外処理を行えます。 +2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`InputGuardrailResult`][agents.guardrail.InputGuardrailResult] にラップされます。 +3. 最後に [.tripwire_triggered][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`InputGuardrailTripwireTriggered`][agents.exceptions.InputGuardrailTripwireTriggered] 例外が送出されるので、適切に ユーザー に応答したり例外を処理できます。 !!! Note - Input ガードレールは ユーザー入力 に対して実行されることを想定しているため、エージェントのガードレールが実行されるのはそのエージェントが *最初* のエージェントである場合だけです。「なぜ `guardrails` プロパティがエージェントにあり、 `Runner.run` に渡さないのか?」と思うかもしれません。ガードレールは実際の エージェント に密接に関連する場合が多く、エージェントごとに異なるガードレールを実行するため、コードを同じ場所に置くことで可読性が向上するからです。 + 入力ガードレールは ユーザー 入力に対して実行されることを意図しているため、ガードレールは *最初* のエージェントでのみ実行されます。「なぜ `guardrails` プロパティがエージェントにあり、`Runner.run` に渡さないのか」と疑問に思うかもしれません。これは、ガードレールが実際の エージェント と密接に関連していることが多いからです。異なるエージェントには異なるガードレールを実行するため、コードを同じ場所に置くことで可読性が向上します。 -## Output ガードレール +## 出力ガードレール -Output ガードレールは 3 つのステップで実行されます。 +出力ガードレールは 3 ステップで実行されます: -1. まず、ガードレールはエージェントに渡されたものと同じ入力を受け取ります。 -2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] でラップされます。 -3. 最後に [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] 例外が送出されるので、 ユーザー への適切な応答や例外処理を行えます。 +1. まず、ガードレールはエージェントが生成した出力を受け取ります。 +2. 次に、ガードレール関数が実行され [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を生成し、それが [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] にラップされます。 +3. 最後に [.tripwire_triggered][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] が true かどうかを確認します。true の場合、[`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] 例外が送出されるので、適切に ユーザー に応答したり例外を処理できます。 !!! Note - Output ガードレールは最終的なエージェント出力に対して実行されることを想定しているため、エージェントのガードレールが実行されるのはそのエージェントが *最後* のエージェントである場合だけです。Input ガードレール同様、ガードレールは実際の エージェント に密接に関連するため、コードを同じ場所に置くことで可読性が向上します。 + 出力ガードレールは最終的なエージェント出力に対して実行されることを意図しているため、ガードレールは *最後* のエージェントでのみ実行されます。入力ガードレールの場合と同様、ガードレールが実際の エージェント と密接に関連していることが多いため、コードを同じ場所に置くことで可読性が向上します。 -## トリップワイヤ +## トリップワイヤー -入力または出力がガードレールに失敗した場合、ガードレールはトリップワイヤを用いてそれを通知できます。ガードレールがトリップワイヤを発火したことを検知すると、ただちに `{Input,Output}GuardrailTripwireTriggered` 例外を送出してエージェントの実行を停止します。 +入力または出力がガードレールを通過できなかった場合、ガードレールはトリップワイヤーでそれを示すことができます。トリップワイヤーがトリガーされたガードレールを検知した時点で、直ちに `{Input,Output}GuardrailTripwireTriggered` 例外を送出し、エージェントの実行を停止します。 ## ガードレールの実装 -入力を受け取り、[`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を返す関数を用意する必要があります。次の例では、内部で エージェント を実行してこれを行います。 +入力を受け取り、[`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput] を返す関数を提供する必要があります。この例では、内部で エージェント を実行してこれを行います。 ```python from pydantic import BaseModel @@ -94,12 +94,12 @@ async def main(): print("Math homework guardrail tripped") ``` -1. この エージェント をガードレール関数内で使用します。 -2. これはエージェントの入力 / コンテキストを受け取り、結果を返すガードレール関数です。 +1. このエージェントをガードレール関数内で使用します。 +2. これはエージェントの入力/コンテキストを受け取り、結果を返すガードレール関数です。 3. ガードレール結果に追加情報を含めることができます。 4. これはワークフローを定義する実際のエージェントです。 -Output ガードレールも同様です。 +出力ガードレールも同様です。 ```python from pydantic import BaseModel @@ -155,4 +155,4 @@ async def main(): 1. これは実際のエージェントの出力型です。 2. これはガードレールの出力型です。 3. これはエージェントの出力を受け取り、結果を返すガードレール関数です。 -4. これはワークフローを定義する実際のエージェントです。 \ No newline at end of file +4. これはワークフローを定義する実際のエージェントです。 \ No newline at end of file diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 09804beb2..1e394a5e6 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -23,13 +23,20 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # 注意:実際には通常は MCP サーバーをエージェントに追加し、 + # フレームワークがツール一覧の取得を自動的に処理するようにします。 + # list_tools() への直接呼び出しには run_context と agent パラメータが必要です。 + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## MCP サーバーの利用 diff --git a/docs/ja/models.md b/docs/ja/models.md deleted file mode 100644 index 5a76d60ec..000000000 --- a/docs/ja/models.md +++ /dev/null @@ -1,106 +0,0 @@ -# モデル - -Agents SDK には、OpenAI モデルの 2 種類のサポートが標準で用意されています。 - -- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] は、新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を使って OpenAI API を呼び出します。 -- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] は、[Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を使って OpenAI API を呼び出します。 - -## モデルの組み合わせ - -1 つのワークフロー内で、各エージェントごとに異なるモデルを使いたい場合があります。たとえば、トリアージには小型で高速なモデルを使い、複雑なタスクにはより大きく高性能なモデルを使うことができます。[`Agent`][agents.Agent] を設定する際、以下のいずれかの方法で特定のモデルを選択できます。 - -1. OpenAI モデル名を直接渡す。 -2. 任意のモデル名と、その名前を Model インスタンスにマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を渡す。 -3. [`Model`][agents.models.interface.Model] 実装を直接指定する。 - -!!!note - - SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両方の形状をサポートしていますが、各ワークフローで 1 つのモデル形状のみを使うことを推奨します。なぜなら、2 つの形状はサポートする機能やツールが異なるためです。ワークフローでモデル形状を組み合わせて使う場合は、利用するすべての機能が両方で利用可能かご確認ください。 - -```python -from agents import Agent, Runner, AsyncOpenAI, OpenAIChatCompletionsModel -import asyncio - -spanish_agent = Agent( - name="Spanish agent", - instructions="You only speak Spanish.", - model="o3-mini", # (1)! -) - -english_agent = Agent( - name="English agent", - instructions="You only speak English", - model=OpenAIChatCompletionsModel( # (2)! - model="gpt-4o", - openai_client=AsyncOpenAI() - ), -) - -triage_agent = Agent( - name="Triage agent", - instructions="Handoff to the appropriate agent based on the language of the request.", - handoffs=[spanish_agent, english_agent], - model="gpt-3.5-turbo", -) - -async def main(): - result = await Runner.run(triage_agent, input="Hola, ¿cómo estás?") - print(result.final_output) -``` - -1. OpenAI モデル名を直接設定します。 -2. [`Model`][agents.models.interface.Model] 実装を指定します。 - -エージェントで使用するモデルをさらに細かく設定したい場合は、[`ModelSettings`][agents.models.interface.ModelSettings] を渡すことができます。これにより、temperature などのオプションのモデル設定パラメーターを指定できます。 - -```python -from agents import Agent, ModelSettings - -english_agent = Agent( - name="English agent", - instructions="You only speak English", - model="gpt-4o", - model_settings=ModelSettings(temperature=0.1), -) -``` - -## 他の LLM プロバイダーの利用 - -他の LLM プロバイダーは、3 つの方法で利用できます([こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に code examples があります)。 - -1. [`set_default_openai_client`][agents.set_default_openai_client] は、`AsyncOpenAI` のインスタンスを LLM クライアントとしてグローバルに利用したい場合に便利です。これは、LLM プロバイダーが OpenAI 互換の API エンドポイントを持ち、`base_url` と `api_key` を設定できる場合に使います。[examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py) に設定例があります。 -2. [`ModelProvider`][agents.models.interface.ModelProvider] は `Runner.run` レベルで利用します。これにより、「この実行のすべてのエージェントでカスタムモデルプロバイダーを使う」と指定できます。[examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py) に設定例があります。 -3. [`Agent.model`][agents.agent.Agent.model] で、特定のエージェントインスタンスにモデルを指定できます。これにより、エージェントごとに異なるプロバイダーを組み合わせて使うことができます。[examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py) に設定例があります。 - -`platform.openai.com` の API キーがない場合は、`set_tracing_disabled()` でトレーシングを無効にするか、[別のトレーシングプロセッサー](tracing.md) を設定することを推奨します。 - -!!! note - - これらの code examples では Chat Completions API/モデルを使っています。なぜなら、ほとんどの LLM プロバイダーはまだ Responses API をサポートしていないためです。もし LLM プロバイダーが Responses API をサポートしている場合は、Responses の利用を推奨します。 - -## 他の LLM プロバイダー利用時のよくある問題 - -### Tracing クライアントの 401 エラー - -トレーシングに関連するエラーが発生した場合、これはトレースが OpenAI サーバーにアップロードされるため、OpenAI API キーがないことが原因です。解決方法は 3 つあります。 - -1. トレーシングを完全に無効化する: [`set_tracing_disabled(True)`][agents.set_tracing_disabled]。 -2. トレーシング用の OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key]。この API キーはトレースのアップロードのみに使われ、[platform.openai.com](https://platform.openai.com/) のものが必要です。 -3. OpenAI 以外のトレースプロセッサーを使う。[トレーシングのドキュメント](tracing.md#custom-tracing-processors) をご覧ください。 - -### Responses API サポート - -SDK はデフォルトで Responses API を使いますが、ほとんどの他の LLM プロバイダーはまだ対応していません。そのため、404 エラーなどが発生する場合があります。解決方法は 2 つあります。 - -1. [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api] を呼び出します。これは、環境変数で `OPENAI_API_KEY` と `OPENAI_BASE_URL` を設定している場合に有効です。 -2. [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] を使います。[こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に code examples があります。 - -### structured outputs サポート - -一部のモデルプロバイダーは [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) をサポートしていません。その場合、次のようなエラーが発生することがあります。 - -``` -BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} -``` - -これは一部のモデルプロバイダーの制限で、JSON 出力には対応していても、出力に使う `json_schema` を指定できない場合があります。現在この問題の修正に取り組んでいますが、JSON schema 出力をサポートしているプロバイダーの利用を推奨します。そうでない場合、不正な JSON によりアプリが頻繁に動作しなくなる可能性があります。 \ No newline at end of file diff --git a/docs/ja/models/index.md b/docs/ja/models/index.md index a40ae38f6..410c01676 100644 --- a/docs/ja/models/index.md +++ b/docs/ja/models/index.md @@ -4,21 +4,52 @@ search: --- # モデル -Agents SDK には、標準で 2 種類の OpenAI モデルサポートが含まれています。 +Agents SDK は OpenAI モデルを 2 つの形態で即利用できます。 -- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] — 新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を利用して OpenAI API を呼び出します。 -- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] — [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を利用して OpenAI API を呼び出します。 +- **推奨**: [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] は、新しい [Responses API](https://platform.openai.com/docs/api-reference/responses) を使用して OpenAI API を呼び出します。 +- [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] は、[Chat Completions API](https://platform.openai.com/docs/api-reference/chat) を使用して OpenAI API を呼び出します。 + +## 非 OpenAI モデル + +ほとんどの非 OpenAI モデルは [LiteLLM インテグレーション](./litellm.md) 経由で利用できます。まず、litellm 依存グループをインストールします: + +```bash +pip install "openai-agents[litellm]" +``` + +次に、`litellm/` 接頭辞を付けて任意の [サポート対象モデル](https://docs.litellm.ai/docs/providers) を使用します: + +```python +claude_agent = Agent(model="litellm/anthropic/claude-3-5-sonnet-20240620", ...) +gemini_agent = Agent(model="litellm/gemini/gemini-2.5-flash-preview-04-17", ...) +``` + +### 非 OpenAI モデルを利用するその他の方法 + +他の LLM プロバイダーを統合する方法は、あと 3 つあります([こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) に例があります)。 + +1. [`set_default_openai_client`][agents.set_default_openai_client] + `AsyncOpenAI` インスタンスを LLM クライアントとしてグローバルに使用したい場合に便利です。LLM プロバイダーが OpenAI 互換の API エンドポイントを持ち、`base_url` と `api_key` を設定できる場合に使用します。設定例は [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py) にあります。 +2. [`ModelProvider`][agents.models.interface.ModelProvider] + `Runner.run` レベルでカスタムモデルプロバイダーを指定できます。これにより「この run のすべてのエージェントでカスタムプロバイダーを使う」と宣言できます。設定例は [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py) にあります。 +3. [`Agent.model`][agents.agent.Agent.model] + 特定のエージェントインスタンスにモデルを指定できます。エージェントごとに異なるプロバイダーを組み合わせることが可能です。設定例は [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py) にあります。ほとんどのモデルを簡単に利用する方法として [LiteLLM インテグレーション](./litellm.md) を利用できます。 + +`platform.openai.com` の API キーを持っていない場合は、`set_tracing_disabled()` でトレーシングを無効化するか、[別のトレーシングプロセッサー](../tracing.md) を設定することをお勧めします。 + +!!! note + これらの例では、Responses API をまだサポートしていない LLM プロバイダーが多いため、Chat Completions API/モデルを使用しています。LLM プロバイダーが Responses API をサポートしている場合は、Responses を使用することを推奨します。 ## モデルの組み合わせ -1 つのワークフロー内で、エージェントごとに異なるモデルを使用したい場合があります。たとえば、振り分けには小さく高速なモデルを、複雑なタスクには大きく高性能なモデルを使う、といった使い分けです。[`Agent`][agents.Agent] を設定する際は、以下のいずれかで特定のモデルを指定できます。 +1 つのワークフロー内でエージェントごとに異なるモデルを使用したい場合があります。たとえば、振り分けには小さく高速なモデルを、複雑なタスクには大きく高性能なモデルを使用するといったケースです。[`Agent`][agents.Agent] を設定する際、次のいずれかの方法でモデルを選択できます。 -1. OpenAI モデル名を直接渡す -2. 任意のモデル名と、それを `Model` インスタンスへマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を渡す +1. モデル名を直接指定する +2. 任意のモデル名と、その名前を Model インスタンスへマッピングできる [`ModelProvider`][agents.models.interface.ModelProvider] を指定する 3. [`Model`][agents.models.interface.Model] 実装を直接渡す !!!note - SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両方の形に対応していますが、ワークフローごとに 1 つのモデル形を使用することを推奨します。2 つの形ではサポートする機能・ツールが異なるためです。どうしても混在させる場合は、利用するすべての機能が両方で利用可能であることを確認してください。 + SDK は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel] と [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] の両形態をサポートしていますが、各ワークフローで 1 つのモデル形態に統一することを推奨します。2 つの形態はサポートする機能とツールが異なるためです。混在させる場合は、使用する機能が双方で利用可能かを必ず確認してください。 ```python from agents import Agent, Runner, AsyncOpenAI, OpenAIChatCompletionsModel @@ -51,10 +82,10 @@ async def main(): print(result.final_output) ``` -1. OpenAI モデル名を直接指定 +1. OpenAI のモデル名を直接設定 2. [`Model`][agents.models.interface.Model] 実装を提供 -エージェントで使用するモデルをさらに細かく設定したい場合は、`temperature` などのオプションを指定できる [`ModelSettings`][agents.models.interface.ModelSettings] を渡します。 +エージェントで使用するモデルをさらに構成したい場合は、`temperature` などのオプションパラメーターを指定できる [`ModelSettings`][agents.models.interface.ModelSettings] を渡せます。 ```python from agents import Agent, ModelSettings @@ -67,50 +98,58 @@ english_agent = Agent( ) ``` -## 他の LLM プロバイダーの利用 - -他の LLM プロバイダーは 3 通りの方法で利用できます(コード例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/))。 - -1. [`set_default_openai_client`][agents.set_default_openai_client] - OpenAI 互換の API エンドポイントを持つ場合に、`AsyncOpenAI` インスタンスをグローバルに LLM クライアントとして設定できます。`base_url` と `api_key` を設定するケースです。設定例は [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py)。 +OpenAI の Responses API を使用する場合、`user` や `service_tier` など[その他のオプションパラメーター](https://platform.openai.com/docs/api-reference/responses/create) があります。トップレベルで指定できない場合は、`extra_args` で渡してください。 -2. [`ModelProvider`][agents.models.interface.ModelProvider] - `Runner.run` レベルで「この実行中のすべてのエージェントにカスタムモデルプロバイダーを使う」と宣言できます。設定例は [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py)。 - -3. [`Agent.model`][agents.agent.Agent.model] - 特定の Agent インスタンスにモデルを指定できます。エージェントごとに異なるプロバイダーを組み合わせられます。設定例は [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py)。多くのモデルを簡単に使う方法として [LiteLLM 連携](./litellm.md) があります。 - -`platform.openai.com` の API キーを持たない場合は、`set_tracing_disabled()` でトレーシングを無効化するか、[別のトレーシングプロセッサー](../tracing.md) を設定することを推奨します。 +```python +from agents import Agent, ModelSettings -!!! note - これらの例では Chat Completions API/モデルを使用しています。多くの LLM プロバイダーがまだ Responses API をサポートしていないためです。もしプロバイダーが Responses API をサポートしている場合は、Responses の使用を推奨します。 +english_agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings( + temperature=0.1, + extra_args={"service_tier": "flex", "user": "user_12345"}, + ), +) +``` -## 他の LLM プロバイダーでよくある問題 +## 他の LLM プロバイダー使用時の一般的な問題 ### Tracing クライアントの 401 エラー -トレースは OpenAI サーバーへアップロードされるため、OpenAI API キーがない場合にエラーになります。解決策は次の 3 つです。 +Tracing 関連のエラーが発生する場合、トレースは OpenAI サーバーへアップロードされるため、OpenAI API キーが必要です。対応方法は次の 3 つです。 1. トレーシングを完全に無効化する: [`set_tracing_disabled(True)`][agents.set_tracing_disabled] -2. トレーシング用の OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key] - このキーはトレースのアップロードにのみ使用され、[platform.openai.com](https://platform.openai.com/) のものが必要です。 -3. OpenAI 以外のトレースプロセッサーを使う。詳しくは [tracing ドキュメント](../tracing.md#custom-tracing-processors) を参照してください。 +2. トレース用に OpenAI キーを設定する: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key] + この API キーはトレースのアップロードのみに使用され、[platform.openai.com](https://platform.openai.com/) で取得したものが必要です。 +3. OpenAI 以外のトレースプロセッサーを使用する。詳細は [tracing のドキュメント](../tracing.md#custom-tracing-processors) を参照してください。 -### Responses API サポート +### Responses API のサポート -SDK は既定で Responses API を使用しますが、多くの LLM プロバイダーはまだ対応していません。そのため 404 などのエラーが発生する場合があります。対処方法は 2 つです。 +SDK はデフォルトで Responses API を使用しますが、ほとんどの LLM プロバイダーはまだ非対応です。その結果、404 などのエラーが発生することがあります。対処方法は次の 2 つです。 1. [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api] を呼び出す - 環境変数 `OPENAI_API_KEY` と `OPENAI_BASE_URL` を設定している場合に機能します。 + `OPENAI_API_KEY` と `OPENAI_BASE_URL` を環境変数で設定している場合に有効です。 2. [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel] を使用する - コード例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) にあります。 + 例は [こちら](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/) にあります。 ### structured outputs のサポート 一部のモデルプロバイダーは [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) をサポートしていません。その場合、次のようなエラーが発生することがあります。 ``` + BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} + ``` -これは一部プロバイダーの制限で、JSON 出力はサポートしていても `json_schema` を指定できません。現在修正に取り組んでいますが、JSON スキーマ出力をサポートしているプロバイダーを利用することを推奨します。そうでない場合、不正な JSON によりアプリが頻繁に壊れる可能性があります。 \ No newline at end of file +これは一部プロバイダーの制限で、JSON 出力自体はサポートしていても `json_schema` を指定できないことが原因です。修正に向けて取り組んでいますが、JSON スキーマ出力をサポートしているプロバイダーを使用することをお勧めします。そうでないと、不正な JSON が返されてアプリが頻繁に壊れる可能性があります。 + +## プロバイダーを跨いだモデルの組み合わせ + +モデルプロバイダーごとの機能差に注意しないと、エラーが発生します。たとえば OpenAI は structured outputs、マルチモーダル入力、ホスト型の file search や web search をサポートしていますが、多くの他プロバイダーは非対応です。以下の制限に留意してください。 + +- 対応していないプロバイダーには未サポートの `tools` を送らない +- テキストのみのモデルを呼び出す前にマルチモーダル入力を除外する +- structured JSON 出力をサポートしていないプロバイダーでは、不正な JSON が返ることがある点に注意する \ No newline at end of file diff --git a/docs/mcp.md b/docs/mcp.md index 76d142029..eef61a047 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -4,7 +4,7 @@ The [Model context protocol](https://modelcontextprotocol.io/introduction) (aka > MCP is an open protocol that standardizes how applications provide context to LLMs. Think of MCP like a USB-C port for AI applications. Just as USB-C provides a standardized way to connect your devices to various peripherals and accessories, MCP provides a standardized way to connect AI models to different data sources and tools. -The Agents SDK has support for MCP. This enables you to use a wide range of MCP servers to provide tools to your Agents. +The Agents SDK has support for MCP. This enables you to use a wide range of MCP servers to provide tools and prompts to your Agents. ## MCP servers @@ -19,13 +19,20 @@ You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServe For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). ```python +from agents.run_context import RunContextWrapper + async with MCPServerStdio( params={ "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], } ) as server: - tools = await server.list_tools() + # Note: In practice, you typically add the server to an Agent + # and let the framework handle tool listing automatically. + # Direct calls to list_tools() require run_context and agent parameters. + run_context = RunContextWrapper(context=None) + agent = Agent(name="test", instructions="test") + tools = await server.list_tools(run_context, agent) ``` ## Using MCP servers @@ -41,6 +48,125 @@ agent=Agent( ) ``` +## Tool filtering + +You can filter which tools are available to your Agent by configuring tool filters on MCP servers. The SDK supports both static and dynamic tool filtering. + +### Static tool filtering + +For simple allow/block lists, you can use static filtering: + +```python +from agents.mcp import create_static_tool_filter + +# Only expose specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ) +) + +# Exclude specific tools from this server +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=create_static_tool_filter( + blocked_tool_names=["delete_file"] + ) +) + +``` + +**When both `allowed_tool_names` and `blocked_tool_names` are configured, the processing order is:** +1. First apply `allowed_tool_names` (allowlist) - only keep the specified tools +2. Then apply `blocked_tool_names` (blocklist) - exclude specified tools from the remaining tools + +For example, if you configure `allowed_tool_names=["read_file", "write_file", "delete_file"]` and `blocked_tool_names=["delete_file"]`, only `read_file` and `write_file` tools will be available. + +### Dynamic tool filtering + +For more complex filtering logic, you can use dynamic filters with functions: + +```python +from agents.mcp import ToolFilterContext + +# Simple synchronous filter +def custom_filter(context: ToolFilterContext, tool) -> bool: + """Example of a custom tool filter.""" + # Filter logic based on tool name patterns + return tool.name.startswith("allowed_prefix") + +# Context-aware filter +def context_aware_filter(context: ToolFilterContext, tool) -> bool: + """Filter tools based on context information.""" + # Access agent information + agent_name = context.agent.name + + # Access server information + server_name = context.server_name + + # Implement your custom filtering logic here + return some_filtering_logic(agent_name, server_name, tool) + +# Asynchronous filter +async def async_filter(context: ToolFilterContext, tool) -> bool: + """Example of an asynchronous filter.""" + # Perform async operations if needed + result = await some_async_check(context, tool) + return result + +server = MCPServerStdio( + params={ + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", samples_dir], + }, + tool_filter=custom_filter # or context_aware_filter or async_filter +) +``` + +The `ToolFilterContext` provides access to: +- `run_context`: The current run context +- `agent`: The agent requesting the tools +- `server_name`: The name of the MCP server + +## Prompts + +MCP servers can also provide prompts that can be used to dynamically generate agent instructions. This allows you to create reusable instruction templates that can be customized with parameters. + +### Using prompts + +MCP servers that support prompts provide two key methods: + +- `list_prompts()`: Lists all available prompts on the server +- `get_prompt(name, arguments)`: Gets a specific prompt with optional parameters + +```python +# List available prompts +prompts_result = await server.list_prompts() +for prompt in prompts_result.prompts: + print(f"Prompt: {prompt.name} - {prompt.description}") + +# Get a specific prompt with parameters +prompt_result = await server.get_prompt( + "generate_code_review_instructions", + {"focus": "security vulnerabilities", "language": "python"} +) +instructions = prompt_result.messages[0].content.text + +# Use the prompt-generated instructions with an Agent +agent = Agent( + name="Code Reviewer", + instructions=instructions, # Instructions from MCP prompt + mcp_servers=[server] +) +``` + ## Caching Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. diff --git a/docs/models/index.md b/docs/models/index.md index 1c89d778a..b3b2b7f0b 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -93,6 +93,22 @@ english_agent = Agent( ) ``` +Also, when you use OpenAI's Responses API, [there are a few other optional parameters](https://platform.openai.com/docs/api-reference/responses/create) (e.g., `user`, `service_tier`, and so on). If they are not available at the top level, you can use `extra_args` to pass them as well. + +```python +from agents import Agent, ModelSettings + +english_agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings( + temperature=0.1, + extra_args={"service_tier": "flex", "user": "user_12345"}, + ), +) +``` + ## Common issues with using other LLM providers ### Tracing client error 401 diff --git a/docs/release.md b/docs/release.md index 3bfc6e9b8..a86103f96 100644 --- a/docs/release.md +++ b/docs/release.md @@ -1,4 +1,4 @@ -# Release process +# Release process/changelog The project follows a slightly modified version of semantic versioning using the form `0.Y.Z`. The leading `0` indicates the SDK is still evolving rapidly. Increment the components as follows: @@ -12,7 +12,13 @@ If you don't want breaking changes, we recommend pinning to `0.0.x` versions in We will increment `Z` for non-breaking changes: -- Bug fixes -- New features -- Changes to private interfaces -- Updates to beta features +- Bug fixes +- New features +- Changes to private interfaces +- Updates to beta features + +## Breaking change changelog + +### 0.1.0 + +In this version, [`MCPServer.list_tools()`][agents.mcp.server.MCPServer] has two new params: `run_context` and `agent`. You'll need to add these params to any classes that subclass `MCPServer`. diff --git a/docs/scripts/translate_docs.py b/docs/scripts/translate_docs.py index b2e8b44fc..5dada2681 100644 --- a/docs/scripts/translate_docs.py +++ b/docs/scripts/translate_docs.py @@ -1,5 +1,7 @@ # ruff: noqa import os +import sys +import argparse from openai import OpenAI from concurrent.futures import ThreadPoolExecutor @@ -263,24 +265,45 @@ def translate_single_source_file(file_path: str) -> None: def main(): - # Traverse the source directory - for root, _, file_names in os.walk(source_dir): - # Skip the target directories - if any(lang in root for lang in languages): - continue - # Increasing this will make the translation faster; you can decide considering the model's capacity - concurrency = 6 - with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = [] - for file_name in file_names: - filepath = os.path.join(root, file_name) - futures.append(executor.submit(translate_single_source_file, filepath)) - if len(futures) >= concurrency: - for future in futures: - future.result() - futures.clear() - - print("Translation completed.") + parser = argparse.ArgumentParser(description="Translate documentation files") + parser.add_argument("--file", type=str, help="Specific file to translate (relative to docs directory)") + args = parser.parse_args() + + if args.file: + # Translate a single file + # Handle both "foo.md" and "docs/foo.md" formats + if args.file.startswith("docs/"): + # Remove "docs/" prefix if present + relative_file = args.file[5:] + else: + relative_file = args.file + + file_path = os.path.join(source_dir, relative_file) + if os.path.exists(file_path): + translate_single_source_file(file_path) + print(f"Translation completed for {relative_file}") + else: + print(f"Error: File {file_path} does not exist") + sys.exit(1) + else: + # Traverse the source directory (original behavior) + for root, _, file_names in os.walk(source_dir): + # Skip the target directories + if any(lang in root for lang in languages): + continue + # Increasing this will make the translation faster; you can decide considering the model's capacity + concurrency = 6 + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [] + for file_name in file_names: + filepath = os.path.join(root, file_name) + futures.append(executor.submit(translate_single_source_file, filepath)) + if len(futures) >= concurrency: + for future in futures: + future.result() + futures.clear() + + print("Translation completed.") if __name__ == "__main__": diff --git a/examples/basic/agent_lifecycle_example.py b/examples/basic/agent_lifecycle_example.py index 29bb18c96..b4334a83b 100644 --- a/examples/basic/agent_lifecycle_example.py +++ b/examples/basic/agent_lifecycle_example.py @@ -101,12 +101,10 @@ async def main() -> None: ### (Start Agent) 1: Agent Start Agent started ### (Start Agent) 2: Agent Start Agent started tool random_number ### (Start Agent) 3: Agent Start Agent ended tool random_number with result 37 -### (Start Agent) 4: Agent Start Agent started -### (Start Agent) 5: Agent Start Agent handed off to Multiply Agent +### (Start Agent) 4: Agent Start Agent handed off to Multiply Agent ### (Multiply Agent) 1: Agent Multiply Agent started ### (Multiply Agent) 2: Agent Multiply Agent started tool multiply_by_two ### (Multiply Agent) 3: Agent Multiply Agent ended tool multiply_by_two with result 74 -### (Multiply Agent) 4: Agent Multiply Agent started -### (Multiply Agent) 5: Agent Multiply Agent ended with output number=74 +### (Multiply Agent) 4: Agent Multiply Agent ended with output number=74 Done! """ diff --git a/examples/basic/hello_world_jupyter.ipynb b/examples/basic/hello_world_jupyter.ipynb new file mode 100644 index 000000000..42ee8e6a2 --- /dev/null +++ b/examples/basic/hello_world_jupyter.ipynb @@ -0,0 +1,45 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8a77ee2e-22f2-409c-837d-b994978b0aa2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A function calls self, \n", + "Unraveling layers deep, \n", + "Base case ends the quest. \n", + "\n", + "Infinite loops lurk, \n", + "Mind the base condition well, \n", + "Or it will not work. \n", + "\n", + "Trees and lists unfold, \n", + "Elegant solutions bloom, \n", + "Recursion's art told.\n" + ] + } + ], + "source": [ + "from agents import Agent, Runner\n", + "\n", + "agent = Agent(name=\"Assistant\", instructions=\"You are a helpful assistant\")\n", + "\n", + "# Intended for Jupyter notebooks where there's an existing event loop\n", + "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", + "print(result.final_output)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/basic/hello_world_jupyter.py b/examples/basic/hello_world_jupyter.py deleted file mode 100644 index c929a7c68..000000000 --- a/examples/basic/hello_world_jupyter.py +++ /dev/null @@ -1,11 +0,0 @@ -from agents import Agent, Runner - -agent = Agent(name="Assistant", instructions="You are a helpful assistant") - -# Intended for Jupyter notebooks where there's an existing event loop -result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 -print(result.final_output) - -# Code within code loops, -# Infinite mirrors reflect— -# Logic folds on self. diff --git a/examples/basic/lifecycle_example.py b/examples/basic/lifecycle_example.py index 285bfecd6..02ce449f4 100644 --- a/examples/basic/lifecycle_example.py +++ b/examples/basic/lifecycle_example.py @@ -105,14 +105,12 @@ async def main() -> None: Enter a max number: 250 ### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens ### 2: Tool random_number started. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens -### 3: Tool random_number ended with result 101. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens -### 4: Agent Start Agent started. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens -### 5: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens -### 6: Agent Multiply Agent started. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens -### 7: Tool multiply_by_two started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens -### 8: Tool multiply_by_two ended with result 202. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens -### 9: Agent Multiply Agent started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens -### 10: Agent Multiply Agent ended with output number=202. Usage: 4 requests, 714 input tokens, 63 output tokens, 777 total tokens +### 3: Tool random_number ended with result 101. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total token +### 4: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens +### 5: Agent Multiply Agent started. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens +### 6: Tool multiply_by_two started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens +### 7: Tool multiply_by_two ended with result 202. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens +### 8: Agent Multiply Agent ended with output number=202. Usage: 4 requests, 714 input tokens, 63 output tokens, 777 total tokens Done! """ diff --git a/examples/mcp/prompt_server/README.md b/examples/mcp/prompt_server/README.md new file mode 100644 index 000000000..c1b1c3b37 --- /dev/null +++ b/examples/mcp/prompt_server/README.md @@ -0,0 +1,29 @@ +# MCP Prompt Server Example + +This example uses a local MCP prompt server in [server.py](server.py). + +Run the example via: + +``` +uv run python examples/mcp/prompt_server/main.py +``` + +## Details + +The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `http://localhost:8000/mcp` and provides user-controlled prompts that generate agent instructions. + +The server exposes prompts like `generate_code_review_instructions` that take parameters such as focus area and programming language. The agent calls these prompts to dynamically generate its system instructions based on user-provided parameters. + +## Workflow + +The example demonstrates two key functions: + +1. **`show_available_prompts`** - Lists all available prompts on the MCP server, showing users what prompts they can select from. This demonstrates the discovery aspect of MCP prompts. + +2. **`demo_code_review`** - Shows the complete user-controlled prompt workflow: + - Calls `generate_code_review_instructions` with specific parameters (focus: "security vulnerabilities", language: "python") + - Uses the generated instructions to create an Agent with specialized code review capabilities + - Runs the agent against vulnerable sample code (command injection via `os.system`) + - The agent analyzes the code and provides security-focused feedback using available tools + +This pattern allows users to dynamically configure agent behavior through MCP prompts rather than hardcoded instructions. \ No newline at end of file diff --git a/examples/mcp/prompt_server/main.py b/examples/mcp/prompt_server/main.py new file mode 100644 index 000000000..8f2991fc0 --- /dev/null +++ b/examples/mcp/prompt_server/main.py @@ -0,0 +1,110 @@ +import asyncio +import os +import shutil +import subprocess +import time +from typing import Any + +from agents import Agent, Runner, gen_trace_id, trace +from agents.mcp import MCPServer, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + + +async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str, **kwargs) -> str: + """Get agent instructions by calling MCP prompt endpoint (user-controlled)""" + print(f"Getting instructions from prompt: {prompt_name}") + + try: + prompt_result = await mcp_server.get_prompt(prompt_name, kwargs) + content = prompt_result.messages[0].content + if hasattr(content, 'text'): + instructions = content.text + else: + instructions = str(content) + print("Generated instructions") + return instructions + except Exception as e: + print(f"Failed to get instructions: {e}") + return f"You are a helpful assistant. Error: {e}" + + +async def demo_code_review(mcp_server: MCPServer): + """Demo: Code review with user-selected prompt""" + print("=== CODE REVIEW DEMO ===") + + # User explicitly selects prompt and parameters + instructions = await get_instructions_from_prompt( + mcp_server, + "generate_code_review_instructions", + focus="security vulnerabilities", + language="python", + ) + + agent = Agent( + name="Code Reviewer Agent", + instructions=instructions, # Instructions from MCP prompt + model_settings=ModelSettings(tool_choice="auto"), + ) + + message = """Please review this code: + +def process_user_input(user_input): + command = f"echo {user_input}" + os.system(command) + return "Command executed" + +""" + + print(f"Running: {message[:60]}...") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + print("\n" + "=" * 50 + "\n") + + +async def show_available_prompts(mcp_server: MCPServer): + """Show available prompts for user selection""" + print("=== AVAILABLE PROMPTS ===") + + prompts_result = await mcp_server.list_prompts() + print("User can select from these prompts:") + for i, prompt in enumerate(prompts_result.prompts, 1): + print(f" {i}. {prompt.name} - {prompt.description}") + print() + + +async def main(): + async with MCPServerStreamableHttp( + name="Simple Prompt Server", + params={"url": "http://localhost:8000/mcp"}, + ) as server: + trace_id = gen_trace_id() + with trace(workflow_name="Simple Prompt Demo", trace_id=trace_id): + print(f"Trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") + + await show_available_prompts(server) + await demo_code_review(server) + + +if __name__ == "__main__": + if not shutil.which("uv"): + raise RuntimeError("uv is not installed") + + process: subprocess.Popen[Any] | None = None + try: + this_dir = os.path.dirname(os.path.abspath(__file__)) + server_file = os.path.join(this_dir, "server.py") + + print("Starting Simple Prompt Server...") + process = subprocess.Popen(["uv", "run", server_file]) + time.sleep(3) + print("Server started\n") + except Exception as e: + print(f"Error starting server: {e}") + exit(1) + + try: + asyncio.run(main()) + finally: + if process: + process.terminate() + print("Server terminated.") diff --git a/examples/mcp/prompt_server/server.py b/examples/mcp/prompt_server/server.py new file mode 100644 index 000000000..01dcbac34 --- /dev/null +++ b/examples/mcp/prompt_server/server.py @@ -0,0 +1,37 @@ +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP("Prompt Server") + + +# Instruction-generating prompts (user-controlled) +@mcp.prompt() +def generate_code_review_instructions( + focus: str = "general code quality", language: str = "python" +) -> str: + """Generate agent instructions for code review tasks""" + print(f"[debug-server] generate_code_review_instructions({focus}, {language})") + + return f"""You are a senior {language} code review specialist. Your role is to provide comprehensive code analysis with focus on {focus}. + +INSTRUCTIONS: +- Analyze code for quality, security, performance, and best practices +- Provide specific, actionable feedback with examples +- Identify potential bugs, vulnerabilities, and optimization opportunities +- Suggest improvements with code examples when applicable +- Be constructive and educational in your feedback +- Focus particularly on {focus} aspects + +RESPONSE FORMAT: +1. Overall Assessment +2. Specific Issues Found +3. Security Considerations +4. Performance Notes +5. Recommended Improvements +6. Best Practices Suggestions + +Use the available tools to check current time if you need timestamps for your analysis.""" + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/examples/reasoning_content/__init__.py b/examples/reasoning_content/__init__.py new file mode 100644 index 000000000..f24b2606d --- /dev/null +++ b/examples/reasoning_content/__init__.py @@ -0,0 +1,3 @@ +""" +Examples demonstrating how to use models that provide reasoning content. +""" diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py new file mode 100644 index 000000000..5f67e1779 --- /dev/null +++ b/examples/reasoning_content/main.py @@ -0,0 +1,124 @@ +""" +Example demonstrating how to use the reasoning content feature with models that support it. + +Some models, like deepseek-reasoner, provide a reasoning_content field in addition to the regular content. +This example shows how to access and use this reasoning content from both streaming and non-streaming responses. + +To run this example, you need to: +1. Set your OPENAI_API_KEY environment variable +2. Use a model that supports reasoning content (e.g., deepseek-reasoner) +""" + +import asyncio +import os +from typing import Any, cast + +from agents import ModelSettings +from agents.models.interface import ModelTracing +from agents.models.openai_provider import OpenAIProvider +from agents.types import ResponseOutputRefusal, ResponseOutputText # type: ignore + +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "deepseek-reasoner" + + +async def stream_with_reasoning_content(): + """ + Example of streaming a response from a model that provides reasoning content. + The reasoning content will be emitted as separate events. + """ + provider = OpenAIProvider() + model = provider.get_model(MODEL_NAME) + + print("\n=== Streaming Example ===") + print("Prompt: Write a haiku about recursion in programming") + + reasoning_content = "" + regular_content = "" + + async for event in model.stream_response( + system_instructions="You are a helpful assistant that writes creative content.", + input="Write a haiku about recursion in programming", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ): + if event.type == "response.reasoning_summary_text.delta": + print( + f"\033[33m{event.delta}\033[0m", end="", flush=True + ) # Yellow for reasoning content + reasoning_content += event.delta + elif event.type == "response.output_text.delta": + print(f"\033[32m{event.delta}\033[0m", end="", flush=True) # Green for regular content + regular_content += event.delta + + print("\n\nReasoning Content:") + print(reasoning_content) + print("\nRegular Content:") + print(regular_content) + print("\n") + + +async def get_response_with_reasoning_content(): + """ + Example of getting a complete response from a model that provides reasoning content. + The reasoning content will be available as a separate item in the response. + """ + provider = OpenAIProvider() + model = provider.get_model(MODEL_NAME) + + print("\n=== Non-streaming Example ===") + print("Prompt: Explain the concept of recursion in programming") + + response = await model.get_response( + system_instructions="You are a helpful assistant that explains technical concepts clearly.", + input="Explain the concept of recursion in programming", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ) + + # Extract reasoning content and regular content from the response + reasoning_content = None + regular_content = None + + for item in response.output: + if hasattr(item, "type") and item.type == "reasoning": + reasoning_content = item.summary[0].text + elif hasattr(item, "type") and item.type == "message": + if item.content and len(item.content) > 0: + content_item = item.content[0] + if isinstance(content_item, ResponseOutputText): + regular_content = content_item.text + elif isinstance(content_item, ResponseOutputRefusal): + refusal_item = cast(Any, content_item) + regular_content = refusal_item.refusal + + print("\nReasoning Content:") + print(reasoning_content or "No reasoning content provided") + + print("\nRegular Content:") + print(regular_content or "No regular content provided") + + print("\n") + + +async def main(): + try: + await stream_with_reasoning_content() + await get_response_with_reasoning_content() + except Exception as e: + print(f"Error: {e}") + print("\nNote: This example requires a model that supports reasoning content.") + print("You may need to use a specific model like deepseek-reasoner or similar.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/reasoning_content/runner_example.py b/examples/reasoning_content/runner_example.py new file mode 100644 index 000000000..e51f85799 --- /dev/null +++ b/examples/reasoning_content/runner_example.py @@ -0,0 +1,88 @@ +""" +Example demonstrating how to use the reasoning content feature with the Runner API. + +This example shows how to extract and use reasoning content from responses when using +the Runner API, which is the most common way users interact with the Agents library. + +To run this example, you need to: +1. Set your OPENAI_API_KEY environment variable +2. Use a model that supports reasoning content (e.g., deepseek-reasoner) +""" + +import asyncio +import os +from typing import Any + +from agents import Agent, Runner, trace +from agents.items import ReasoningItem + +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "deepseek-reasoner" + + +async def main(): + print(f"Using model: {MODEL_NAME}") + + # Create an agent with a model that supports reasoning content + agent = Agent( + name="Reasoning Agent", + instructions="You are a helpful assistant that explains your reasoning step by step.", + model=MODEL_NAME, + ) + + # Example 1: Non-streaming response + with trace("Reasoning Content - Non-streaming"): + print("\n=== Example 1: Non-streaming response ===") + result = await Runner.run( + agent, "What is the square root of 841? Please explain your reasoning." + ) + + # Extract reasoning content from the result items + reasoning_content = None + # RunResult has 'response' attribute which has 'output' attribute + for item in result.response.output: # type: ignore + if isinstance(item, ReasoningItem): + reasoning_content = item.summary[0].text # type: ignore + break + + print("\nReasoning Content:") + print(reasoning_content or "No reasoning content provided") + + print("\nFinal Output:") + print(result.final_output) + + # Example 2: Streaming response + with trace("Reasoning Content - Streaming"): + print("\n=== Example 2: Streaming response ===") + print("\nStreaming response:") + + # Buffers to collect reasoning and regular content + reasoning_buffer = "" + content_buffer = "" + + # RunResultStreaming is async iterable + stream = Runner.run_streamed(agent, "What is 15 x 27? Please explain your reasoning.") + + async for event in stream: # type: ignore + if isinstance(event, ReasoningItem): + # This is reasoning content + reasoning_item: Any = event + reasoning_buffer += reasoning_item.summary[0].text + print( + f"\033[33m{reasoning_item.summary[0].text}\033[0m", end="", flush=True + ) # Yellow for reasoning + elif hasattr(event, "text"): + # This is regular content + content_buffer += event.text + print( + f"\033[32m{event.text}\033[0m", end="", flush=True + ) # Green for regular content + + print("\n\nCollected Reasoning Content:") + print(reasoning_buffer) + + print("\nCollected Final Answer:") + print(content_buffer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 75e714768..e659348ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.19" +version = "0.1.0" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" @@ -23,14 +23,14 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Intended Audience :: Developers", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: MIT License", ] [project.urls] -Homepage = "https://github.com/openai/openai-agents-python" +Homepage = "https://openai.github.io/openai-agents-python/" Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 3eecaead0..1296b72be 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -206,7 +206,6 @@ def enable_verbose_stdout_logging(): "ToolCallItem", "ToolCallOutputItem", "ReasoningItem", - "ModelResponse", "ItemHelpers", "RunHooks", "AgentHooks", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a75c5e825..4ac8b316b 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -28,6 +28,9 @@ ActionType, ActionWait, ) +from openai.types.responses.response_input_item_param import ( + ComputerCallOutputAcknowledgedSafetyCheck, +) from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse from openai.types.responses.response_output_item import ( ImageGenerationCall, @@ -67,6 +70,7 @@ from .stream_events import RunItemStreamEvent, StreamEvent from .tool import ( ComputerTool, + ComputerToolSafetyCheckData, FunctionTool, FunctionToolResult, HostedMCPTool, @@ -638,6 +642,29 @@ async def execute_computer_actions( results: list[RunItem] = [] # Need to run these serially, because each action can affect the computer state for action in actions: + acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None + if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check: + acknowledged = [] + for check in action.tool_call.pending_safety_checks: + data = ComputerToolSafetyCheckData( + ctx_wrapper=context_wrapper, + agent=agent, + tool_call=action.tool_call, + safety_check=check, + ) + maybe = action.computer_tool.on_safety_check(data) + ack = await maybe if inspect.isawaitable(maybe) else maybe + if ack: + acknowledged.append( + ComputerCallOutputAcknowledgedSafetyCheck( + id=check.id, + code=check.code, + message=check.message, + ) + ) + else: + raise UserError("Computer tool safety check was not acknowledged") + results.append( await ComputerAction.execute( agent=agent, @@ -645,6 +672,7 @@ async def execute_computer_actions( hooks=hooks, context_wrapper=context_wrapper, config=config, + acknowledged_safety_checks=acknowledged, ) ) @@ -998,6 +1026,7 @@ async def execute( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, + acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None, ) -> RunItem: output_func = ( cls._get_screenshot_async(action.computer_tool.computer, action.tool_call) @@ -1036,6 +1065,7 @@ async def execute( "image_url": image_url, }, type="computer_call_output", + acknowledged_safety_checks=acknowledged_safety_checks, ), ) diff --git a/src/agents/agent.py b/src/agents/agent.py index 61a9abe0c..6c87297f1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -256,14 +256,18 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools(self) -> list[Tool]: + async def get_mcp_tools( + self, run_context: RunContextWrapper[TContext] + ) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) - return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" - mcp_tools = await self.get_mcp_tools() + mcp_tools = await self.get_mcp_tools(run_context) async def _check_tool_enabled(tool: Tool) -> bool: if not isinstance(tool, FunctionTool): diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index c58a52dae..a06c61dc3 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -98,7 +98,11 @@ async def get_response( logger.debug("Received model response") else: logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" + f"""LLM resp:\n{ + json.dumps( + response.choices[0].message.model_dump(), indent=2, ensure_ascii=False + ) + }\n""" ) if hasattr(response, "usage"): @@ -269,8 +273,8 @@ async def _fetch_response( else: logger.debug( f"Calling Litellm model: {self.model}\n" - f"{json.dumps(converted_messages, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools, indent=2)}\n" + f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index dd1db1fee..e1a91e189 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -337,7 +337,8 @@ def function_schema( # 5. Return as a FuncSchema dataclass return FuncSchema( name=func_name, - description=description_override or doc_info.description if doc_info else None, + # Ensure description_override takes precedence even if docstring info is disabled. + description=description_override or (doc_info.description if doc_info else None), params_pydantic_model=dynamic_model, params_json_schema=json_schema, signature=sig, diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 76c93a298..cb2752e4f 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -15,6 +15,7 @@ from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError from .util import _error_tracing, _json, _transforms +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent @@ -99,6 +100,11 @@ class Handoff(Generic[TContext]): True, as it increases the likelihood of correct JSON input. """ + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and + agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable + a handoff based on your context/state.""" + def get_transfer_message(self, agent: Agent[Any]) -> str: return json.dumps({"assistant": agent.name}) @@ -121,6 +127,7 @@ def handoff( tool_name_override: str | None = None, tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -133,6 +140,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -144,6 +152,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -154,6 +163,7 @@ def handoff( on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: """Create a handoff from an agent. @@ -166,6 +176,9 @@ def handoff( input_type: the type of the input to the handoff. If provided, the input will be validated against this type. Only relevant if you pass a function that takes an input. input_filter: a function that filters the inputs that are passed to the next agent. + is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the handoff is enabled. Disabled handoffs are + hidden from the LLM at runtime. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( "You must provide either both on_handoff and input_type, or neither" @@ -233,4 +246,5 @@ async def _invoke_handoff( on_invoke_handoff=_invoke_handoff, input_filter=input_filter, agent_name=agent.name, + is_enabled=is_enabled, ) diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index d4eb8fa68..da5a68b16 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -11,7 +11,14 @@ except ImportError: pass -from .util import MCPUtil +from .util import ( + MCPUtil, + ToolFilter, + ToolFilterCallable, + ToolFilterContext, + ToolFilterStatic, + create_static_tool_filter, +) __all__ = [ "MCPServer", @@ -22,4 +29,9 @@ "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", "MCPUtil", + "ToolFilter", + "ToolFilterCallable", + "ToolFilterContext", + "ToolFilterStatic", + "create_static_tool_filter", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 3d1e17790..4fd606e34 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -2,21 +2,27 @@ import abc import asyncio +import inspect from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.message import SessionMessage -from mcp.types import CallToolResult, InitializeResult +from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError from ..logger import logger +from ..run_context import RunContextWrapper +from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic + +if TYPE_CHECKING: + from ..agent import Agent class MCPServer(abc.ABC): @@ -44,7 +50,11 @@ async def cleanup(self): pass @abc.abstractmethod - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any] | None = None, + agent: Agent[Any] | None = None, + ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -53,11 +63,30 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C """Invoke a tool on the server.""" pass + @abc.abstractmethod + async def list_prompts( + self, + ) -> ListPromptsResult: + """List the prompts available on the server.""" + pass + + @abc.abstractmethod + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + """Get a specific prompt from the server.""" + pass + class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): + def __init__( + self, + cache_tools_list: bool, + client_session_timeout_seconds: float | None, + tool_filter: ToolFilter = None, + ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -68,6 +97,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -81,6 +111,86 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self._cache_dirty = True self._tools_list: list[MCPTool] | None = None + self.tool_filter = tool_filter + + async def _apply_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply the tool filter to the list of tools.""" + if self.tool_filter is None: + return tools + + # Handle static tool filter + if isinstance(self.tool_filter, dict): + return self._apply_static_tool_filter(tools, self.tool_filter) + + # Handle callable tool filter (dynamic filter) + else: + return await self._apply_dynamic_tool_filter(tools, run_context, agent) + + def _apply_static_tool_filter( + self, tools: list[MCPTool], static_filter: ToolFilterStatic + ) -> list[MCPTool]: + """Apply static tool filtering based on allowlist and blocklist.""" + filtered_tools = tools + + # Apply allowed_tool_names filter (whitelist) + if "allowed_tool_names" in static_filter: + allowed_names = static_filter["allowed_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name in allowed_names] + + # Apply blocked_tool_names filter (blacklist) + if "blocked_tool_names" in static_filter: + blocked_names = static_filter["blocked_tool_names"] + filtered_tools = [t for t in filtered_tools if t.name not in blocked_names] + + return filtered_tools + + async def _apply_dynamic_tool_filter( + self, + tools: list[MCPTool], + run_context: RunContextWrapper[Any], + agent: Agent[Any], + ) -> list[MCPTool]: + """Apply dynamic tool filtering using a callable filter function.""" + + # Ensure we have a callable filter and cast to help mypy + if not callable(self.tool_filter): + raise ValueError("Tool filter must be callable for dynamic filtering") + tool_filter_func = cast(ToolFilterCallable, self.tool_filter) + + # Create filter context + filter_context = ToolFilterContext( + run_context=run_context, + agent=agent, + server_name=self.name, + ) + + filtered_tools = [] + for tool in tools: + try: + # Call the filter function with context + result = tool_filter_func(filter_context, tool) + + if inspect.isawaitable(result): + should_include = await result + else: + should_include = result + + if should_include: + filtered_tools.append(tool) + except Exception as e: + logger.error( + f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}" + ) + # On error, exclude the tool for safety + continue + + return filtered_tools + @abc.abstractmethod def create_streams( self, @@ -131,21 +241,32 @@ async def connect(self): await self.cleanup() raise - async def list_tools(self) -> list[MCPTool]: + async def list_tools( + self, + run_context: RunContextWrapper[Any] | None = None, + agent: Agent[Any] | None = None, + ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: raise UserError("Server not initialized. Make sure you call `connect()` first.") # Return from cache if caching is enabled, we have tools, and the cache is not dirty if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - - # Reset the cache dirty to False - self._cache_dirty = False - - # Fetch the tools from the server - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list + tools = self._tools_list + else: + # Reset the cache dirty to False + self._cache_dirty = False + # Fetch the tools from the server + self._tools_list = (await self.session.list_tools()).tools + tools = self._tools_list + + # Filter tools based on tool_filter + filtered_tools = tools + if self.tool_filter is not None: + if run_context is None or agent is None: + raise UserError("run_context and agent are required for dynamic tool filtering") + filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent) + return filtered_tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: """Invoke a tool on the server.""" @@ -154,6 +275,24 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C return await self.session.call_tool(tool_name, arguments) + async def list_prompts( + self, + ) -> ListPromptsResult: + """List the prompts available on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.list_prompts() + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + """Get a specific prompt from the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.get_prompt(name, arguments) + async def cleanup(self): """Cleanup the server.""" async with self._cleanup_lock: @@ -206,6 +345,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the stdio transport. @@ -223,8 +363,13 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the command. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = StdioServerParameters( command=params["command"], @@ -283,6 +428,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -302,8 +448,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -362,6 +513,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + tool_filter: ToolFilter = None, ): """Create a new MCP server based on the Streamable HTTP transport. @@ -382,8 +534,13 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + tool_filter: The tool filter to use for filtering tools. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__( + cache_tools_list, + client_session_timeout_seconds, + tool_filter, + ) self.params = params self._name = name or f"streamable_http: {self.params['url']}" diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 5a963bc01..48da9f841 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,6 +1,9 @@ import functools import json -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from typing_extensions import NotRequired, TypedDict from agents.strict_schema import ensure_strict_json_schema @@ -10,25 +13,102 @@ from ..run_context import RunContextWrapper from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..util._types import MaybeAwaitable if TYPE_CHECKING: from mcp.types import Tool as MCPTool + from ..agent import Agent from .server import MCPServer +@dataclass +class ToolFilterContext: + """Context information available to tool filter functions.""" + + run_context: RunContextWrapper[Any] + """The current run context.""" + + agent: "Agent[Any]" + """The agent that is requesting the tool list.""" + + server_name: str + """The name of the MCP server.""" + + +ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]] +"""A function that determines whether a tool should be available. + +Args: + context: The context information including run context, agent, and server name. + tool: The MCP tool to filter. + +Returns: + Whether the tool should be available (True) or filtered out (False). +""" + + +class ToolFilterStatic(TypedDict): + """Static tool filter configuration using allowlists and blocklists.""" + + allowed_tool_names: NotRequired[list[str]] + """Optional list of tool names to allow (whitelist). + If set, only these tools will be available.""" + + blocked_tool_names: NotRequired[list[str]] + """Optional list of tool names to exclude (blacklist). + If set, these tools will be filtered out.""" + + +ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] +"""A tool filter that can be either a function, static configuration, or None (no filtering).""" + + +def create_static_tool_filter( + allowed_tool_names: Optional[list[str]] = None, + blocked_tool_names: Optional[list[str]] = None, +) -> Optional[ToolFilterStatic]: + """Create a static tool filter from allowlist and blocklist parameters. + + This is a convenience function for creating a ToolFilterStatic. + + Args: + allowed_tool_names: Optional list of tool names to allow (whitelist). + blocked_tool_names: Optional list of tool names to exclude (blacklist). + + Returns: + A ToolFilterStatic if any filtering is specified, None otherwise. + """ + if allowed_tool_names is None and blocked_tool_names is None: + return None + + filter_dict: ToolFilterStatic = {} + if allowed_tool_names is not None: + filter_dict["allowed_tool_names"] = allowed_tool_names + if blocked_tool_names is not None: + filter_dict["blocked_tool_names"] = blocked_tool_names + + return filter_dict + + class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" @classmethod async def get_all_function_tools( - cls, servers: list["MCPServer"], convert_schemas_to_strict: bool + cls, + servers: list["MCPServer"], + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: - server_tools = await cls.get_function_tools(server, convert_schemas_to_strict) + server_tools = await cls.get_function_tools( + server, convert_schemas_to_strict, run_context, agent + ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: raise UserError( @@ -42,12 +122,16 @@ async def get_all_function_tools( @classmethod async def get_function_tools( - cls, server: "MCPServer", convert_schemas_to_strict: bool + cls, + server: "MCPServer", + convert_schemas_to_strict: bool, + run_context: RunContextWrapper[Any], + agent: "Agent[Any]", ) -> list[Tool]: """Get all function tools from a single MCP server.""" with mcp_tools_span(server=server.name) as span: - tools = await server.list_tools() + tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 881390279..26af94ba3 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,13 +1,50 @@ from __future__ import annotations import dataclasses +from collections.abc import Mapping from dataclasses import dataclass, fields, replace -from typing import Any, Literal +from typing import Annotated, Any, Literal, Union -from openai._types import Body, Headers, Query +from openai import Omit as _Omit +from openai._types import Body, Query +from openai.types.responses import ResponseIncludable from openai.types.shared import Reasoning -from pydantic import BaseModel - +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic_core import core_schema +from typing_extensions import TypeAlias + + +class _OmitTypeAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + def validate_from_none(value: None) -> _Omit: + return _Omit() + + from_none_schema = core_schema.chain_schema( + [ + core_schema.none_schema(), + core_schema.no_info_plain_validator_function(validate_from_none), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_none_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(_Omit), + from_none_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: None + ), + ) +Omit = Annotated[_Omit, _OmitTypeAnnotation] +Headers: TypeAlias = Mapping[str, Union[str, Omit]] @dataclass class ModelSettings: @@ -36,8 +73,13 @@ class ModelSettings: """The tool choice to use when calling the model.""" parallel_tool_calls: bool | None = None - """Whether to use parallel tool calls when calling the model. - Defaults to False if not provided.""" + """Controls whether the model can make multiple parallel tool calls in a single turn. + If not provided (i.e., set to None), this behavior defers to the underlying + model provider's default. For most current providers (e.g., OpenAI), this typically + means parallel tool calls are enabled (True). + Set to True to explicitly enable parallel tool calls, or False to restrict the + model to at most one tool call per turn. + """ truncation: Literal["auto", "disabled"] | None = None """The truncation strategy to use when calling the model.""" @@ -61,6 +103,10 @@ class ModelSettings: """Whether to include usage chunk. Defaults to True if not provided.""" + response_include: list[ResponseIncludable] | None = None + """Additional output data to include in the model response. + [include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)""" + extra_query: Query | None = None """Additional query fields to provide with the request. Defaults to None if not provided.""" diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 1d599e8c0..9d0c6cf5e 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -19,6 +19,7 @@ ChatCompletionToolMessageParam, ChatCompletionUserMessageParam, ) +from openai.types.chat.chat_completion_content_part_param import File, FileFile from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam from openai.types.chat.completion_create_params import ResponseFormat from openai.types.responses import ( @@ -27,14 +28,17 @@ ResponseFunctionToolCall, ResponseFunctionToolCallParam, ResponseInputContentParam, + ResponseInputFileParam, ResponseInputImageParam, ResponseInputTextParam, ResponseOutputMessage, ResponseOutputMessageParam, ResponseOutputRefusal, ResponseOutputText, + ResponseReasoningItem, ) from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message +from openai.types.responses.response_reasoning_item import Summary from ..agent_output import AgentOutputSchemaBase from ..exceptions import AgentsException, UserError @@ -85,6 +89,16 @@ def convert_response_format( def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TResponseOutputItem]: items: list[TResponseOutputItem] = [] + # Handle reasoning content if available + if hasattr(message, "reasoning_content") and message.reasoning_content: + items.append( + ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text=message.reasoning_content, type="summary_text")], + type="reasoning", + ) + ) + message_item = ResponseOutputMessage( id=FAKE_RESPONSES_ID, content=[], @@ -239,7 +253,19 @@ def extract_all_content( ) ) elif isinstance(c, dict) and c.get("type") == "input_file": - raise UserError(f"File uploads are not supported for chat completions {c}") + casted_file_param = cast(ResponseInputFileParam, c) + if "file_data" not in casted_file_param or not casted_file_param["file_data"]: + raise UserError( + f"Only file_data is supported for input_file {casted_file_param}" + ) + out.append( + File( + type="file", + file=FileFile( + file_data=casted_file_param["file_data"], + ), + ) + ) else: raise UserError(f"Unknown content: {c}") return out diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index d18f5912a..83fa32abc 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -20,21 +20,38 @@ ResponseOutputMessage, ResponseOutputRefusal, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, ResponseRefusalDeltaEvent, ResponseTextDeltaEvent, ResponseUsage, ) +from openai.types.responses.response_reasoning_item import Summary +from openai.types.responses.response_reasoning_summary_part_added_event import ( + Part as AddedEventPart, +) +from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from ..items import TResponseStreamEvent from .fake_id import FAKE_RESPONSES_ID +# Define a Part class for internal use +class Part: + def __init__(self, text: str, type: str): + self.text = text + self.type = type + + @dataclass class StreamingState: started: bool = False text_content_index_and_output: tuple[int, ResponseOutputText] | None = None refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None + reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) @@ -75,12 +92,65 @@ async def handle_stream( delta = chunk.choices[0].delta - # Handle text - if delta.content: + # Handle reasoning content + if hasattr(delta, "reasoning_content"): + reasoning_content = delta.reasoning_content + if reasoning_content and not state.reasoning_content_index_and_output: + state.reasoning_content_index_and_output = ( + 0, + ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text="", type="summary_text")], + type="reasoning", + ), + ) + yield ResponseOutputItemAddedEvent( + item=ResponseReasoningItem( + id=FAKE_RESPONSES_ID, + summary=[Summary(text="", type="summary_text")], + type="reasoning", + ), + output_index=0, + type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), + ) + + yield ResponseReasoningSummaryPartAddedEvent( + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + part=AddedEventPart(text="", type="summary_text"), + type="response.reasoning_summary_part.added", + sequence_number=sequence_number.get_and_increment(), + ) + + if reasoning_content and state.reasoning_content_index_and_output: + yield ResponseReasoningSummaryTextDeltaEvent( + delta=reasoning_content, + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + type="response.reasoning_summary_text.delta", + sequence_number=sequence_number.get_and_increment(), + ) + + # Create a new summary with updated text + current_summary = state.reasoning_content_index_and_output[1].summary[0] + updated_text = current_summary.text + reasoning_content + new_summary = Summary(text=updated_text, type="summary_text") + state.reasoning_content_index_and_output[1].summary[0] = new_summary + + # Handle regular content + if delta.content is not None: if not state.text_content_index_and_output: - # Initialize a content tracker for streaming text + content_index = 0 + if state.reasoning_content_index_and_output: + content_index += 1 + if state.refusal_content_index_and_output: + content_index += 1 + state.text_content_index_and_output = ( - 0 if not state.refusal_content_index_and_output else 1, + content_index, ResponseOutputText( text="", type="output_text", @@ -98,14 +168,16 @@ async def handle_stream( # Notify consumers of the start of a new output message + first content part yield ResponseOutputItemAddedEvent( item=assistant_item, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.added", sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=ResponseOutputText( text="", type="output_text", @@ -119,7 +191,8 @@ async def handle_stream( content_index=state.text_content_index_and_output[0], delta=delta.content, item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_text.delta", sequence_number=sequence_number.get_and_increment(), ) @@ -130,9 +203,14 @@ async def handle_stream( # This is always set by the OpenAI API, but not by others e.g. LiteLLM if hasattr(delta, "refusal") and delta.refusal: if not state.refusal_content_index_and_output: - # Initialize a content tracker for streaming refusal text + refusal_index = 0 + if state.reasoning_content_index_and_output: + refusal_index += 1 + if state.text_content_index_and_output: + refusal_index += 1 + state.refusal_content_index_and_output = ( - 0 if not state.text_content_index_and_output else 1, + refusal_index, ResponseOutputRefusal(refusal="", type="refusal"), ) # Start a new assistant message if one doesn't exist yet (in-progress) @@ -146,14 +224,16 @@ async def handle_stream( # Notify downstream that assistant message + first content part are starting yield ResponseOutputItemAddedEvent( item=assistant_item, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.added", sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=ResponseOutputText( text="", type="output_text", @@ -167,7 +247,8 @@ async def handle_stream( content_index=state.refusal_content_index_and_output[0], delta=delta.refusal, item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.refusal.delta", sequence_number=sequence_number.get_and_increment(), ) @@ -195,16 +276,39 @@ async def handle_stream( state.function_calls[tc_delta.index].name += ( tc_function.name if tc_function else "" ) or "" - state.function_calls[tc_delta.index].call_id += tc_delta.id or "" + state.function_calls[tc_delta.index].call_id = tc_delta.id or "" + + if state.reasoning_content_index_and_output: + yield ResponseReasoningSummaryPartDoneEvent( + item_id=FAKE_RESPONSES_ID, + output_index=0, + summary_index=0, + part=DoneEventPart( + text=state.reasoning_content_index_and_output[1].summary[0].text, + type="summary_text", + ), + type="response.reasoning_summary_part.done", + sequence_number=sequence_number.get_and_increment(), + ) + yield ResponseOutputItemDoneEvent( + item=state.reasoning_content_index_and_output[1], + output_index=0, + type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), + ) function_call_starting_index = 0 + if state.reasoning_content_index_and_output: + function_call_starting_index += 1 + if state.text_content_index_and_output: function_call_starting_index += 1 # Send end event for this content part yield ResponseContentPartDoneEvent( content_index=state.text_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=state.text_content_index_and_output[1], type="response.content_part.done", sequence_number=sequence_number.get_and_increment(), @@ -216,7 +320,8 @@ async def handle_stream( yield ResponseContentPartDoneEvent( content_index=state.refusal_content_index_and_output[0], item_id=FAKE_RESPONSES_ID, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 part=state.refusal_content_index_and_output[1], type="response.content_part.done", sequence_number=sequence_number.get_and_increment(), @@ -261,6 +366,12 @@ async def handle_stream( # Finally, send the Response completed event outputs: list[ResponseOutputItem] = [] + + # include Reasoning item if it exists + if state.reasoning_content_index_and_output: + outputs.append(state.reasoning_content_index_and_output[1]) + + # include text or refusal content if they exist if state.text_content_index_and_output or state.refusal_content_index_and_output: assistant_msg = ResponseOutputMessage( id=FAKE_RESPONSES_ID, @@ -278,7 +389,8 @@ async def handle_stream( # send a ResponseOutputItemDone for the assistant message yield ResponseOutputItemDoneEvent( item=assistant_msg, - output_index=0, + output_index=state.reasoning_content_index_and_output + is not None, # fixed 0 -> 0 or 1 type="response.output_item.done", sequence_number=sequence_number.get_and_increment(), ) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 08803d8c0..6de431b4d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -7,7 +7,8 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream from openai.types import ChatModel -from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice from openai.types.responses import Response from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @@ -74,8 +75,11 @@ async def get_response( prompt=prompt, ) - first_choice = response.choices[0] - message = first_choice.message + message: ChatCompletionMessage | None = None + first_choice: Choice | None = None + if response.choices and len(response.choices) > 0: + first_choice = response.choices[0] + message = first_choice.message if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") @@ -83,13 +87,11 @@ async def get_response( if message is not None: logger.debug( "LLM resp:\n%s\n", - json.dumps(message.model_dump(), indent=2), + json.dumps(message.model_dump(), indent=2, ensure_ascii=False), ) else: - logger.debug( - "LLM resp had no message. finish_reason: %s", - first_choice.finish_reason, - ) + finish_reason = first_choice.finish_reason if first_choice else "-" + logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") usage = ( Usage( @@ -254,8 +256,8 @@ async def _fetch_response( logger.debug("Calling LLM") else: logger.debug( - f"{json.dumps(converted_messages, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools, indent=2)}\n" + f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 961f690b0..a7ce62983 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -96,7 +96,13 @@ async def get_response( else: logger.debug( "LLM resp:\n" - f"{json.dumps([x.model_dump() for x in response.output], indent=2)}\n" + f"""{ + json.dumps( + [x.model_dump() for x in response.output], + indent=2, + ensure_ascii=False, + ) + }\n""" ) usage = ( @@ -240,13 +246,17 @@ async def _fetch_response( converted_tools = Converter.convert_tools(tools, handoffs) response_format = Converter.get_response_format(output_schema) + include: list[ResponseIncludable] = converted_tools.includes + if model_settings.response_include is not None: + include = list({*include, *model_settings.response_include}) + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Calling LLM") else: logger.debug( f"Calling LLM {self.model} with input:\n" - f"{json.dumps(list_input, indent=2)}\n" - f"Tools:\n{json.dumps(converted_tools.tools, indent=2)}\n" + f"{json.dumps(list_input, indent=2, ensure_ascii=False)}\n" + f"Tools:\n{json.dumps(converted_tools.tools, indent=2, ensure_ascii=False)}\n" f"Stream: {stream}\n" f"Tool choice: {tool_choice}\n" f"Response format: {response_format}\n" @@ -258,7 +268,7 @@ async def _fetch_response( instructions=self._non_null_or_not_given(system_instructions), model=self.model, input=list_input, - include=converted_tools.includes, + include=include, tools=converted_tools.tools, prompt=self._non_null_or_not_given(prompt), temperature=self._non_null_or_not_given(model_settings.temperature), diff --git a/src/agents/repl.py b/src/agents/repl.py index 9a4f30759..f7142555f 100644 --- a/src/agents/repl.py +++ b/src/agents/repl.py @@ -5,7 +5,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from .agent import Agent -from .items import ItemHelpers, TResponseInputItem +from .items import TResponseInputItem from .result import RunResultBase from .run import Runner from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent @@ -50,9 +50,6 @@ async def run_demo_loop(agent: Agent[Any], *, stream: bool = True) -> None: print("\n[tool called]", flush=True) elif event.item.type == "tool_call_output_item": print(f"\n[tool output: {event.item.output}]", flush=True) - elif event.item.type == "message_output_item": - message = ItemHelpers.text_message_output(event.item) - print(message, end="", flush=True) elif isinstance(event, AgentUpdatedStreamEvent): print(f"\n[Agent updated: {event.new_agent.name}]", flush=True) print() diff --git a/src/agents/run.py b/src/agents/run.py index 8a44a0e54..e5f9378ec 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,6 +2,7 @@ import asyncio import copy +import inspect from dataclasses import dataclass, field from typing import Any, Generic, cast @@ -361,7 +362,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [ - h.agent_name for h in AgentRunner._get_handoffs(current_agent) + h.agent_name + for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) ] if output_schema := AgentRunner._get_output_schema(current_agent): output_type_name = output_schema.name() @@ -641,7 +643,10 @@ async def _start_streaming( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name + for h in await cls._get_handoffs(current_agent, context_wrapper) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -798,7 +803,7 @@ async def _run_single_turn_streamed( agent.get_prompt(context_wrapper), ) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) @@ -898,7 +903,7 @@ async def _run_single_turn( ) output_schema = cls._get_output_schema(agent) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) @@ -1091,14 +1096,28 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: return AgentOutputSchema(agent.output_type) @classmethod - def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: + async def _get_handoffs( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff]: handoffs = [] for handoff_item in agent.handoffs: if isinstance(handoff_item, Handoff): handoffs.append(handoff_item) elif isinstance(handoff_item, Agent): handoffs.append(handoff(handoff_item)) - return handoffs + + async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] + return enabled @classmethod async def _get_all_tools( diff --git a/src/agents/tool.py b/src/agents/tool.py index ce66a53ba..3aab47752 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,6 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_computer_tool_call import ( + PendingSafetyCheck, + ResponseComputerToolCall, +) from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation @@ -26,6 +30,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: + from .agent import Agent ToolParams = ParamSpec("ToolParams") @@ -141,11 +146,31 @@ class ComputerTool: as well as implements the computer actions like click, screenshot, etc. """ + on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None + """Optional callback to acknowledge computer tool safety checks.""" + @property def name(self): return "computer_use_preview" +@dataclass +class ComputerToolSafetyCheckData: + """Information about a computer tool safety check.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + agent: Agent[Any] + """The agent performing the computer action.""" + + tool_call: ResponseComputerToolCall + """The computer tool call.""" + + safety_check: PendingSafetyCheck + """The pending safety check to acknowledge.""" + + @dataclass class MCPToolApprovalRequest: """A request to approve a tool call.""" diff --git a/src/agents/tracing/__init__.py b/src/agents/tracing/__init__.py index 64b0bd71f..b45c06d75 100644 --- a/src/agents/tracing/__init__.py +++ b/src/agents/tracing/__init__.py @@ -1,7 +1,5 @@ import atexit -from agents.tracing.provider import DefaultTraceProvider, TraceProvider - from .create import ( agent_span, custom_span, @@ -20,6 +18,7 @@ ) from .processor_interface import TracingProcessor from .processors import default_exporter, default_processor +from .provider import DefaultTraceProvider, TraceProvider from .setup import get_trace_provider, set_trace_provider from .span_data import ( AgentSpanData, diff --git a/src/agents/tracing/processor_interface.py b/src/agents/tracing/processor_interface.py index 4dcd897c7..0a05bcae2 100644 --- a/src/agents/tracing/processor_interface.py +++ b/src/agents/tracing/processor_interface.py @@ -23,7 +23,7 @@ def on_trace_end(self, trace: "Trace") -> None: """Called when a trace is finished. Args: - trace: The trace that started. + trace: The trace that finished. """ pass diff --git a/src/agents/voice/pipeline.py b/src/agents/voice/pipeline.py index d1dac57cf..5addd995f 100644 --- a/src/agents/voice/pipeline.py +++ b/src/agents/voice/pipeline.py @@ -125,6 +125,12 @@ async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudi self._get_tts_model(), self.config.tts_settings, self.config ) + try: + async for intro_text in self.workflow.on_start(): + await output._add_text(intro_text) + except Exception as e: + logger.warning(f"on_start() failed: {e}") + transcription_session = await self._get_stt_model().create_session( audio_input, self.config.stt_settings, diff --git a/src/agents/voice/workflow.py b/src/agents/voice/workflow.py index c706ec413..538676ad1 100644 --- a/src/agents/voice/workflow.py +++ b/src/agents/voice/workflow.py @@ -32,6 +32,14 @@ def run(self, transcription: str) -> AsyncIterator[str]: """ pass + async def on_start(self) -> AsyncIterator[str]: + """ + Optional method that runs before any user input is received. Can be used + to deliver a greeting or instruction via TTS. Defaults to doing nothing. + """ + return + yield + class VoiceWorkflowHelper: @classmethod diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index 8ff153c18..31d43c228 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -1,11 +1,14 @@ +import asyncio import json import shutil from typing import Any from mcp import Tool as MCPTool -from mcp.types import CallToolResult, TextContent +from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, PromptMessage, TextContent from agents.mcp import MCPServer +from agents.mcp.server import _MCPServerWithClientSession +from agents.mcp.util import ToolFilter tee = shutil.which("tee") or "" assert tee, "tee not found" @@ -28,11 +31,41 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class _TestFilterServer(_MCPServerWithClientSession): + """Minimal implementation of _MCPServerWithClientSession for testing tool filtering""" + + def __init__(self, tool_filter: ToolFilter, server_name: str): + # Initialize parent class properly to avoid type errors + super().__init__( + cache_tools_list=False, + client_session_timeout_seconds=None, + tool_filter=tool_filter, + ) + self._server_name: str = server_name + # Override some attributes for test isolation + self.session = None + self._cleanup_lock = asyncio.Lock() + + def create_streams(self): + raise NotImplementedError("Not needed for filtering tests") + + @property + def name(self) -> str: + return self._server_name + + class FakeMCPServer(MCPServer): - def __init__(self, tools: list[MCPTool] | None = None): + def __init__( + self, + tools: list[MCPTool] | None = None, + tool_filter: ToolFilter = None, + server_name: str = "fake_mcp_server", + ): self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] self.tool_results: list[str] = [] + self.tool_filter = tool_filter + self._server_name = server_name def add_tool(self, name: str, input_schema: dict[str, Any]): self.tools.append(MCPTool(name=name, inputSchema=input_schema)) @@ -43,8 +76,16 @@ async def connect(self): async def cleanup(self): pass - async def list_tools(self): - return self.tools + async def list_tools(self, run_context=None, agent=None): + tools = self.tools + + # Apply tool filtering using the REAL implementation + if self.tool_filter is not None: + # Use the real _MCPServerWithClientSession filtering logic + filter_server = _TestFilterServer(self.tool_filter, self.name) + tools = await filter_server._apply_tool_filter(tools, run_context, agent) + + return tools async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: self.tool_calls.append(tool_name) @@ -53,6 +94,18 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C content=[TextContent(text=self.tool_results[-1], type="text")], ) + async def list_prompts(self, run_context=None, agent=None) -> ListPromptsResult: + """Return empty list of prompts for fake server""" + return ListPromptsResult(prompts=[]) + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + """Return a simple prompt result for fake server""" + content = f"Fake prompt content for {name}" + message = PromptMessage(role="user", content=TextContent(type="text", text=content)) + return GetPromptResult(description=f"Fake prompt: {name}", messages=[message]) + @property def name(self) -> str: - return "fake_mcp_server" + return self._server_name diff --git a/tests/mcp/test_caching.py b/tests/mcp/test_caching.py index cac409e6e..f31cdf951 100644 --- a/tests/mcp/test_caching.py +++ b/tests/mcp/test_caching.py @@ -3,7 +3,9 @@ import pytest from mcp.types import ListToolsResult, Tool as MCPTool +from agents import Agent from agents.mcp import MCPServerStdio +from agents.run_context import RunContextWrapper from .helpers import DummyStreamsContextManager, tee @@ -33,25 +35,29 @@ async def test_server_caching_works( mock_list_tools.return_value = ListToolsResult(tools=tools) async with server: + # Create test context and agent + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + # Call list_tools() multiple times - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should have been called once" # Call list_tools() again, should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 1, "list_tools() should not have been called again" # Invalidate the cache and call list_tools() again server.invalidate_tools_cache() - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools assert mock_list_tools.call_count == 2, "list_tools() should be called again" # Without invalidating the cache, calling list_tools() again should return the cached value - tools = await server.list_tools() - assert tools == tools + result_tools = await server.list_tools(run_context, agent) + assert result_tools == tools diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 74356a16d..3230e63dd 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -57,7 +57,10 @@ async def test_get_all_function_tools(): server3.add_tool(names[4], schemas[4]) servers: list[MCPServer] = [server1, server2, server3] - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + + tools = await MCPUtil.get_all_function_tools(servers, False, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -70,7 +73,7 @@ async def test_get_all_function_tools(): assert tool.name == names[idx] # Also make sure it works with strict schemas - tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True) + tools = await MCPUtil.get_all_function_tools(servers, True, run_context, agent) assert len(tools) == 5 assert all(tool.name in names for tool in tools) @@ -144,7 +147,8 @@ async def test_agent_convert_schemas_true(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -208,7 +212,8 @@ async def test_agent_convert_schemas_false(): agent = Agent( name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False} ) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -245,7 +250,8 @@ async def test_agent_convert_schemas_unset(): server.add_tool("bar", non_strict_schema) server.add_tool("baz", possible_to_convert_schema) agent = Agent(name="test_agent", mcp_servers=[server]) - tools = await agent.get_mcp_tools() + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) foo_tool = next(tool for tool in tools if tool.name == "foo") assert isinstance(foo_tool, FunctionTool) @@ -279,7 +285,9 @@ async def test_util_adds_properties(): server = FakeMCPServer() server.add_tool("test_tool", schema) - tools = await MCPUtil.get_all_function_tools([server], convert_schemas_to_strict=False) + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + tools = await MCPUtil.get_all_function_tools([server], False, run_context, agent) tool = next(tool for tool in tools if tool.name == "test_tool") assert isinstance(tool, FunctionTool) diff --git a/tests/mcp/test_prompt_server.py b/tests/mcp/test_prompt_server.py new file mode 100644 index 000000000..15afe28e4 --- /dev/null +++ b/tests/mcp/test_prompt_server.py @@ -0,0 +1,301 @@ +from typing import Any + +import pytest + +from agents import Agent, Runner +from agents.mcp import MCPServer + +from ..fake_model import FakeModel +from ..test_responses import get_text_message + + +class FakeMCPPromptServer(MCPServer): + """Fake MCP server for testing prompt functionality""" + + def __init__(self, server_name: str = "fake_prompt_server"): + self.prompts: list[Any] = [] + self.prompt_results: dict[str, str] = {} + self._server_name = server_name + + def add_prompt(self, name: str, description: str, arguments: dict[str, Any] | None = None): + """Add a prompt to the fake server""" + from mcp.types import Prompt + + prompt = Prompt(name=name, description=description, arguments=[]) + self.prompts.append(prompt) + + def set_prompt_result(self, name: str, result: str): + """Set the result that should be returned for a prompt""" + self.prompt_results[name] = result + + async def connect(self): + pass + + async def cleanup(self): + pass + + async def list_prompts(self, run_context=None, agent=None): + """List available prompts""" + from mcp.types import ListPromptsResult + + return ListPromptsResult(prompts=self.prompts) + + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None): + """Get a prompt with arguments""" + from mcp.types import GetPromptResult, PromptMessage, TextContent + + if name not in self.prompt_results: + raise ValueError(f"Prompt '{name}' not found") + + content = self.prompt_results[name] + + # If it's a format string, try to format it with arguments + if arguments and "{" in content: + try: + content = content.format(**arguments) + except KeyError: + pass # Use original content if formatting fails + + message = PromptMessage(role="user", content=TextContent(type="text", text=content)) + + return GetPromptResult(description=f"Generated prompt for {name}", messages=[message]) + + async def list_tools(self, run_context=None, agent=None): + return [] + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + raise NotImplementedError("This fake server doesn't support tools") + + @property + def name(self) -> str: + return self._server_name + + +@pytest.mark.asyncio +async def test_list_prompts(): + """Test listing available prompts""" + server = FakeMCPPromptServer() + server.add_prompt( + "generate_code_review_instructions", "Generate agent instructions for code review tasks" + ) + + result = await server.list_prompts() + + assert len(result.prompts) == 1 + assert result.prompts[0].name == "generate_code_review_instructions" + assert "code review" in result.prompts[0].description + + +@pytest.mark.asyncio +async def test_get_prompt_without_arguments(): + """Test getting a prompt without arguments""" + server = FakeMCPPromptServer() + server.add_prompt("simple_prompt", "A simple prompt") + server.set_prompt_result("simple_prompt", "You are a helpful assistant.") + + result = await server.get_prompt("simple_prompt") + + assert len(result.messages) == 1 + assert result.messages[0].content.text == "You are a helpful assistant." + + +@pytest.mark.asyncio +async def test_get_prompt_with_arguments(): + """Test getting a prompt with arguments""" + server = FakeMCPPromptServer() + server.add_prompt( + "generate_code_review_instructions", "Generate agent instructions for code review tasks" + ) + server.set_prompt_result( + "generate_code_review_instructions", + "You are a senior {language} code review specialist. Focus on {focus}.", + ) + + result = await server.get_prompt( + "generate_code_review_instructions", + {"focus": "security vulnerabilities", "language": "python"}, + ) + + assert len(result.messages) == 1 + expected_text = ( + "You are a senior python code review specialist. Focus on security vulnerabilities." + ) + assert result.messages[0].content.text == expected_text + + +@pytest.mark.asyncio +async def test_get_prompt_not_found(): + """Test getting a prompt that doesn't exist""" + server = FakeMCPPromptServer() + + with pytest.raises(ValueError, match="Prompt 'nonexistent' not found"): + await server.get_prompt("nonexistent") + + +@pytest.mark.asyncio +async def test_agent_with_prompt_instructions(): + """Test using prompt-generated instructions with an agent""" + server = FakeMCPPromptServer() + server.add_prompt( + "generate_code_review_instructions", "Generate agent instructions for code review tasks" + ) + server.set_prompt_result( + "generate_code_review_instructions", + "You are a code reviewer. Analyze the provided code for security issues.", + ) + + # Get instructions from prompt + prompt_result = await server.get_prompt("generate_code_review_instructions") + instructions = prompt_result.messages[0].content.text + + # Create agent with prompt-generated instructions + model = FakeModel() + agent = Agent(name="prompt_agent", instructions=instructions, model=model, mcp_servers=[server]) + + # Mock model response + model.add_multiple_turn_outputs( + [[get_text_message("Code analysis complete. Found security vulnerability.")]] + ) + + # Run the agent + result = await Runner.run(agent, input="Review this code: def unsafe_exec(cmd): os.system(cmd)") + + assert "Code analysis complete" in result.final_output + assert ( + agent.instructions + == "You are a code reviewer. Analyze the provided code for security issues." + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("streaming", [False, True]) +async def test_agent_with_prompt_instructions_streaming(streaming: bool): + """Test using prompt-generated instructions with streaming and non-streaming""" + server = FakeMCPPromptServer() + server.add_prompt( + "generate_code_review_instructions", "Generate agent instructions for code review tasks" + ) + server.set_prompt_result( + "generate_code_review_instructions", + "You are a {language} code reviewer focusing on {focus}.", + ) + + # Get instructions from prompt with arguments + prompt_result = await server.get_prompt( + "generate_code_review_instructions", {"language": "Python", "focus": "security"} + ) + instructions = prompt_result.messages[0].content.text + + # Create agent + model = FakeModel() + agent = Agent( + name="streaming_prompt_agent", instructions=instructions, model=model, mcp_servers=[server] + ) + + model.add_multiple_turn_outputs([[get_text_message("Security analysis complete.")]]) + + if streaming: + streaming_result = Runner.run_streamed(agent, input="Review code") + async for _ in streaming_result.stream_events(): + pass + final_result = streaming_result.final_output + else: + result = await Runner.run(agent, input="Review code") + final_result = result.final_output + + assert "Security analysis complete" in final_result + assert agent.instructions == "You are a Python code reviewer focusing on security." + + +@pytest.mark.asyncio +async def test_multiple_prompts(): + """Test server with multiple prompts""" + server = FakeMCPPromptServer() + + # Add multiple prompts + server.add_prompt( + "generate_code_review_instructions", "Generate agent instructions for code review tasks" + ) + server.add_prompt( + "generate_testing_instructions", "Generate agent instructions for testing tasks" + ) + + server.set_prompt_result("generate_code_review_instructions", "You are a code reviewer.") + server.set_prompt_result("generate_testing_instructions", "You are a test engineer.") + + # Test listing prompts + prompts_result = await server.list_prompts() + assert len(prompts_result.prompts) == 2 + + prompt_names = [p.name for p in prompts_result.prompts] + assert "generate_code_review_instructions" in prompt_names + assert "generate_testing_instructions" in prompt_names + + # Test getting each prompt + review_result = await server.get_prompt("generate_code_review_instructions") + assert review_result.messages[0].content.text == "You are a code reviewer." + + testing_result = await server.get_prompt("generate_testing_instructions") + assert testing_result.messages[0].content.text == "You are a test engineer." + + +@pytest.mark.asyncio +async def test_prompt_with_complex_arguments(): + """Test prompt with complex argument formatting""" + server = FakeMCPPromptServer() + server.add_prompt( + "generate_detailed_instructions", "Generate detailed instructions with multiple parameters" + ) + server.set_prompt_result( + "generate_detailed_instructions", + "You are a {role} specialist. Your focus is on {focus}. " + + "You work with {language} code. Your experience level is {level}.", + ) + + arguments = { + "role": "security", + "focus": "vulnerability detection", + "language": "Python", + "level": "senior", + } + + result = await server.get_prompt("generate_detailed_instructions", arguments) + + expected = ( + "You are a security specialist. Your focus is on vulnerability detection. " + "You work with Python code. Your experience level is senior." + ) + assert result.messages[0].content.text == expected + + +@pytest.mark.asyncio +async def test_prompt_with_missing_arguments(): + """Test prompt with missing arguments in format string""" + server = FakeMCPPromptServer() + server.add_prompt("incomplete_prompt", "Prompt with missing arguments") + server.set_prompt_result("incomplete_prompt", "You are a {role} working on {task}.") + + # Only provide one of the required arguments + result = await server.get_prompt("incomplete_prompt", {"role": "developer"}) + + # Should return the original string since formatting fails + assert result.messages[0].content.text == "You are a {role} working on {task}." + + +@pytest.mark.asyncio +async def test_prompt_server_cleanup(): + """Test that prompt server cleanup works correctly""" + server = FakeMCPPromptServer() + server.add_prompt("test_prompt", "Test prompt") + server.set_prompt_result("test_prompt", "Test result") + + # Test that server works before cleanup + result = await server.get_prompt("test_prompt") + assert result.messages[0].content.text == "Test result" + + # Cleanup should not raise any errors + await server.cleanup() + + # Server should still work after cleanup (in this fake implementation) + result = await server.get_prompt("test_prompt") + assert result.messages[0].content.text == "Test result" diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index fbd8db17d..9e0455115 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -1,7 +1,9 @@ import pytest +from agents import Agent from agents.exceptions import UserError from agents.mcp.server import _MCPServerWithClientSession +from agents.run_context import RunContextWrapper class CrashingClientSessionServer(_MCPServerWithClientSession): @@ -35,8 +37,11 @@ async def test_server_errors_cause_error_and_cleanup_called(): async def test_not_calling_connect_causes_error(): server = CrashingClientSessionServer() + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + with pytest.raises(UserError): - await server.list_tools() + await server.list_tools(run_context, agent) with pytest.raises(UserError): await server.call_tool("foo", {}) diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py new file mode 100644 index 000000000..c1ffff4b8 --- /dev/null +++ b/tests/mcp/test_tool_filtering.py @@ -0,0 +1,243 @@ +""" +Tool filtering tests use FakeMCPServer instead of real MCPServer implementations to avoid +external dependencies (processes, network connections) and ensure fast, reliable unit tests. +FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. +""" +import asyncio + +import pytest +from mcp import Tool as MCPTool + +from agents import Agent +from agents.mcp import ToolFilterContext, create_static_tool_filter +from agents.run_context import RunContextWrapper + +from .helpers import FakeMCPServer + + +def create_test_agent(name: str = "test_agent") -> Agent: + """Create a test agent for filtering tests.""" + return Agent(name=name, instructions="Test agent") + + +def create_test_context() -> RunContextWrapper: + """Create a test run context for filtering tests.""" + return RunContextWrapper(context=None) + + +# === Static Tool Filtering Tests === + +@pytest.mark.asyncio +async def test_static_tool_filtering(): + """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("tool1", {}) + server.add_tool("tool2", {}) + server.add_tool("tool3", {}) + server.add_tool("tool4", {}) + + # Create test context and agent for all calls + run_context = create_test_context() + agent = create_test_agent() + + # Test allowed_tool_names only + server.tool_filter = {"allowed_tool_names": ["tool1", "tool2"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test blocked_tool_names only + server.tool_filter = {"blocked_tool_names": ["tool3", "tool4"]} + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test both filters together (allowed first, then blocked) + server.tool_filter = { + "allowed_tool_names": ["tool1", "tool2", "tool3"], + "blocked_tool_names": ["tool3"] + } + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool2"} + + # Test no filter + server.tool_filter = None + tools = await server.list_tools(run_context, agent) + assert len(tools) == 4 + + # Test helper function + server.tool_filter = create_static_tool_filter( + allowed_tool_names=["tool1", "tool2"], + blocked_tool_names=["tool2"] + ) + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "tool1" + + +# === Dynamic Tool Filtering Core Tests === + +@pytest.mark.asyncio +async def test_dynamic_filter_sync_and_async(): + """Test both synchronous and asynchronous dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("allowed_tool", {}) + server.add_tool("blocked_tool", {}) + server.add_tool("restricted_tool", {}) + + # Create test context and agent + run_context = create_test_context() + agent = create_test_agent() + + # Test sync filter + def sync_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return tool.name.startswith("allowed") + + server.tool_filter = sync_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 1 + assert tools[0].name == "allowed_tool" + + # Test async filter + async def async_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + await asyncio.sleep(0.001) # Simulate async operation + return "restricted" not in tool.name + + server.tool_filter = async_filter + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"allowed_tool", "blocked_tool"} + + +@pytest.mark.asyncio +async def test_dynamic_filter_context_handling(): + """Test dynamic filters with context access""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("admin_tool", {}) + server.add_tool("user_tool", {}) + server.add_tool("guest_tool", {}) + + # Test context-independent filter + def context_independent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + return not tool.name.startswith("admin") + + server.tool_filter = context_independent_filter + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + # Test context-dependent filter (needs context) + def context_dependent_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + assert context is not None + assert context.run_context is not None + assert context.agent is not None + assert context.server_name == "test_server" + + # Only admin tools for agents with "admin" in name + if "admin" in context.agent.name.lower(): + return True + else: + return not tool.name.startswith("admin") + + server.tool_filter = context_dependent_filter + + # Should work with context + run_context = RunContextWrapper(context=None) + regular_agent = create_test_agent("regular_user") + tools = await server.list_tools(run_context, regular_agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"user_tool", "guest_tool"} + + admin_agent = create_test_agent("admin_user") + tools = await server.list_tools(run_context, admin_agent) + assert len(tools) == 3 + + +@pytest.mark.asyncio +async def test_dynamic_filter_error_handling(): + """Test error handling in dynamic filters""" + server = FakeMCPServer(server_name="test_server") + server.add_tool("good_tool", {}) + server.add_tool("error_tool", {}) + server.add_tool("another_good_tool", {}) + + def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + if tool.name == "error_tool": + raise ValueError("Simulated filter error") + return True + + server.tool_filter = error_prone_filter + + # Test with server call + run_context = create_test_context() + agent = create_test_agent() + tools = await server.list_tools(run_context, agent) + assert len(tools) == 2 + assert {t.name for t in tools} == {"good_tool", "another_good_tool"} + + +# === Integration Tests === + +@pytest.mark.asyncio +async def test_agent_dynamic_filtering_integration(): + """Test dynamic filtering integration with Agent methods""" + server = FakeMCPServer() + server.add_tool("file_read", {"type": "object", "properties": {"path": {"type": "string"}}}) + server.add_tool( + "file_write", + { + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + }, + ) + server.add_tool( + "database_query", {"type": "object", "properties": {"query": {"type": "string"}}} + ) + server.add_tool( + "network_request", {"type": "object", "properties": {"url": {"type": "string"}}} + ) + + # Role-based filter for comprehensive testing + async def role_based_filter(context: ToolFilterContext, tool: MCPTool) -> bool: + # Simulate async permission check + await asyncio.sleep(0.001) + + agent_name = context.agent.name.lower() + if "admin" in agent_name: + return True + elif "readonly" in agent_name: + return "read" in tool.name or "query" in tool.name + else: + return tool.name.startswith("file_") + + server.tool_filter = role_based_filter + + # Test admin agent + admin_agent = Agent(name="admin_user", instructions="Admin", mcp_servers=[server]) + run_context = RunContextWrapper(context=None) + admin_tools = await admin_agent.get_mcp_tools(run_context) + assert len(admin_tools) == 4 + + # Test readonly agent + readonly_agent = Agent(name="readonly_viewer", instructions="Read-only", mcp_servers=[server]) + readonly_tools = await readonly_agent.get_mcp_tools(run_context) + assert len(readonly_tools) == 2 + assert {t.name for t in readonly_tools} == {"file_read", "database_query"} + + # Test regular agent + regular_agent = Agent(name="regular_user", instructions="Regular", mcp_servers=[server]) + regular_tools = await regular_agent.get_mcp_tools(run_context) + assert len(regular_tools) == 2 + assert {t.name for t in regular_tools} == {"file_read", "file_write"} + + # Test get_all_tools method + all_tools = await regular_agent.get_all_tools(run_context) + mcp_tool_names = { + t.name + for t in all_tools + if t.name in {"file_read", "file_write", "database_query", "network_request"} + } + assert mcp_tool_names == {"file_read", "file_write"} diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index ad4db4019..94d11def3 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -2,6 +2,8 @@ from dataclasses import fields from openai.types.shared import Reasoning +from pydantic import TypeAdapter +from pydantic_core import to_json from agents.model_settings import ModelSettings @@ -44,6 +46,7 @@ def test_all_fields_serialization() -> None: metadata={"foo": "bar"}, store=False, include_usage=False, + response_include=["reasoning.encrypted_content"], extra_query={"foo": "bar"}, extra_body={"foo": "bar"}, extra_headers={"foo": "bar"}, @@ -131,3 +134,32 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.extra_args is None assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 + +def test_pydantic_serialization() -> None: + + """Tests whether ModelSettings can be serialized with Pydantic.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + extra_args={"custom_param": "value", "another_param": 42}, + ) + + json = to_json(model_settings) + deserialized = TypeAdapter(ModelSettings).validate_json(json) + + assert model_settings == deserialized diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index f9423619d..a985fd60d 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -43,7 +43,7 @@ async def test_handoff_with_agents(): handoffs=[agent_1, agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj(): ], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent(): handoffs=[handoff(agent_1), agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index 70dcabd59..a306b1841 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -18,6 +18,7 @@ ActionScroll, ActionType, ActionWait, + PendingSafetyCheck, ResponseComputerToolCall, ) @@ -31,8 +32,9 @@ RunContextWrapper, RunHooks, ) -from agents._run_impl import ComputerAction, ToolRunComputerAction +from agents._run_impl import ComputerAction, RunImpl, ToolRunComputerAction from agents.items import ToolCallOutputItem +from agents.tool import ComputerToolSafetyCheckData class LoggingComputer(Computer): @@ -309,3 +311,44 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None: assert raw["output"]["type"] == "computer_screenshot" assert "image_url" in raw["output"] assert raw["output"]["image_url"].endswith("xyz") + + +@pytest.mark.asyncio +async def test_pending_safety_check_acknowledged() -> None: + """Safety checks should be acknowledged via the callback.""" + + computer = LoggingComputer(screenshot_return="img") + called: list[ComputerToolSafetyCheckData] = [] + + def on_sc(data: ComputerToolSafetyCheckData) -> bool: + called.append(data) + return True + + tool = ComputerTool(computer=computer, on_safety_check=on_sc) + safety = PendingSafetyCheck(id="sc", code="c", message="m") + tool_call = ResponseComputerToolCall( + id="t1", + type="computer_call", + action=ActionClick(type="click", x=1, y=1, button="left"), + call_id="t1", + pending_safety_checks=[safety], + status="completed", + ) + run_action = ToolRunComputerAction(tool_call=tool_call, computer_tool=tool) + agent = Agent(name="a", tools=[tool]) + ctx = RunContextWrapper(context=None) + + results = await RunImpl.execute_computer_actions( + agent=agent, + actions=[run_action], + hooks=RunHooks[Any](), + context_wrapper=ctx, + config=RunConfig(), + ) + + assert len(results) == 1 + raw = results[0].raw_item + assert isinstance(raw, dict) + assert raw.get("acknowledged_safety_checks") == [{"id": "sc", "code": "c", "message": "m"}] + assert len(called) == 1 + assert called[0].safety_check.id == "sc" diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index a1b5b80ba..0f7fc2166 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -38,16 +38,17 @@ def get_len(data: HandoffInputData) -> int: return input_len + pre_handoff_len + new_items_len -def test_single_handoff_setup(): +@pytest.mark.asyncio +async def test_single_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2", handoffs=[agent_1]) assert not agent_1.handoffs assert agent_2.handoffs == [agent_1] - assert not AgentRunner._get_handoffs(agent_1) + assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1))) - handoff_objects = AgentRunner._get_handoffs(agent_2) + handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2)) assert len(handoff_objects) == 1 obj = handoff_objects[0] assert obj.tool_name == Handoff.default_tool_name(agent_1) @@ -55,7 +56,8 @@ def test_single_handoff_setup(): assert obj.agent_name == agent_1.name -def test_multiple_handoffs_setup(): +@pytest.mark.asyncio +async def test_multiple_handoffs_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -64,7 +66,7 @@ def test_multiple_handoffs_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1) assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2) @@ -76,7 +78,8 @@ def test_multiple_handoffs_setup(): assert handoff_objects[1].agent_name == agent_2.name -def test_custom_handoff_setup(): +@pytest.mark.asyncio +async def test_custom_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -95,7 +98,7 @@ def test_custom_handoff_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 first_handoff = handoff_objects[0] @@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None: obj = handoff(agent) transfer = obj.get_transfer_message(agent) assert json.loads(transfer) == {"assistant": agent.name} + + +def test_handoff_is_enabled_bool(): + """Test that handoff respects is_enabled boolean parameter.""" + agent = Agent(name="test") + + # Test enabled handoff (default) + handoff_enabled = handoff(agent) + assert handoff_enabled.is_enabled is True + + # Test explicitly enabled handoff + handoff_explicit_enabled = handoff(agent, is_enabled=True) + assert handoff_explicit_enabled.is_enabled is True + + # Test disabled handoff + handoff_disabled = handoff(agent, is_enabled=False) + assert handoff_disabled.is_enabled is False + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_callable(): + """Test that handoff respects is_enabled callable parameter.""" + agent = Agent(name="test") + + # Test callable that returns True + def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_callable_enabled = handoff(agent, is_enabled=always_enabled) + assert callable(handoff_callable_enabled.is_enabled) + result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent) + assert result is True + + # Test callable that returns False + def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return False + + handoff_callable_disabled = handoff(agent, is_enabled=always_disabled) + assert callable(handoff_callable_disabled.is_enabled) + result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent) + assert result is False + + # Test async callable + async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_async_enabled = handoff(agent, is_enabled=async_enabled) + assert callable(handoff_async_enabled.is_enabled) + result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore + assert result is True + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_filtering_integration(): + """Integration test that disabled handoffs are filtered out by the runner.""" + + # Set up agents + agent_1 = Agent(name="agent_1") + agent_2 = Agent(name="agent_2") + agent_3 = Agent(name="agent_3") + + # Create main agent with mixed enabled/disabled handoffs + main_agent = Agent( + name="main_agent", + handoffs=[ + handoff(agent_1, is_enabled=True), # enabled + handoff(agent_2, is_enabled=False), # disabled + handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable + ], + ) + + context_wrapper = RunContextWrapper(main_agent) + + # Get filtered handoffs using the runner's method + filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper) + + # Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out + assert len(filtered_handoffs) == 2 + + # Check that the correct agents are present + agent_names = {h.agent_name for h in filtered_handoffs} + assert agent_names == {"agent_1", "agent_3"} + assert "agent_2" not in agent_names diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py new file mode 100644 index 000000000..5160e09c2 --- /dev/null +++ b/tests/test_reasoning_content.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any, cast + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) +from openai.types.responses import ( + Response, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) + +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing +from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel +from agents.models.openai_provider import OpenAIProvider + + +# Helper functions to create test objects consistently +def create_content_delta(content: str) -> dict[str, Any]: + """Create a delta dictionary with regular content""" + return { + "content": content, + "role": None, + "function_call": None, + "tool_calls": None + } + +def create_reasoning_delta(content: str) -> dict[str, Any]: + """Create a delta dictionary with reasoning content. The Only difference is reasoning_content""" + return { + "content": None, + "role": None, + "function_call": None, + "tool_calls": None, + "reasoning_content": content + } + + +def create_chunk(delta: dict[str, Any], include_usage: bool = False) -> ChatCompletionChunk: + """Create a ChatCompletionChunk with the given delta""" + # Create a ChoiceDelta object from the dictionary + delta_obj = ChoiceDelta( + content=delta.get("content"), + role=delta.get("role"), + function_call=delta.get("function_call"), + tool_calls=delta.get("tool_calls"), + ) + + # Add reasoning_content attribute dynamically if present in the delta + if "reasoning_content" in delta: + # Use direct assignment for the reasoning_content attribute + delta_obj_any = cast(Any, delta_obj) + delta_obj_any.reasoning_content = delta["reasoning_content"] + + # Create the chunk + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="deepseek is usually expected", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=delta_obj)], + ) + + if include_usage: + chunk.usage = CompletionUsage( + completion_tokens=4, + prompt_tokens=2, + total_tokens=6, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ) + + return chunk + + +async def create_fake_stream( + chunks: list[ChatCompletionChunk], +) -> AsyncIterator[ChatCompletionChunk]: + for chunk in chunks: + yield chunk + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_reasoning_content(monkeypatch) -> None: + """ + Validate that when a model streams reasoning content, + `stream_response` emits the appropriate sequence of events including + `response.reasoning_summary_text.delta` events for each chunk of the reasoning content and + constructs a completed response with a `ResponseReasoningItem` part. + """ + # Create test chunks + chunks = [ + # Reasoning content chunks + create_chunk(create_reasoning_delta("Let me think")), + create_chunk(create_reasoning_delta(" about this")), + # Regular content chunks + create_chunk(create_content_delta("The answer")), + create_chunk(create_content_delta(" is 42"), include_usage=True), + ] + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, create_fake_stream(chunks) + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ): + output_events.append(event) + + # verify reasoning content events were emitted + reasoning_delta_events = [ + e for e in output_events if e.type == "response.reasoning_summary_text.delta" + ] + assert len(reasoning_delta_events) == 2 + assert reasoning_delta_events[0].delta == "Let me think" + assert reasoning_delta_events[1].delta == " about this" + + # verify regular content events were emitted + content_delta_events = [e for e in output_events if e.type == "response.output_text.delta"] + assert len(content_delta_events) == 2 + assert content_delta_events[0].delta == "The answer" + assert content_delta_events[1].delta == " is 42" + + # verify the final response contains both types of content + response_event = output_events[-1] + assert response_event.type == "response.completed" + assert len(response_event.response.output) == 2 + + # first item should be reasoning + assert isinstance(response_event.response.output[0], ResponseReasoningItem) + assert response_event.response.output[0].summary[0].text == "Let me think about this" + + # second item should be message with text + assert isinstance(response_event.response.output[1], ResponseOutputMessage) + assert isinstance(response_event.response.output[1].content[0], ResponseOutputText) + assert response_event.response.output[1].content[0].text == "The answer is 42" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_get_response_with_reasoning_content(monkeypatch) -> None: + """ + Test that when a model returns reasoning content in addition to regular content, + `get_response` properly includes both in the response output. + """ + # create a message with reasoning content + msg = ChatCompletionMessage( + role="assistant", + content="The answer is 42", + ) + # Use dynamic attribute for reasoning_content + # We need to cast to Any to avoid mypy errors since reasoning_content is not a defined attribute + msg_with_reasoning = cast(Any, msg) + msg_with_reasoning.reasoning_content = "Let me think about this question carefully" + + # create a choice with the message + mock_choice = { + "index": 0, + "finish_reason": "stop", + "message": msg_with_reasoning, + "delta": None + } + + chat = ChatCompletion( + id="resp-id", + created=0, + model="deepseek is expected", + object="chat.completion", + choices=[mock_choice], # type: ignore[list-item] + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=5, + total_tokens=15, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=6), + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ), + ) + + async def patched_fetch_response(self, *args, **kwargs): + return chat + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + resp = await model.get_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None, + ) + + # should have produced a reasoning item and a message with text content + assert len(resp.output) == 2 + + # first output should be the reasoning item + assert isinstance(resp.output[0], ResponseReasoningItem) + assert resp.output[0].summary[0].text == "Let me think about this question carefully" + + # second output should be the message with text content + assert isinstance(resp.output[1], ResponseOutputMessage) + assert isinstance(resp.output[1].content[0], ResponseOutputText) + assert resp.output[1].content[0].text == "The answer is 42" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_with_empty_reasoning_content(monkeypatch) -> None: + """ + Test that when a model streams empty reasoning content, + the response still processes correctly without errors. + """ + # create test chunks with empty reasoning content + chunks = [ + create_chunk(create_reasoning_delta("")), + create_chunk(create_content_delta("The answer is 42"), include_usage=True), + ] + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, create_fake_stream(chunks) + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + prompt=None + ): + output_events.append(event) + + # verify the final response contains the content + response_event = output_events[-1] + assert response_event.type == "response.completed" + + # should only have the message, not an empty reasoning item + assert len(response_event.response.output) == 1 + assert isinstance(response_event.response.output[0], ResponseOutputMessage) + assert isinstance(response_event.response.output[0].content[0], ResponseOutputText) + assert response_event.response.output[0].content[0].text == "The answer is 42" diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2454a4462..4cf9ae832 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -325,7 +325,7 @@ async def get_execute_result( run_config: RunConfig | None = None, ) -> SingleStepResult: output_schema = AgentRunner._get_output_schema(agent) - handoffs = AgentRunner._get_handoffs(agent) + handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None)) processed_response = RunImpl.process_model_response( agent=agent, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 5a75ec837..6a2904791 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" @@ -216,7 +216,7 @@ async def test_missing_handoff_fails(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1 diff --git a/uv.lock b/uv.lock index 679d8d2c1..d882c9bc5 100644 --- a/uv.lock +++ b/uv.lock @@ -1480,7 +1480,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.19" +version = "0.1.0" source = { editable = "." } dependencies = [ { name = "griffe" },