Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a0cd676

Browse files
author
Lin-Nikaido
committed
feat: enable to will_continue-like function response in BaseLlmFlow.run_async method with streaming mode.
1 parent 67f23df commit a0cd676

File tree

8 files changed

+713
-89
lines changed

8 files changed

+713
-89
lines changed

‎src/google/adk/agents/readonly_context.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
if TYPE_CHECKING:
2323
from google.genai import types
2424

25+
from ..sessions.session import Session
2526
from .invocation_context import InvocationContext
2627

2728

@@ -52,3 +53,8 @@ def agent_name(self) -> str:
5253
def state(self) -> MappingProxyType[str, Any]:
5354
"""The state of the current session. READONLY field."""
5455
return MappingProxyType(self._invocation_context.session.state)
56+
57+
@property
58+
def session(self) -> Session:
59+
"""The current session. READONLY field."""
60+
return self._invocation_context.session

‎src/google/adk/flows/llm_flows/base_llm_flow.py‎

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from typing import Optional
2525
from typing import TYPE_CHECKING
2626

27-
from google.genai import types
2827
from websockets.exceptions import ConnectionClosed
2928
from websockets.exceptions import ConnectionClosedOK
3029

30+
from google.genai import types
31+
3132
from . import _output_schema_processor
3233
from . import functions
3334
from ...agents.base_agent import BaseAgent
@@ -629,43 +630,51 @@ async def _postprocess_handle_function_calls_async(
629630
function_call_event: Event,
630631
llm_request: LlmRequest,
631632
) -> AsyncGenerator[Event, None]:
632-
if function_response_event := await functions.handle_function_calls_async(
633+
# if invocation_context.run_config.streaming_mode == StreamingMode.SSE:
634+
#
635+
# else:
636+
if function_response_event_agen := functions.handle_function_calls_async_gen(
633637
invocation_context, function_call_event, llm_request.tools_dict
634638
):
635-
auth_event = functions.generate_auth_event(
636-
invocation_context, function_response_event
637-
)
638-
if auth_event:
639-
yield auth_event
640-
641-
tool_confirmation_event = functions.generate_request_confirmation_event(
642-
invocation_context, function_call_event, function_response_event
643-
)
644-
if tool_confirmation_event:
645-
yield tool_confirmation_event
646-
647-
# Always yield the function response event first
648-
yield function_response_event
649-
650-
# Check if this is a set_model_response function response
651-
if json_response := _output_schema_processor.get_structured_model_response(
652-
function_response_event
653-
):
654-
# Create and yield a final model response event
655-
final_event = (
656-
_output_schema_processor.create_final_model_response_event(
657-
invocation_context, json_response
658-
)
639+
function_response_event = None
640+
async for function_response_event in function_response_event_agen:
641+
auth_event = functions.generate_auth_event(
642+
invocation_context, function_response_event
659643
)
660-
yieldfinal_event
661-
transfer_to_agent=function_response_event.actions.transfer_to_agent
662-
iftransfer_to_agent:
663-
agent_to_run = self._get_agent_to_run(
664-
invocation_context, transfer_to_agent
644+
ifauth_event:
645+
yieldauth_event
646+
647+
tool_confirmation_event = functions.generate_request_confirmation_event(
648+
invocation_context, function_call_event, function_response_event
665649
)
666-
async with Aclosing(agent_to_run.run_async(invocation_context)) as agen:
667-
async for event in agen:
668-
yield event
650+
if tool_confirmation_event:
651+
yield tool_confirmation_event
652+
653+
# Always yield the function response event first
654+
yield function_response_event
655+
656+
# Check if this is a set_model_response function response
657+
if json_response := _output_schema_processor.get_structured_model_response(
658+
function_response_event
659+
):
660+
# Create and yield a final model response event
661+
final_event = (
662+
_output_schema_processor.create_final_model_response_event(
663+
invocation_context, json_response
664+
)
665+
)
666+
yield final_event
667+
if function_response_event:
668+
transfer_to_agent = function_response_event.actions.transfer_to_agent
669+
if transfer_to_agent:
670+
agent_to_run = self._get_agent_to_run(
671+
invocation_context, transfer_to_agent
672+
)
673+
async with Aclosing(
674+
agent_to_run.run_async(invocation_context)
675+
) as agen:
676+
async for event in agen:
677+
yield event
669678

670679
def _get_agent_to_run(
671680
self, invocation_context: InvocationContext, agent_name: str

‎src/google/adk/flows/llm_flows/functions.py‎

Lines changed: 142 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,24 @@
2020
import copy
2121
import inspect
2222
import logging
23-
import threading
2423
from typing import Any
2524
from typing import AsyncGenerator
25+
from typing import AsyncIterator
2626
from typing import cast
27+
from typing import Iterator
28+
from typing import List
2729
from typing import Optional
30+
from typing import Tuple
2831
from typing import TYPE_CHECKING
2932
import uuid
3033

34+
from nltk.sem.chat80 import continent
35+
3136
from google.genai import types
3237

3338
from ...agents.active_streaming_tool import ActiveStreamingTool
3439
from ...agents.invocation_context import InvocationContext
40+
from ...agents.run_config import StreamingMode
3541
from ...auth.auth_tool import AuthToolArguments
3642
from ...events.event import Event
3743
from ...events.event_actions import EventActions
@@ -184,70 +190,85 @@ async def handle_function_calls_async(
184190
tools_dict: dict[str, BaseTool],
185191
filters: Optional[set[str]] = None,
186192
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
193+
187194
) -> Optional[Event]:
188195
"""Calls the functions and returns the function response event."""
189-
function_calls = function_call_event.get_function_calls()
190-
return await handle_function_call_list_async(
191-
invocation_context,
192-
function_calls,
193-
tools_dict,
194-
filters,
195-
tool_confirmation_dict,
196-
)
196+
async with Aclosing(
197+
handle_function_calls_async_gen(
198+
invocation_context,
199+
function_call_event,
200+
tools_dict,
201+
filters,
202+
tool_confirmation_dict,
203+
)
204+
) as agen:
205+
last_event = None
206+
async for event in agen:
207+
last_event = event
208+
return last_event
197209

198210

199-
async def handle_function_call_list_async(
211+
async def handle_function_calls_async_gen(
200212
invocation_context: InvocationContext,
201-
function_calls: list[types.FunctionCall],
213+
function_call_event: Event,
202214
tools_dict: dict[str, BaseTool],
203215
filters: Optional[set[str]] = None,
204216
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
205-
) -> Optional[Event]:
206-
"""Calls the functions and returns the function response event."""
217+
) -> AsyncGenerator[Optional[Event]]:
218+
"""Calls the functions and returns the function response event as generator."""
207219
from ...agents.llm_agent import LlmAgent
208220

209221
agent = invocation_context.agent
210222
if not isinstance(agent, LlmAgent):
211-
return None
223+
yield None
224+
return
225+
226+
function_calls = function_call_event.get_function_calls()
212227

213228
# Filter function calls
214229
filtered_calls = [
215230
fc for fc in function_calls if not filters or fc.id in filters
216231
]
217232

218233
if not filtered_calls:
219-
return None
234+
yield None
235+
return
220236

221-
# Create tasks for parallel execution
222-
tasks = [
223-
asyncio.create_task(
224-
_execute_single_function_call_async(
225-
invocation_context,
226-
function_call,
227-
tools_dict,
228-
agent,
229-
tool_confirmation_dict[function_call.id]
230-
if tool_confirmation_dict
231-
else None,
232-
)
237+
function_call_async_gens = [
238+
_execute_single_function_call_async_gen(
239+
invocation_context,
240+
function_call,
241+
tools_dict,
242+
agent,
243+
tool_confirmation_dict[function_call.id]
244+
if tool_confirmation_dict
245+
else None,
233246
)
234247
for function_call in filtered_calls
235248
]
236249

237-
# Wait for all tasks to complete
238-
function_response_events = await asyncio.gather(*tasks)
239-
240-
# Filter out None results
241-
function_response_events = [
242-
event for event in function_response_events if event is not None
243-
]
250+
merged_event = None
251+
result_events: List[Optional[Event]] = [None] * len(function_call_async_gens)
252+
function_response_events = []
253+
async for idx, event in _concat_function_call_generators(
254+
function_call_async_gens
255+
):
256+
result_events[idx] = event
257+
function_response_events = [
258+
event for event in result_events if event is not None
259+
]
260+
if function_response_events:
261+
merged_event = merge_parallel_function_response_events(
262+
function_response_events
263+
)
264+
if invocation_context.run_config.streaming_mode == StreamingMode.SSE:
265+
yield merged_event
266+
if invocation_context.run_config.streaming_mode != StreamingMode.SSE:
267+
yield merged_event
244268

245269
if not function_response_events:
246-
return None
247-
248-
merged_event = merge_parallel_function_response_events(
249-
function_response_events
250-
)
270+
yield None
271+
return
251272

252273
if len(function_response_events) > 1:
253274
# this is needed for debug traces of parallel calls
@@ -258,16 +279,61 @@ async def handle_function_call_list_async(
258279
response_event_id=merged_event.id,
259280
function_response_event=merged_event,
260281
)
261-
return merged_event
262282

263283

264-
async def _execute_single_function_call_async(
284+
async def _concat_function_call_generators(
285+
gens: List[AsyncGenerator[Any]],
286+
) -> AsyncIterator[tuple[int, Any]]:
287+
_SENTINEL = object()
288+
q: asyncio.Queue[tuple[str, int, Any]] = asyncio.Queue()
289+
gens = list(gens)
290+
n = len(gens)
291+
292+
async def __pump(idx: int, agen_: AsyncIterator[Any]):
293+
try:
294+
async for x in agen_:
295+
await q.put(('ITEM', idx, x))
296+
except Exception as e:
297+
await q.put(('EXC', idx, e))
298+
finally:
299+
aclose = getattr(agen_, 'aclose', None)
300+
if callable(aclose):
301+
try:
302+
await aclose()
303+
except Exception: # noqa: ignore exception when task canceled.
304+
pass
305+
306+
await q.put(('END', idx, _SENTINEL))
307+
308+
tasks = [asyncio.create_task(__pump(i, agen)) for i, agen in enumerate(gens)]
309+
finished = 0
310+
try:
311+
while finished < n:
312+
kind, i, payload = await q.get()
313+
if kind == 'ITEM':
314+
yield i, payload
315+
316+
elif kind == 'EXC':
317+
for t in tasks:
318+
t.cancel()
319+
await asyncio.gather(*tasks, return_exceptions=True)
320+
raise payload
321+
322+
elif kind == 'END':
323+
finished += 1
324+
finally:
325+
for t in tasks:
326+
t.cancel()
327+
await asyncio.gather(*tasks, return_exceptions=True)
328+
329+
330+
async def _execute_single_function_call_async_gen(
265331
invocation_context: InvocationContext,
266332
function_call: types.FunctionCall,
267333
tools_dict: dict[str, BaseTool],
268334
agent: LlmAgent,
269335
tool_confirmation: Optional[ToolConfirmation] = None,
270-
) -> Optional[Event]:
336+
) -> AsyncGenerator[Optional[Event]]:
271337
"""Execute a single function call with thread safety for state modifications."""
272338
tool, tool_context = _get_tool_and_context(
273339
invocation_context,
@@ -310,6 +376,37 @@ async def _execute_single_function_call_async(
310376
function_response = await __call_tool_async(
311377
tool, args=function_args, tool_context=tool_context
312378
)
379+
if inspect.isasyncgen(function_response) or isinstance(
380+
function_response, AsyncIterator
381+
):
382+
res = None
383+
async for res in function_response:
384+
if inspect.isawaitable(res):
385+
res = await res
386+
if (
387+
invocation_context.run_config.streaming_mode
388+
== StreamingMode.SSE
389+
):
390+
yield __build_response_event(
391+
tool, res, tool_context, invocation_context
392+
)
393+
function_response = res
394+
elif inspect.isgenerator(function_response) or isinstance(
395+
function_response, Iterator
396+
):
397+
res = None
398+
for res in function_response:
399+
if inspect.isawaitable(res):
400+
res = await res
401+
if (
402+
invocation_context.run_config.streaming_mode
403+
== StreamingMode.SSE
404+
):
405+
yield __build_response_event(
406+
tool, res, tool_context, invocation_context
407+
)
408+
function_response = res
409+
313410
except Exception as tool_error:
314411
error_response = (
315412
await invocation_context.plugin_manager.run_on_tool_error_callback(
@@ -359,7 +456,8 @@ async def _execute_single_function_call_async(
359456
# Allow long running function to return None to not provide function
360457
# response.
361458
if not function_response:
362-
return None
459+
yield None
460+
return
363461

364462
# Note: State deltas are not applied here - they are collected in
365463
# tool_context.actions.state_delta and applied later when the session
@@ -374,7 +472,7 @@ async def _execute_single_function_call_async(
374472
args=function_args,
375473
function_response_event=function_response_event,
376474
)
377-
return function_response_event
475+
yield function_response_event
378476

379477

380478
async def handle_function_calls_live(

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /