diff --git a/examples/tools/image_function_tool.py b/examples/tools/image_function_tool.py new file mode 100644 index 00000000..e80949e8 --- /dev/null +++ b/examples/tools/image_function_tool.py @@ -0,0 +1,33 @@ +import asyncio +import base64 +import os + +from agents import Agent, Runner, image_function_tool + +FILEPATH = os.path.join(os.path.dirname(__file__), "media/small.webp") + + +@image_function_tool +def image_to_base64(path: str) -> str: + """ + This function takes a path to an image and returns a base64 encoded string of the image. + It is used to convert the image to a base64 encoded string so that it can be sent to the LLM. + """ + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return f"data:image/jpeg;base64,{encoded_string}" + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + tools=[image_to_base64], + ) + + result = await Runner.run(agent, f"Read the image in {FILEPATH} and tell me what you see.") + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/media/small.webp b/examples/tools/media/small.webp new file mode 100644 index 00000000..49633857 Binary files /dev/null and b/examples/tools/media/small.webp differ diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4..aa41e479 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -58,10 +58,12 @@ FileSearchTool, FunctionTool, FunctionToolResult, + ImageFunctionTool, Tool, WebSearchTool, default_tool_error_function, function_tool, + image_function_tool, ) from .tracing import ( AgentSpanData, @@ -203,12 +205,14 @@ def enable_verbose_stdout_logging(): "AgentUpdatedStreamEvent", "StreamEvent", "FunctionTool", + "ImageFunctionTool", "FunctionToolResult", "ComputerTool", "FileSearchTool", "Tool", "WebSearchTool", "function_tool", + "image_function_tool", "Usage", "add_trace_processor", "agent_span", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685..1215db69 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -52,7 +52,14 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ( + ComputerTool, + FunctionTool, + FunctionToolResult, + ImageFunctionTool, + ImageFunctionToolResult, + Tool, +) from .tracing import ( SpanError, Trace, @@ -106,6 +113,12 @@ class ToolRunFunction: function_tool: FunctionTool +@dataclass +class ToolRunImageFunction: + tool_call: ResponseFunctionToolCall + image_function_tool: ImageFunctionTool + + @dataclass class ToolRunComputerAction: tool_call: ResponseComputerToolCall @@ -117,6 +130,7 @@ class ProcessedResponse: new_items: list[RunItem] handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] + image_functions: list[ToolRunImageFunction] computer_actions: list[ToolRunComputerAction] tools_used: list[str] # Names of all tools used, including hosted tools @@ -127,6 +141,7 @@ def has_tools_to_run(self) -> bool: [ self.handoffs, self.functions, + self.image_functions, self.computer_actions, ] ) @@ -207,7 +222,7 @@ async def execute_tools_and_side_effects( new_step_items.extend(processed_response.new_items) # First, lets run the tool calls - function tools and computer actions - function_results, computer_results = await asyncio.gather( + function_results, image_function_results, computer_results = await asyncio.gather( cls.execute_function_tool_calls( agent=agent, tool_runs=processed_response.functions, @@ -215,6 +230,13 @@ async def execute_tools_and_side_effects( context_wrapper=context_wrapper, config=run_config, ), + cls.execute_image_function_tool_calls( + agent=agent, + tool_runs=processed_response.image_functions, + hooks=hooks, + context_wrapper=context_wrapper, + config=run_config, + ), cls.execute_computer_actions( agent=agent, actions=processed_response.computer_actions, @@ -224,6 +246,7 @@ async def execute_tools_and_side_effects( ), ) new_step_items.extend([result.run_item for result in function_results]) + new_step_items.extend([result.run_item for result in image_function_results]) new_step_items.extend(computer_results) # Second, check if there are any handoffs @@ -342,10 +365,14 @@ def process_model_response( run_handoffs = [] functions = [] + image_functions = [] computer_actions = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + image_function_map = { + tool.name: tool for tool in all_tools if isinstance(tool, ImageFunctionTool) + } computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) for output in response.output: @@ -393,6 +420,15 @@ def process_model_response( handoff=handoff_map[output.name], ) run_handoffs.append(handoff) + + elif output.name in image_function_map: + items.append(ToolCallItem(raw_item=output, agent=agent)) + image_functions.append( + ToolRunImageFunction( + tool_call=output, + image_function_tool=image_function_map[output.name], + ) + ) # Regular function tool call else: if output.name not in function_map: @@ -415,6 +451,7 @@ def process_model_response( new_items=items, handoffs=run_handoffs, functions=functions, + image_functions=image_functions, computer_actions=computer_actions, tools_used=tools_used, ) @@ -489,6 +526,78 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_image_function_tool_calls( + cls, + *, + agent: Agent[TContext], + tool_runs: list[ToolRunImageFunction], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> list[ImageFunctionToolResult]: + async def run_single_tool( + func_tool: ImageFunctionTool, tool_call: ResponseFunctionToolCall + ) -> Any: + with function_span(func_tool.name) as span_fn: + if config.trace_include_sensitive_data: + span_fn.span_data.input = tool_call.arguments + try: + _, _, result = await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, func_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + func_tool.on_invoke_tool(context_wrapper, tool_call.arguments), + ) + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, func_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + except Exception as e: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Error running tool", + data={"tool_name": func_tool.name, "error": str(e)}, + ) + ) + if isinstance(e, AgentsException): + raise e + raise UserError(f"Error running tool {func_tool.name}: {e}") from e + + if config.trace_include_sensitive_data: + span_fn.span_data.output = result + return result + + tasks = [] + for tool_run in tool_runs: + image_function_tool = tool_run.image_function_tool + tasks.append(run_single_tool(image_function_tool, tool_run.tool_call)) + + results = await asyncio.gather(*tasks) + + return [ + ImageFunctionToolResult( + tool=tool_run.image_function_tool, + output=result, + run_item=ToolCallOutputItem( + output=result, + raw_item=ItemHelpers.image_function_tool_call_output_item( + tool_run.tool_call, result + ), + agent=agent, + ), + ) + for tool_run, result in zip(tool_runs, results) + ] + @classmethod async def execute_computer_actions( cls, diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a..db7dcd96 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -57,7 +57,7 @@ class RunItemBase(Generic[T], abc.ABC): def to_input_item(self) -> TResponseInputItem: """Converts this item into an input item suitable for passing to the model.""" - if isinstance(self.raw_item, dict): + if isinstance(self.raw_item, dict) or isinstance(self.raw_item, list): # We know that input items are dicts, so we can ignore the type error return self.raw_item # type: ignore elif isinstance(self.raw_item, BaseModel): @@ -248,3 +248,25 @@ def tool_call_output_item( "output": output, "type": "function_call_output", } + + @classmethod + def image_function_tool_call_output_item( + cls, tool_call: ResponseFunctionToolCall, output: str + ) -> FunctionCallOutput: + """Creates a tool call output item from a tool call and its output.""" + return [ + { + "call_id": tool_call.call_id, + "output": "Image generating tool is called.", + "type": "function_call_output", + }, + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": output, + } + ], + }, + ] diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b9..f6591e38 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -23,7 +23,14 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ( + ComputerTool, + FileSearchTool, + FunctionTool, + ImageFunctionTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -358,6 +365,15 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "description": tool.description, } includes: IncludeLiteral | None = None + elif isinstance(tool, ImageFunctionTool): + converted_tool: ToolParam = { + "name": tool.name, + "parameters": tool.params_json_schema, + "strict": tool.strict_json_schema, + "type": "function", + "description": tool.description, + } + includes: IncludeLiteral | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..43da3dde 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -755,7 +755,14 @@ async def _run_single_turn( output_schema = cls._get_output_schema(agent) handoffs = cls._get_handoffs(agent) input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) + + # input.extend([generated_item.to_input_item() for generated_item in generated_items]) + for generated_item in generated_items: + input_item_from_generated_item = generated_item.to_input_item() + if isinstance(input_item_from_generated_item, list): + input.extend(input_item_from_generated_item) + else: + input.append(input_item_from_generated_item) new_response = await cls._get_new_response( agent, diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c16242..6f3e11dd 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -42,6 +42,18 @@ class FunctionToolResult: """The run item that was produced as a result of the tool call.""" +@dataclass +class ImageFunctionToolResult: + tool: ImageFunctionTool + """The tool that was run.""" + + output: Any + """The output of the tool.""" + + run_item: RunItem + """The run item that was produced as a result of the tool call.""" + + @dataclass class FunctionTool: """A tool that wraps a function. In most cases, you should use the `function_tool` helpers to @@ -73,6 +85,37 @@ class FunctionTool: as it increases the likelihood of correct JSON input.""" +@dataclass +class ImageFunctionTool: + """A tool that wraps a function that generates an image. In most cases, you should use the `image_function_tool` helpers to + create a ImageFunctionTool, as they let you easily wrap a Python function. + """ + + name: str + """The name of the tool, as shown to the LLM. Generally the name of the function.""" + + description: str + """A description of the tool, as shown to the LLM.""" + + params_json_schema: dict[str, Any] + """The JSON schema for the tool's parameters.""" + + on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]] + """A function that invokes the tool with the given context and parameters. The params passed + are: + 1. The tool run context. + 2. The arguments from the LLM, as a JSON string. + + You must return a string representation of the tool output, or something we can call `str()` on. + In case of errors, you can either raise an Exception (which will cause the run to fail) or + return a string error message (which will be sent back to the LLM). + """ + + strict_json_schema: bool = True + """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, + as it increases the likelihood of correct JSON input.""" + + @dataclass class FileSearchTool: """A hosted tool that lets the LLM search through a vector store. Currently only supported with @@ -308,3 +351,142 @@ def decorator(real_func: ToolFunction[...]) -> FunctionTool: return _create_function_tool(real_func) return decorator + + +def image_function_tool( + func: ToolFunction[...] | None = None, + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = default_tool_error_function, + strict_mode: bool = True, +) -> ImageFunctionTool | Callable[[ToolFunction[...]], ImageFunctionTool]: + """ + Decorator to create a FunctionTool from a function. By default, we will: + 1. Parse the function signature to create a JSON schema for the tool's parameters. + 2. Use the function's docstring to populate the tool's description. + 3. Use the function's docstring to populate argument descriptions. + The docstring style is detected automatically, but you can override it. + + If the function takes a `RunContextWrapper` as the first argument, it *must* match the + context type of the agent that uses the tool. + + Args: + func: The function to wrap. + name_override: If provided, use this name for the tool instead of the function's name. + description_override: If provided, use this description for the tool instead of the + function's docstring. + docstring_style: If provided, use this style for the tool's docstring. If not provided, + we will attempt to auto-detect the style. + use_docstring_info: If True, use the function's docstring to populate the tool's + description and argument descriptions. + failure_error_function: If provided, use this function to generate an error message when + the tool call fails. The error message is sent to the LLM. If you pass None, then no + error message will be sent and instead an Exception will be raised. + strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly* + recommend setting this to True, as it increases the likelihood of correct JSON input. + If False, it allows non-strict JSON schemas. For example, if a parameter has a default + value, it will be optional, additional properties are allowed, etc. See here for more: + https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas + """ + + def _create_image_function_tool(the_func: ToolFunction[...]) -> ImageFunctionTool: + schema = function_schema( + func=the_func, + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + strict_json_schema=strict_mode, + ) + + async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + json_data: dict[str, Any] = json.loads(input) if input else {} + except Exception as e: + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invalid JSON input for tool {schema.name}") + else: + logger.debug(f"Invalid JSON input for tool {schema.name}: {input}") + raise ModelBehaviorError( + f"Invalid JSON input for tool {schema.name}: {input}" + ) from e + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Invoking tool {schema.name}") + else: + logger.debug(f"Invoking tool {schema.name} with input {input}") + + try: + parsed = ( + schema.params_pydantic_model(**json_data) + if json_data + else schema.params_pydantic_model() + ) + except ValidationError as e: + raise ModelBehaviorError(f"Invalid JSON input for tool {schema.name}: {e}") from e + + args, kwargs_dict = schema.to_call_args(parsed) + + if not _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}") + + if inspect.iscoroutinefunction(the_func): + if schema.takes_context: + result = await the_func(ctx, *args, **kwargs_dict) + else: + result = await the_func(*args, **kwargs_dict) + else: + if schema.takes_context: + result = the_func(ctx, *args, **kwargs_dict) + else: + result = the_func(*args, **kwargs_dict) + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool {schema.name} completed.") + else: + logger.debug(f"Tool {schema.name} returned {result}") + + return result + + async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: + try: + return await _on_invoke_tool_impl(ctx, input) + except Exception as e: + if failure_error_function is None: + raise + + result = failure_error_function(ctx, e) + if inspect.isawaitable(result): + return await result + + _error_tracing.attach_error_to_current_span( + SpanError( + message="Error running tool (non-fatal)", + data={ + "tool_name": schema.name, + "error": str(e), + }, + ) + ) + return result + + return ImageFunctionTool( + name=schema.name, + description=schema.description or "", + params_json_schema=schema.params_json_schema, + on_invoke_tool=_on_invoke_tool, + strict_json_schema=strict_mode, + ) + + # If func is actually a callable, we were used as @function_tool with no parentheses + if callable(func): + return _create_image_function_tool(func) + + # Otherwise, we were used as @function_tool(...), so return a decorator + def decorator(real_func: ToolFunction[...]) -> ImageFunctionTool: + return _create_image_function_tool(real_func) + + return decorator diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de7..b048a452 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c2540..8efa95a7 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -17,21 +17,21 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +47,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +75,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/tools/test_image_function_tool.py b/tests/tools/test_image_function_tool.py new file mode 100644 index 00000000..7d22a1c2 --- /dev/null +++ b/tests/tools/test_image_function_tool.py @@ -0,0 +1,248 @@ +import asyncio +import json +from typing import Any + +import pytest +from pydantic import BaseModel +from typing_extensions import TypedDict + +from src.agents import ( + ImageFunctionTool, + ModelBehaviorError, + RunContextWrapper, + image_function_tool, +) +from src.agents.tool import default_tool_error_function + +# A dummy base64 encoded image string (e.g., a tiny 1x1 pixel red PNG) +DUMMY_IMAGE_BASE64 = "" + +# =================== Basic Tests for image_function_tool =================== + + +def argless_image_function() -> str: + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_argless_image_function(): + tool = image_function_tool(argless_image_function) + assert tool.name == "argless_image_function" + + result = await tool.on_invoke_tool(RunContextWrapper(None), "") + assert result == DUMMY_IMAGE_BASE64 + + +def argless_with_context_image(ctx: RunContextWrapper[str]) -> str: + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_argless_with_context_image(): + tool = image_function_tool(argless_with_context_image) + assert tool.name == "argless_with_context_image" + + result = await tool.on_invoke_tool(RunContextWrapper(None), "") + assert result == DUMMY_IMAGE_BASE64 + + # Extra JSON should not raise an error + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') + assert result == DUMMY_IMAGE_BASE64 + + +def simple_image_function(prompt: str, style: str = "realistic") -> str: + # In a real scenario, these parameters would affect the generated image + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_simple_image_function(): + tool = image_function_tool(simple_image_function, failure_error_function=None) + assert tool.name == "simple_image_function" + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "cat"}') + assert result == DUMMY_IMAGE_BASE64 + + result = await tool.on_invoke_tool( + RunContextWrapper(None), '{"prompt": "dog", "style": "cartoon"}' + ) + assert result == DUMMY_IMAGE_BASE64 + + # Missing required argument should raise an error + with pytest.raises(ModelBehaviorError): + await tool.on_invoke_tool(RunContextWrapper(None), "") + + +class ImageParams(BaseModel): + prompt: str + width: int = 512 + height: int = 512 + + +class StyleOptions(TypedDict): + style: str + seed: int + + +def complex_args_image_function(params: ImageParams, style_options: StyleOptions) -> str: + # In a real scenario, these parameters would affect the generated image + return DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_complex_args_image_function(): + tool = image_function_tool(complex_args_image_function, failure_error_function=None) + assert tool.name == "complex_args_image_function" + + valid_json = json.dumps( + { + "params": ImageParams(prompt="sunset").model_dump(), + "style_options": StyleOptions(style="realistic", seed=42), + } + ) + result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + assert result == DUMMY_IMAGE_BASE64 + + valid_json = json.dumps( + { + "params": ImageParams(prompt="mountains", width=1024, height=768).model_dump(), + "style_options": StyleOptions(style="abstract", seed=123), + } + ) + result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + assert result == DUMMY_IMAGE_BASE64 + + # Missing required argument should raise an error + with pytest.raises(ModelBehaviorError): + await tool.on_invoke_tool(RunContextWrapper(None), '{"params": {"prompt": "forest"}}') + + +def test_image_function_config_overrides(): + tool = image_function_tool(simple_image_function, name_override="custom_image_name") + assert tool.name == "custom_image_name" + + tool = image_function_tool(simple_image_function, description_override="Generate custom images") + assert tool.description == "Generate custom images" + + tool = image_function_tool( + simple_image_function, + name_override="art_generator", + description_override="Creates beautiful art images", + ) + assert tool.name == "art_generator" + assert tool.description == "Creates beautiful art images" + + +def test_image_function_schema_is_strict(): + tool = image_function_tool(simple_image_function) + assert tool.strict_json_schema, "Should be strict by default" + assert ( + "additionalProperties" in tool.params_json_schema + and not tool.params_json_schema["additionalProperties"] + ) + + tool = image_function_tool(complex_args_image_function) + assert tool.strict_json_schema, "Should be strict by default" + assert ( + "additionalProperties" in tool.params_json_schema + and not tool.params_json_schema["additionalProperties"] + ) + + +@pytest.mark.asyncio +async def test_manual_image_function_tool_creation_works(): + def generate_image(prompt: str) -> str: + return DUMMY_IMAGE_BASE64 + + class ImageArgs(BaseModel): + prompt: str + + async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: + parsed = ImageArgs.model_validate_json(args) + return generate_image(prompt=parsed.prompt) + + tool = ImageFunctionTool( + name="image_creator", + description="Creates images from text prompts", + params_json_schema=ImageArgs.model_json_schema(), + on_invoke_tool=run_function, + ) + + assert tool.name == "image_creator" + assert tool.description == "Creates images from text prompts" + for key, value in ImageArgs.model_json_schema().items(): + assert tool.params_json_schema[key] == value + assert tool.strict_json_schema + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "sunset"}') + assert result == DUMMY_IMAGE_BASE64 + + tool_not_strict = ImageFunctionTool( + name="image_creator", + description="Creates images from text prompts", + params_json_schema=ImageArgs.model_json_schema(), + on_invoke_tool=run_function, + strict_json_schema=False, + ) + + assert not tool_not_strict.strict_json_schema + assert "additionalProperties" not in tool_not_strict.params_json_schema + + result = await tool_not_strict.on_invoke_tool( + RunContextWrapper(None), '{"prompt": "sunset", "style": "realistic"}' + ) + assert result == DUMMY_IMAGE_BASE64 + + +@pytest.mark.asyncio +async def test_image_function_tool_default_error_works(): + def failing_image_generator(prompt: str) -> str: + raise ValueError("Image generation failed") + + tool = image_function_tool(failing_image_generator) + ctx = RunContextWrapper(None) + + result = await tool.on_invoke_tool(ctx, "") + assert "Invalid JSON" in str(result) + + result = await tool.on_invoke_tool(ctx, "{}") + assert "Invalid JSON" in str(result) + + result = await tool.on_invoke_tool(ctx, '{"prompt": "sunset"}') + assert result == default_tool_error_function(ctx, ValueError("Image generation failed")) + + +@pytest.mark.asyncio +async def test_sync_custom_error_function_works_for_image_tool(): + def failing_image_generator(prompt: str) -> str: + raise ValueError("Image generation failed") + + def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str: + return f"error_{error.__class__.__name__}_image" + + tool = image_function_tool( + failing_image_generator, failure_error_function=custom_sync_error_function + ) + ctx = RunContextWrapper(None) + + result = await tool.on_invoke_tool(ctx, "") + assert result == "error_ModelBehaviorError_image" + + result = await tool.on_invoke_tool(ctx, "{}") + assert result == "error_ModelBehaviorError_image" + + result = await tool.on_invoke_tool(ctx, '{"prompt": "sunset"}') + assert result == "error_ValueError_image" + + +@pytest.mark.asyncio +async def test_async_image_generator(): + async def async_image_generator(prompt: str) -> str: + # Simulate some async operation + await asyncio.sleep(0.01) + return DUMMY_IMAGE_BASE64 + + tool = image_function_tool(async_image_generator) + + result = await tool.on_invoke_tool(RunContextWrapper(None), '{"prompt": "sunset"}') + assert result == DUMMY_IMAGE_BASE64