Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ async def _generate_inferences_for_single_user_invocation_live(
current_invocation_id: str,
turn_complete_event: asyncio.Event,
live_timeout_seconds: int,
agent_name: str = _DEFAULT_AUTHOR,
) -> AsyncGenerator[Event, None]:
"""Generates inferences for a single user invocation in live mode."""
yield Event(
Expand All @@ -408,6 +409,22 @@ async def _generate_inferences_for_single_user_invocation_live(
event = await event_queue.get()
if event.invocation_id == current_invocation_id:
yield event
# Emit a synthetic text event for each transcription, preserving
# the order in which events are received.
if (
event.author != _USER_AUTHOR
and event.output_transcription
and event.output_transcription.text
and event.partial
):
yield Event(
content=Content(
role="model",
parts=[types.Part(text=event.output_transcription.text)],
),
author=agent_name,
invocation_id=current_invocation_id,
)

@staticmethod
async def _generate_inferences_from_root_agent_live(
Expand Down Expand Up @@ -495,31 +512,10 @@ async def _generate_inferences_from_root_agent_live(
current_invocation_id=live_session.current_invocation_id,
turn_complete_event=live_session.turn_complete_event,
live_timeout_seconds=live_timeout_seconds,
agent_name=runner.agent.name,
):
events.append(event)

turn_transcription = ""
for evt in events:
if (
evt.invocation_id == live_session.current_invocation_id
and evt.author != _USER_AUTHOR
and evt.output_transcription
):
if not evt.partial and evt.output_transcription.text:
turn_transcription = evt.output_transcription.text
else:
turn_transcription += evt.output_transcription.text
if turn_transcription:
synthetic_event = Event(
content=Content(
role="model",
parts=[types.Part(text=turn_transcription)],
),
author=runner.agent.name,
invocation_id=live_session.current_invocation_id,
)
events.append(synthetic_event)

if live_session.live_finished.is_set():
logger.info("Live session finished signal detected.")
break
Expand Down
56 changes: 56 additions & 0 deletions tests/unittests/evaluation/test_evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,62 @@ async def test_generate_inferences_live(self, mocker):
with pytest.raises(StopAsyncIteration):
await gen.__anext__()

@pytest.mark.asyncio
async def test_generate_inferences_live_with_synthetic_events(self, mocker):
"""Tests live inference generation with synthetic events."""
mock_live_request_queue = mocker.MagicMock()
event_queue = asyncio.Queue()
turn_complete_event = asyncio.Event()

user_content = types.Content(parts=[types.Part(text="User query")])
invocation_id = "inv1"

transcription = types.Transcription(text="Partial transcription")
partial_event = Event(
author="agent",
content=types.Content(parts=[]),
invocation_id=invocation_id,
output_transcription=transcription,
partial=True,
)

gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live(
live_request_queue=mock_live_request_queue,
event_queue=event_queue,
user_message=user_content,
current_invocation_id=invocation_id,
turn_complete_event=turn_complete_event,
live_timeout_seconds=300,
agent_name="custom_agent_name",
)

# First yield should be the user message
first_event = await gen.__anext__()
assert first_event.author == "user"
assert first_event.content == user_content
assert first_event.invocation_id == invocation_id

# Mock turn_complete_event.wait to avoid blocking
turn_complete_event.wait = mocker.AsyncMock()

# Put the partial event in the queue
await event_queue.put(partial_event)

# Now advance
second_event = await gen.__anext__()
assert second_event == partial_event

# Next should be the synthetic event
third_event = await gen.__anext__()
assert third_event.author == "custom_agent_name"
assert third_event.invocation_id == invocation_id
assert third_event.content.role == "model"
assert third_event.content.parts[0].text == "Partial transcription"

# The generator should be exhausted now
with pytest.raises(StopAsyncIteration):
await gen.__anext__()


@pytest.fixture
def mock_runner(mocker):
Expand Down
Loading