Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9bd539e

Browse files
Merge pull request google#1656 from google:fix-transfer-to-agent-parameters-issue
PiperOrigin-RevId: 778334126
2 parents 3d2f13c + 0959b06 commit 9bd539e

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

‎src/google/adk/tools/function_tool.py‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import inspect
1618
from typing import Any
1719
from typing import Callable
@@ -79,9 +81,13 @@ async def run_async(
7981
) -> Any:
8082
args_to_call = args.copy()
8183
signature = inspect.signature(self.func)
82-
if 'tool_context' in signature.parameters:
84+
valid_params = {param for param in signature.parameters}
85+
if 'tool_context' in valid_params:
8386
args_to_call['tool_context'] = tool_context
8487

88+
# Filter args_to_call to only include valid parameters for the function
89+
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
90+
8591
# Before invoking the function, we check for if the list of args passed in
8692
# has all the mandatory arguments or not.
8793
# If the check fails, then we don't invoke the tool and let the Agent know

‎tests/unittests/tools/test_function_tool.py‎

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
from unittest.mock import MagicMock
1616

17+
from google.adk.agents.invocation_context import InvocationContext
18+
from google.adk.sessions.session import Session
1719
from google.adk.tools.function_tool import FunctionTool
20+
from google.adk.tools.tool_context import ToolContext
1821
import pytest
1922

2023

@@ -294,3 +297,51 @@ async def async_func_with_optional_args(
294297
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
295298
result = await tool.run_async(args=args, tool_context=MagicMock())
296299
assert result == "test_value_1,test_value_3"
300+
301+
302+
@pytest.mark.asyncio
303+
async def test_run_async_with_unexpected_argument():
304+
"""Test that run_async filters out unexpected arguments."""
305+
306+
def sample_func(expected_arg: str):
307+
return {"received_arg": expected_arg}
308+
309+
tool = FunctionTool(sample_func)
310+
mock_invocation_context = MagicMock(spec=InvocationContext)
311+
mock_invocation_context.session = MagicMock(spec=Session)
312+
# Add the missing state attribute to the session mock
313+
mock_invocation_context.session.state = MagicMock()
314+
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
315+
316+
result = await tool.run_async(
317+
args={"expected_arg": "hello", "parameters": "should_be_filtered"},
318+
tool_context=tool_context_mock,
319+
)
320+
assert result == {"received_arg": "hello"}
321+
322+
323+
@pytest.mark.asyncio
324+
async def test_run_async_with_tool_context_and_unexpected_argument():
325+
"""Test that run_async handles tool_context and filters out unexpected arguments."""
326+
327+
def sample_func_with_context(expected_arg: str, tool_context: ToolContext):
328+
return {"received_arg": expected_arg, "context_present": bool(tool_context)}
329+
330+
tool = FunctionTool(sample_func_with_context)
331+
mock_invocation_context = MagicMock(spec=InvocationContext)
332+
mock_invocation_context.session = MagicMock(spec=Session)
333+
# Add the missing state attribute to the session mock
334+
mock_invocation_context.session.state = MagicMock()
335+
mock_tool_context = ToolContext(invocation_context=mock_invocation_context)
336+
337+
result = await tool.run_async(
338+
args={
339+
"expected_arg": "world",
340+
"parameters": "should_also_be_filtered",
341+
},
342+
tool_context=mock_tool_context,
343+
)
344+
assert result == {
345+
"received_arg": "world",
346+
"context_present": True,
347+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /