From 26ccf7363ad05f6437e3a0d82db06f674e505874 Mon Sep 17 00:00:00 2001 From: Wen Zhang Date: Thu, 21 May 2026 11:50:00 -0700 Subject: [PATCH] fix: load partial transcription to trajectory in Live API ADK evaluation --- .../adk/evaluation/evaluation_generator.py | 40 ++++++------- .../evaluation/test_evaluation_generator.py | 56 +++++++++++++++++++ 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index d5a6629366..5b0100818c 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -383,6 +383,7 @@ async def _generate_inferences_for_single_user_invocation_live( current_invocation_id: str, turn_complete_event: asyncio.Event, live_timeout_seconds: int, + agent_name: str = _DEFAULT_AUTHOR, ) -> AsyncGenerator[Event, None]: """Generates inferences for a single user invocation in live mode.""" yield Event( @@ -408,6 +409,22 @@ async def _generate_inferences_for_single_user_invocation_live( event = await event_queue.get() if event.invocation_id == current_invocation_id: yield event + # Emit a synthetic text event for each transcription, preserving + # the order in which events are received. + if ( + event.author != _USER_AUTHOR + and event.output_transcription + and event.output_transcription.text + and event.partial + ): + yield Event( + content=Content( + role="model", + parts=[types.Part(text=event.output_transcription.text)], + ), + author=agent_name, + invocation_id=current_invocation_id, + ) @staticmethod async def _generate_inferences_from_root_agent_live( @@ -495,31 +512,10 @@ async def _generate_inferences_from_root_agent_live( current_invocation_id=live_session.current_invocation_id, turn_complete_event=live_session.turn_complete_event, live_timeout_seconds=live_timeout_seconds, + agent_name=runner.agent.name, ): events.append(event) - turn_transcription = "" - for evt in events: - if ( - evt.invocation_id == live_session.current_invocation_id - and evt.author != _USER_AUTHOR - and evt.output_transcription - ): - if not evt.partial and evt.output_transcription.text: - turn_transcription = evt.output_transcription.text - else: - turn_transcription += evt.output_transcription.text - if turn_transcription: - synthetic_event = Event( - content=Content( - role="model", - parts=[types.Part(text=turn_transcription)], - ), - author=runner.agent.name, - invocation_id=live_session.current_invocation_id, - ) - events.append(synthetic_event) - if live_session.live_finished.is_set(): logger.info("Live session finished signal detected.") break diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 508b6f5c9c..05ab25cc72 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -453,6 +453,62 @@ async def test_generate_inferences_live(self, mocker): with pytest.raises(StopAsyncIteration): await gen.__anext__() + @pytest.mark.asyncio + async def test_generate_inferences_live_with_synthetic_events(self, mocker): + """Tests live inference generation with synthetic events.""" + mock_live_request_queue = mocker.MagicMock() + event_queue = asyncio.Queue() + turn_complete_event = asyncio.Event() + + user_content = types.Content(parts=[types.Part(text="User query")]) + invocation_id = "inv1" + + transcription = types.Transcription(text="Partial transcription") + partial_event = Event( + author="agent", + content=types.Content(parts=[]), + invocation_id=invocation_id, + output_transcription=transcription, + partial=True, + ) + + gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=mock_live_request_queue, + event_queue=event_queue, + user_message=user_content, + current_invocation_id=invocation_id, + turn_complete_event=turn_complete_event, + live_timeout_seconds=300, + agent_name="custom_agent_name", + ) + + # First yield should be the user message + first_event = await gen.__anext__() + assert first_event.author == "user" + assert first_event.content == user_content + assert first_event.invocation_id == invocation_id + + # Mock turn_complete_event.wait to avoid blocking + turn_complete_event.wait = mocker.AsyncMock() + + # Put the partial event in the queue + await event_queue.put(partial_event) + + # Now advance + second_event = await gen.__anext__() + assert second_event == partial_event + + # Next should be the synthetic event + third_event = await gen.__anext__() + assert third_event.author == "custom_agent_name" + assert third_event.invocation_id == invocation_id + assert third_event.content.role == "model" + assert third_event.content.parts[0].text == "Partial transcription" + + # The generator should be exhausted now + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + @pytest.fixture def mock_runner(mocker):