Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,13 @@ def inject_special_parameters(
Args:
validated_input: The validated input parameters (modified in place).
tool_use: The tool use request containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
agent.invoke_async(), etc.).
"""
if self._context_param and self._context_param in self.signature.parameters:
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
tool_context = ToolContext(
tool_use=tool_use, agent=invocation_state["agent"], invocation_state=invocation_state
)
validated_input[self._context_param] = tool_context

# Inject agent if requested (backward compatibility)
Expand Down Expand Up @@ -433,7 +436,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw

Args:
tool_use: The tool use specification from the Agent.
invocation_state: Context for the tool invocation, including agent state.
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
agent.invoke_async(), etc.).
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand Down
6 changes: 5 additions & 1 deletion src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class ToolContext:
tool_use: The complete ToolUse object containing tool invocation details.
agent: The Agent instance executing this tool, providing access to conversation history,
model configuration, and other agent state.
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
agent.invoke_async(), etc.).

Note:
This class is intended to be instantiated by the SDK. Direct construction by users
Expand All @@ -140,6 +142,7 @@ class ToolContext:

tool_use: ToolUse
agent: "Agent"
invocation_state: dict[str, Any]


ToolChoice = Union[
Expand Down Expand Up @@ -246,7 +249,8 @@ def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs:

Args:
tool_use: The tool use request containing tool ID and parameters.
invocation_state: Context for the tool invocation, including agent state.
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
agent.invoke_async(), etc.).
**kwargs: Additional keyword arguments for future extensibility.

Yields:
Expand Down
15 changes: 13 additions & 2 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests for the function-based tool decorator pattern.
"""

from asyncio import Queue
from typing import Any, Dict, Optional, Union
from unittest.mock import MagicMock

Expand Down Expand Up @@ -1039,7 +1040,7 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
assert "NoneType: None" in result["content"][0]["text"]


async def _run_context_injection_test(context_tool: AgentTool):
async def _run_context_injection_test(context_tool: AgentTool, additional_context=None):
"""Common test logic for context injection tests."""
tool: AgentTool = context_tool
generator = tool.stream(
Expand All @@ -1052,6 +1053,7 @@ async def _run_context_injection_test(context_tool: AgentTool):
},
invocation_state={
"agent": Agent(name="test_agent"),
**(additional_context or {}),
},
)
tool_results = [value async for value in generator]
Expand All @@ -1074,13 +1076,17 @@ async def _run_context_injection_test(context_tool: AgentTool):
async def test_tool_context_injection_default():
"""Test that ToolContext is properly injected with default parameter name (tool_context)."""

value_to_pass = Queue() # a complex value that is not serializable

@strands.tool(context=True)
def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
"""Tool that uses ToolContext to access tool_use_id."""
tool_use_id = tool_context.tool_use["toolUseId"]
tool_name = tool_context.tool_use["name"]
agent_from_tool_context = tool_context.agent

assert tool_context.invocation_state["test_reference"] is value_to_pass

return {
"status": "success",
"content": [
Expand All @@ -1090,7 +1096,12 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
],
}

await _run_context_injection_test(context_tool)
await _run_context_injection_test(
context_tool,
{
"test_reference": value_to_pass,
},
)


@pytest.mark.asyncio
Expand Down
Loading