diff --git a/.gitignore b/.gitignore index 09f3849bf..6c32a356f 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,9 @@ Thumbs.db # Local documentation and plans docs/ plans/ + +# Sample build artifacts / local scratch +contrib/samples/**/bin/ +mkpro_logs.db + + diff --git a/azure_readme.md b/azure_readme.md new file mode 100644 index 000000000..b2ec865d6 --- /dev/null +++ b/azure_readme.md @@ -0,0 +1,730 @@ +# Azure OpenAI Integration for ADK-Java + +This document describes how Azure-hosted models connect to the Agent Development Kit (ADK), how to configure and use them, which API contracts are supported, and how to extend the integration with new Azure API surfaces. + +--- + +## Overview + +ADK-Java treats Azure OpenAI as a first-class model provider through a **unified adapter** (`AzureBaseLM`) that delegates to **transport-specific implementations** based on the deployment name. All Azure code lives under: + +``` +core/src/main/java/com/google/adk/models/ +├── AzureBaseLM.java # Unified entry point (extends BaseLlm) +└── azure/ + ├── AzureConfig.java # Shared env-based configuration + ├── AzureTransport.java # Strategy interface for API contracts + ├── AzureRequestConverter.java # ADK → Azure request mapping + ├── AzureRestTransport.java # HTTP Responses API + ├── AzureRealtimeTransport.java # WebSocket voice-agent Realtime API + ├── AzureRealtimeLlmConnection.java + ├── AzureRealtimeTranslateTransport.java + └── AzureRealtimeTranslateLlmConnection.java +``` + +ADK agents never talk to Azure directly. They use the standard ADK model APIs (`BaseLlm.generateContent`, `BaseLlm.connect`), which are wired through `LlmRegistry` or explicit `Model` instances. + +--- + +## System Architecture + +### High-level data flow + +```mermaid +flowchart TB + subgraph ADK["ADK Agent Layer"] + Agent["LlmAgent"] + Flow["BaseLlmFlow / Basic"] + Registry["LlmRegistry"] + end + + subgraph AzureAdapter["Azure Adapter"] + AzureBaseLM["AzureBaseLM"] + Config["AzureConfig"] + Converter["AzureRequestConverter"] + end + + subgraph Transports["Azure Transports (Strategy)"] + Rest["AzureRestTransport
HTTP Responses API"] + Realtime["AzureRealtimeTransport
WebSocket Realtime"] + Translate["AzureRealtimeTranslateTransport
WebSocket Translate"] + end + + subgraph Connections["Live Connections"] + Generic["GenericLlmConnection"] + RealtimeConn["AzureRealtimeLlmConnection"] + TranslateConn["AzureRealtimeTranslateLlmConnection"] + end + + subgraph Azure["Azure OpenAI"] + ResponsesAPI["Responses API (REST/SSE)"] + RealtimeWS["Realtime WebSocket"] + TranslateWS["Realtime Translations WebSocket"] + end + + Agent --> Flow + Flow --> Registry + Registry --> AzureBaseLM + Flow --> AzureBaseLM + + AzureBaseLM --> Config + AzureBaseLM -->|"selectTransport(modelName)"| Rest + AzureBaseLM --> Realtime + AzureBaseLM --> Translate + + Rest --> Converter + Realtime --> Converter + Translate --> Config + + Rest -->|"generateContent / connect"| Generic + Realtime -->|"connect"| RealtimeConn + Translate -->|"connect"| TranslateConn + + Rest --> ResponsesAPI + RealtimeConn --> RealtimeWS + TranslateConn --> TranslateWS +``` + +### Transport selection logic + +`AzureBaseLM` picks a transport automatically from the deployment name: + +| Condition on `modelName` | Transport | Protocol | +|---|---|---| +| Contains `realtime-translate` (case-insensitive) | `AzureRealtimeTranslateTransport` | WebSocket `/openai/v1/realtime/translations` | +| Contains `realtime` but **not** `realtime-translate` | `AzureRealtimeTransport` | WebSocket `/openai/v1/realtime` | +| Everything else | `AzureRestTransport` | HTTP Responses API (REST + SSE streaming) | + +```java +// AzureBaseLM.selectTransport() — simplified +if (isTranslateModel(modelName)) → AzureRealtimeTranslateTransport +if (isRealtimeModel(modelName)) → AzureRealtimeTransport +else → AzureRestTransport +``` + +--- + +## Class Diagram + +```mermaid +classDiagram + direction TB + + class BaseLlm { + <> + +model() String + +generateContent(LlmRequest, boolean) Flowable~LlmResponse~ + +connect(LlmRequest) BaseLlmConnection + } + + class AzureBaseLM { + -AzureConfig config + -AzureTransport transport + +AzureBaseLM(String modelName) + +isRealtimeModel(String) boolean$ + +isTranslateModel(String) boolean$ + -selectTransport(String) AzureTransport$ + } + + class AzureTransport { + <> + +supports(String modelName) boolean + +generateContent(LlmRequest, AzureConfig, boolean) Flowable~LlmResponse~ + +connect(LlmRequest, AzureConfig) BaseLlmConnection + } + + class AzureRestTransport { + +supports() boolean + +generateContent() Flowable~LlmResponse~ + +connect() GenericLlmConnection + } + + class AzureRealtimeTransport { + +connect() AzureRealtimeLlmConnection + } + + class AzureRealtimeTranslateTransport { + +connect() AzureRealtimeTranslateLlmConnection + } + + class AzureConfig { + +fromEnvironment(String modelName)$ AzureConfig + +responseEndpoint() String + +realtimeWebSocketUrl() String + +translationsWebSocketUrl() String + +apiKey() String + +voice() String + +translateTargetLanguage() String + } + + class AzureRequestConverter { + +extractInstructions(LlmRequest)$ String + +buildTools(LlmRequest)$ JSONArray + +schemaToJson(Schema)$ JSONObject + +cleanForIdentifier(String)$ String + } + + class BaseLlmConnection { + <> + +sendHistory(List~Content~) Completable + +sendContent(Content) Completable + +sendRealtime(Blob) Completable + +clearRealtimeAudioBuffer() Completable + +receive() Flowable~LlmResponse~ + +close() + } + + class GenericLlmConnection { + -BaseLlm llm + -List~Content~ history + } + + class AzureRealtimeLlmConnection { + -WebSocketClient wsClient + -PublishProcessor~LlmResponse~ responseProcessor + } + + class AzureRealtimeTranslateLlmConnection { + -TranslateWebSocketClient wsClient + } + + class LlmRegistry { + +getLlm(String modelName)$ BaseLlm + +registerLlm(String pattern, LlmFactory)$ void + } + + BaseLlm <|-- AzureBaseLM + AzureTransport <|.. AzureRestTransport + AzureTransport <|.. AzureRealtimeTransport + AzureTransport <|.. AzureRealtimeTranslateTransport + + AzureBaseLM --> AzureConfig + AzureBaseLM --> AzureTransport + AzureRestTransport --> AzureRequestConverter + AzureRestTransport --> AzureConfig + AzureRealtimeTransport --> AzureRealtimeLlmConnection + AzureRealtimeTranslateTransport --> AzureRealtimeTranslateLlmConnection + AzureRealtimeLlmConnection ..|> BaseLlmConnection + AzureRealtimeTranslateLlmConnection ..|> BaseLlmConnection + GenericLlmConnection ..|> BaseLlmConnection + AzureRestTransport --> GenericLlmConnection + + LlmRegistry --> AzureBaseLM : creates +``` + +--- + +## Supported API Contracts + +ADK currently supports **three Azure API contracts**, each mapped to a transport: + +### 1. Responses API (REST / SSE) — `AzureRestTransport` + +**Use for:** Text chat, function calling, reasoning models, batch inference. + +| Feature | Support | +|---|---| +| Non-streaming `generateContent` | Yes | +| SSE streaming `generateContent` | Yes | +| Function / tool calling | Yes (via `AzureRequestConverter.buildTools`) | +| System instructions | Yes (from `GenerateContentConfig.systemInstruction`) | +| Temperature / max tokens | Yes | +| Reasoning summary streaming | Yes (emitted as partial text) | +| Live `connect()` | Yes (via `GenericLlmConnection` — HTTP round-trip per turn) | +| Real-time audio | No | + +**Endpoint env var:** `AZURE_RESPONSE_ENDPOINT` + +**Example endpoint:** +``` +https://.openai.azure.com/openai/v1/responses +``` + +**Typical deployment names:** Any name that does **not** contain `realtime`, e.g. `gpt-4o`, `gpt-5`, `o3-mini`, `gpt5pro`. + +--- + +### 2. Realtime Voice Agent API — `AzureRealtimeTransport` + +**Use for:** Bidirectional voice agents with VAD, barge-in, tool calling, and audio output. + +| Feature | Support | +|---|---| +| `connect()` + live session | Yes (primary mode) | +| `sendRealtime(Blob)` — PCM16 audio in | Yes | +| `clearRealtimeAudioBuffer()` | Yes | +| `sendContent()` — text / function responses | Yes | +| `sendHistory()` | Yes | +| Audio output (PCM16) | Yes (as `Blob` in `LlmResponse`) | +| Input transcription | Yes (Whisper, as `inputTranscription`) | +| Function calling | Yes | +| Barge-in / interrupted signal | Yes (`LlmResponse.interrupted`) | +| Turn completion | Yes (`LlmResponse.turnComplete`) | +| `generateContent()` | Fallback only (short-lived WebSocket) | + +**Endpoint env var:** `AZURE_REALTIME_ENDPOINT` + +**Example endpoint:** +``` +https://.openai.azure.com/openai/v1/realtime +``` + +**Typical deployment names:** Names containing `realtime` but not `realtime-translate`, e.g. `gpt-4o-realtime-preview`, `gpt-realtime`. + +**Optional env vars:** +- `AZURE_REALTIME_VOICE` — voice name (default: `alloy`) + +--- + +### 3. GPT Realtime Translate — `AzureRealtimeTranslateTransport` + +**Use for:** Continuous speech translation (source audio in → translated audio + transcript out). + +| Feature | Support | +|---|---| +| `connect()` + live session | Yes (required) | +| `sendRealtime(Blob)` — source audio | Yes | +| Translated audio output | Yes | +| Output transcript deltas | Yes | +| Input transcript deltas | Yes (`inputTranscription`) | +| Target language config | Yes (`AZURE_TRANSLATE_TARGET_LANGUAGE`) | +| Agent turn / function calling | No | +| `generateContent()` | Not supported (throws) | + +**Endpoint env var:** `AZURE_TRANSLATE_ENDPOINT` + +**Example endpoint (GA format):** +``` +wss://.openai.azure.com/openai/v1/realtime/translations?model= +``` + +**Typical deployment names:** Names containing `realtime-translate`, e.g. `gpt-realtime-translate`. + +**Optional env vars:** +- `AZURE_TRANSLATE_TARGET_LANGUAGE` — ISO language code (default: `en`) + +> **Note:** ADK normalizes translate URLs to the GA shape (`/openai/v1/realtime/translations?model=`) and strips legacy `api-version` query params that cause HTTP 400. + +--- + +## Configuration + +### Environment variables + +| Variable | Required | Used by | Description | +|---|---|---|---| +| `AZURE_OPENAI_API_KEY` | **Yes** | All transports | API key sent as `api-key` header | +| `AZURE_RESPONSE_ENDPOINT` | For REST | `AzureRestTransport` | HTTP Responses API URL | +| `AZURE_REALTIME_ENDPOINT` | For Realtime | `AzureRealtimeTransport` | Realtime WebSocket base URL | +| `AZURE_TRANSLATE_ENDPOINT` | For Translate | `AzureRealtimeTranslateTransport` | Translate WebSocket URL | +| `AZURE_MODEL_ENDPOINT` | Fallback | All (legacy) | Used when contract-specific vars are unset | +| `AZURE_REALTIME_VOICE` | No | Realtime | Voice (default: `alloy`) | +| `AZURE_TRANSLATE_TARGET_LANGUAGE` | No | Translate | Target language (default: `en`) | + +### Example `.env` / shell setup + +```bash +# Required +export AZURE_OPENAI_API_KEY="your-api-key" + +# REST chat / tools +export AZURE_RESPONSE_ENDPOINT="https://my-resource.openai.azure.com/openai/v1/responses" + +# Voice agent +export AZURE_REALTIME_ENDPOINT="https://my-resource.openai.azure.com/openai/v1/realtime" +export AZURE_REALTIME_VOICE="alloy" + +# Speech translation +export AZURE_TRANSLATE_ENDPOINT="https://my-resource.openai.azure.com/openai/v1/realtime/translations" +export AZURE_TRANSLATE_TARGET_LANGUAGE="hi" +``` + +### Legacy single-endpoint setup + +If you only set `AZURE_MODEL_ENDPOINT`, it is used as a fallback for REST, Realtime, and Translate when the contract-specific variables are missing. Prefer contract-specific variables in production. + +--- + +## How Azure Connects to ADK + +### Registration via `LlmRegistry` + +Azure models are resolved through `LlmRegistry`, the central factory for all LLM providers. Two patterns match Azure deployments: + +```java +// Pattern 1: Explicit Azure prefix (recommended) +// Model name: "Azure|" +registerLlm("Azure\\|.*", modelName -> { + String actualModel = modelName.split("\\|", 2)[1]; + return new AzureBaseLM(actualModel); +}); + +// Pattern 2: Any model name containing "realtime" +registerLlm(".*realtime.*", modelName -> { + String actualModel = modelName.contains("|") + ? modelName.split("\\|", 2)[1] + : modelName; + return new AzureBaseLM(actualModel); +}); +``` + +At runtime, `LlmAgent` resolves the model via `LlmRegistry.getLlm(modelName)` (see `LlmAgent.resolveModelInternal()`), and `BaseLlmFlow` calls `generateContent` or `connect` on the resolved `BaseLlm`. + +### Request lifecycle (REST) + +```mermaid +sequenceDiagram + participant Agent as LlmAgent + participant Flow as BaseLlmFlow + participant LM as AzureBaseLM + participant T as AzureRestTransport + participant C as AzureRequestConverter + participant API as Azure Responses API + + Agent->>Flow: run (SSE or batch) + Flow->>LM: generateContent(LlmRequest, stream) + LM->>T: generateContent(request, config, stream) + T->>C: extractInstructions / buildTools + T->>T: buildInputItems(contents) + T->>API: POST /responses (JSON or SSE) + API-->>T: response / SSE events + T-->>Flow: Flowable + Flow-->>Agent: Event stream +``` + +### Request lifecycle (Realtime voice) + +```mermaid +sequenceDiagram + participant Agent as LlmAgent + participant Flow as BaseLlmFlow + participant LM as AzureBaseLM + participant T as AzureRealtimeTransport + participant Conn as AzureRealtimeLlmConnection + participant WS as Azure Realtime WS + + Agent->>Flow: run (live mode) + Flow->>LM: connect(LlmRequest) + LM->>T: connect(request, config) + T->>Conn: new AzureRealtimeLlmConnection + Conn->>WS: WebSocket connect + session.update + Flow->>Conn: sendHistory / sendRealtime + Conn->>WS: input_audio_buffer.append + WS-->>Conn: response.audio.delta / transcript / function_call + Conn-->>Flow: Flowable + Flow-->>Agent: Event stream (audio, text, tools) +``` + +--- + +## Usage Guide + +### 1. REST chat agent (Responses API) + +```java +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.Model; + +LlmAgent agent = LlmAgent.builder() + .name("azure-chat-agent") + .model(Model.builder().modelName("Azure|gpt-4o").build()) + .instruction("You are a helpful assistant.") + .build(); +``` + +Or instantiate the LLM directly: + +```java +import com.google.adk.models.AzureBaseLM; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; + +AzureBaseLM llm = new AzureBaseLM("gpt-4o"); + +LlmRequest request = LlmRequest.builder() + .contents(Content.fromParts(Part.fromText("Explain quantum computing briefly."))) + .build(); + +llm.generateContent(request, false) // false = non-streaming + .blockingForEach(response -> { + response.content().ifPresent(c -> + c.parts().ifPresent(parts -> + parts.forEach(p -> p.text().ifPresent(System.out::println)))); + }); +``` + +### 2. Streaming REST + +```java +llm.generateContent(request, true) // true = SSE streaming + .subscribe( + response -> { /* handle partial LlmResponse */ }, + error -> { /* handle error */ }, + () -> { /* stream complete */ }); +``` + +### 3. Function calling (tools) + +Define tools on the agent as usual. ADK converts them to Azure function schemas via `AzureRequestConverter.buildTools()`: + +```java +LlmAgent agent = LlmAgent.builder() + .name("azure-tools-agent") + .model(Model.builder().modelName("Azure|gpt-4o").build()) + .tools(myTool) + .build(); +``` + +The REST transport maps ADK `FunctionCall` / `FunctionResponse` parts to Azure Responses API `function_call` and `function_call_output` items. + +### 4. Realtime voice agent + +Set the model to a Realtime deployment and run the agent in live mode (ADK handles `connect()`, `sendRealtime`, and `receive()` via `BaseLlmFlow`): + +```java +LlmAgent voiceAgent = LlmAgent.builder() + .name("azure-voice-agent") + .model(Model.builder().modelName("Azure|gpt-4o-realtime-preview").build()) + .instruction("You are a voice assistant.") + .tools(searchTool) + .build(); +``` + +Ensure `AZURE_REALTIME_ENDPOINT` and `AZURE_OPENAI_API_KEY` are set. Audio is PCM16 (`audio/pcm` MIME type). + +Direct connection API (without full agent flow): + +```java +AzureBaseLM llm = new AzureBaseLM("gpt-4o-realtime-preview"); +BaseLlmConnection conn = llm.connect(LlmRequest.builder().build()); + +conn.receive().subscribe(response -> { /* audio blobs, transcripts, tool calls */ }); + +// Send PCM16 audio chunks +conn.sendRealtime(Blob.builder() + .mimeType("audio/pcm") + .data(pcmBytes) + .build()).blockingAwait(); + +conn.close(); +``` + +### 5. Realtime translation + +```java +AzureBaseLM translateLlm = new AzureBaseLM("gpt-realtime-translate"); +BaseLlmConnection conn = translateLlm.connect(LlmRequest.builder().build()); + +conn.receive().subscribe(response -> { + // Translated audio: response.content() → Part.inlineData (audio/pcm) + // Translated text: response.content() → Part.text (partial deltas) + // Source text: response.inputTranscription() +}); + +conn.sendRealtime(sourceAudioBlob).blockingAwait(); +``` + +Override target language programmatically: + +```java +// AzureConfig supports withTranslateTargetLanguage() if you construct config manually +``` + +--- + +## Supported Models (Deployment Names) + +ADK does not hard-code a model catalog. It routes by **deployment name pattern** and **Azure endpoint**. Any deployment hosted on your Azure resource works as long as the API contract matches. + +| Category | Name pattern | Azure API | Example deployment names | +|---|---|---|---| +| Chat / reasoning / tools | No `realtime` in name | Responses API | `gpt-4o`, `gpt-4.1`, `gpt-5`, `gpt5pro`, `o3-mini`, `o4-mini` | +| Voice agent | Contains `realtime`, not `realtime-translate` | Realtime WebSocket | `gpt-4o-realtime-preview`, `gpt-realtime` | +| Speech translation | Contains `realtime-translate` | Realtime Translations | `gpt-realtime-translate` | + +The string passed to `AzureBaseLM` or after the `Azure|` prefix must match your **Azure deployment name**, not necessarily the base model ID. + +--- + +## ADK ↔ Azure Request Mapping + +`AzureRequestConverter` is the shared conversion layer used by all transports: + +| ADK concept | Azure / OpenAI field | +|---|---| +| `GenerateContentConfig.systemInstruction` | `instructions` (REST) or `session.instructions` (Realtime) | +| `LlmRequest.tools` | `tools[]` with `type: function` | +| `Schema` (tool parameters) | JSON Schema object | +| `Content` with text parts | `input[]` messages (REST) or `conversation.item.create` (Realtime) | +| `FunctionCall` part | `function_call` item | +| `FunctionResponse` part | `function_call_output` item | +| `GenerateContentConfig.temperature` | `temperature` | +| `GenerateContentConfig.maxOutputTokens` | `max_output_tokens` | + +Tool names are sanitized via `cleanForIdentifier()` to match Azure's allowed character set `[a-zA-Z0-9_.-]`. + +--- + +## Adding a New Azure API Contract + +To add support for another Azure API surface (e.g. Chat Completions, Embeddings, a new Realtime variant): + +### Step 1 — Add configuration + +Extend `AzureConfig` with a new endpoint environment variable and accessor: + +```java +public static final String EMBEDDINGS_ENDPOINT_ENV = "AZURE_EMBEDDINGS_ENDPOINT"; + +public String embeddingsEndpoint() { + return embeddingsEndpoint; +} +``` + +Resolve it in `fromEnvironment()` using the same `resolveContractEndpoint()` helper pattern. + +### Step 2 — Create a transport + +Implement `AzureTransport`: + +```java +public final class AzureEmbeddingsTransport implements AzureTransport { + + @Override + public boolean supports(String modelName) { + return modelName != null && modelName.toLowerCase().contains("embedding"); + } + + @Override + public Flowable generateContent( + LlmRequest request, AzureConfig config, boolean stream) { + // Call Azure Embeddings API, map result to LlmResponse + } + + @Override + public BaseLlmConnection connect(LlmRequest request, AzureConfig config) { + throw new UnsupportedOperationException("Embeddings does not support live connections"); + } +} +``` + +Reuse `AzureRequestConverter` wherever ADK types need conversion. + +### Step 3 — Wire transport selection + +Update `AzureBaseLM.selectTransport()`: + +```java +private static AzureTransport selectTransport(String modelName) { + if (isTranslateModel(modelName)) return new AzureRealtimeTranslateTransport(); + if (isRealtimeModel(modelName)) return new AzureRealtimeTransport(); + if (isEmbeddingModel(modelName)) return new AzureEmbeddingsTransport(); // new + return new AzureRestTransport(); +} +``` + +Add a public static detection helper alongside `isRealtimeModel()` / `isTranslateModel()`. + +### Step 4 — (Optional) Add a live connection class + +If the new contract uses WebSocket or another persistent protocol, implement `BaseLlmConnection` in the `azure` subpackage (follow `AzureRealtimeLlmConnection` as a reference): + +- Open connection in constructor +- Map protocol events → `LlmResponse` via `PublishProcessor` +- Implement `sendHistory`, `sendContent`, `sendRealtime` as appropriate +- Handle barge-in, errors, and cleanup in `close()` + +Return the connection from your transport's `connect()` method. + +### Step 5 — Register in `LlmRegistry` (if needed) + +If the new contract uses a distinct model name pattern, register a factory: + +```java +LlmRegistry.registerLlm("Azure\\|.*embedding.*", name -> new AzureBaseLM(name.split("\\|", 2)[1])); +``` + +Existing `Azure|*` and `.*realtime.*` patterns already route to `AzureBaseLM` for most cases. + +### Step 6 — Document and test + +- Add env var docs to this file +- Add unit tests for URL normalization, request conversion, and response parsing +- Add an integration test gated on env vars (see existing patterns in `contrib/spring-ai`) + +### Design principles to follow + +1. **One transport per API contract** — do not mix REST and WebSocket logic in the same class. +2. **Shared config in `AzureConfig`** — never read env vars directly from transports. +3. **Shared conversion in `AzureRequestConverter`** — avoid duplicating tool/instruction mapping. +4. **Return ADK types** — all transports must emit `LlmResponse` / `BaseLlmConnection`, never leak raw Azure JSON to agent code. +5. **Keep `AzureBaseLM` thin** — it should only select transport and delegate. + +--- + +## Package Reference + +| Class | Responsibility | +|---|---| +| `AzureBaseLM` | Unified `BaseLlm` entry point; transport selection | +| `AzureConfig` | Env-based endpoints, API key, voice, translate language | +| `AzureTransport` | Strategy interface for API contracts | +| `AzureRequestConverter` | ADK `LlmRequest` → Azure JSON (instructions, tools, schemas) | +| `AzureRestTransport` | HTTP Responses API (sync + SSE streaming) | +| `AzureRealtimeTransport` | Realtime WebSocket transport wrapper | +| `AzureRealtimeLlmConnection` | Full Realtime protocol (audio, VAD, tools, barge-in) | +| `AzureRealtimeTranslateTransport` | Translate WebSocket transport wrapper | +| `AzureRealtimeTranslateLlmConnection` | Translation session protocol | +| `GenericLlmConnection` | HTTP-based pseudo-connection used by REST transport | +| `LlmRegistry` | Factory/registry that creates `AzureBaseLM` instances | +| `BaseLlmFlow` | Agent flow that calls `generateContent` or `connect` | + +--- + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `AZURE_OPENAI_API_KEY environment variable is not set` | Missing API key | Set `AZURE_OPENAI_API_KEY` | +| `Azure Responses API endpoint not configured` | Missing REST endpoint | Set `AZURE_RESPONSE_ENDPOINT` | +| Translate returns HTTP 400 | Legacy preview URL with `api-version` | Use GA URL: `/openai/v1/realtime/translations?model=` | +| `Unsupported model: ...` | Name doesn't match any `LlmRegistry` pattern | Use `Azure\|` or register a custom pattern | +| Realtime connects but no audio | Wrong MIME type | Send PCM16 as `audio/pcm` | +| Function calls missing name on Realtime | API version omits fields on `function_call_arguments.done` | Already handled via `pendingFunctionCalls` map in `AzureRealtimeLlmConnection` | +| Voice agent gets empty REST response | Realtime deployment used with REST endpoint | Use `AZURE_REALTIME_ENDPOINT` and a `realtime` deployment name | + +--- + +## Related Documentation + +- [Azure OpenAI Responses API](https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses) +- [Azure OpenAI Realtime Audio WebSockets](https://learn.microsoft.com/en-us/azure/foundry/openai/how-to/realtime-audio-websockets) +- [GPT Realtime Translate overview](https://learn.microsoft.com/en-us/azure/foundry/openai/concepts/gpt-realtime-translate) +- ADK transcription capability: [`TRANSCRIPTION_CAPABILITY.md`](TRANSCRIPTION_CAPABILITY.md) +- Spring AI bridge (alternative Azure path): [`contrib/spring-ai/README.md`](contrib/spring-ai/README.md) + +--- + +## Quick Reference + +```bash +# Minimal REST setup +export AZURE_OPENAI_API_KEY="..." +export AZURE_RESPONSE_ENDPOINT="https://.openai.azure.com/openai/v1/responses" +``` + +```java +// Minimal agent +LlmAgent.builder() + .name("my-agent") + .model(Model.builder().modelName("Azure|my-deployment").build()) + .build(); +``` + +```java +// Direct LLM access +BaseLlm llm = new AzureBaseLM("my-deployment"); +llm.generateContent(request, stream).subscribe(...); +``` diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 28f675df9..f0d12a1a9 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -68,7 +68,6 @@ public class Event extends JsonBaseModel { private @Nullable String modelVersion; private @Nullable Transcription inputTranscription; private @Nullable Transcription outputTranscription; - private long timestamp; private Event() {} @@ -586,10 +585,10 @@ public Event build() { event.setGroundingMetadata(groundingMetadata); event.setCustomMetadata(customMetadata); event.setModelVersion(modelVersion); - event.setActions(actions().orElseGet(() -> EventActions.builder().build())); - event.setTimestamp(timestamp().orElseGet(() -> Instant.now().toEpochMilli())); event.setInputTranscription(inputTranscription); event.setOutputTranscription(outputTranscription); + event.setActions(actions().orElseGet(() -> EventActions.builder().build())); + event.setTimestamp(timestamp().orElseGet(() -> Instant.now().toEpochMilli())); return event; } } diff --git a/core/src/main/java/com/google/adk/models/AzureBaseLM.java b/core/src/main/java/com/google/adk/models/AzureBaseLM.java index 8efed09e8..ee7564fcb 100644 --- a/core/src/main/java/com/google/adk/models/AzureBaseLM.java +++ b/core/src/main/java/com/google/adk/models/AzureBaseLM.java @@ -1,985 +1,88 @@ package com.google.adk.models; -import static com.google.adk.models.RedbusADG.cleanForIdentifierPattern; -import static com.google.common.collect.ImmutableList.toImmutableList; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; -import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.FunctionDeclaration; -import com.google.genai.types.GenerateContentConfig; -import com.google.genai.types.GenerateContentResponseUsageMetadata; -import com.google.genai.types.Part; -import com.google.genai.types.Schema; +import com.google.adk.models.azure.AzureConfig; +import com.google.adk.models.azure.AzureRealtimeTranslateTransport; +import com.google.adk.models.azure.AzureRealtimeTransport; +import com.google.adk.models.azure.AzureRestTransport; +import com.google.adk.models.azure.AzureTransport; import io.reactivex.rxjava3.core.Flowable; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * BaseLlm implementation for Azure OpenAI models via the Responses API. + * Unified Azure LLM adapter that delegates to the appropriate transport based on model type. + * + *

Supports all Azure-hosted models (REST Responses API, WebSocket Realtime API, and future + * transports) through a single entry point. Transport selection is automatic based on model name. + * + *

Environment variables (see {@link AzureConfig}): * - *

Reads the endpoint from {@code AZURE_MODEL_ENDPOINT} and the API key from {@code - * AZURE_OPENAI_API_KEY} environment variables. The model/deployment name is passed to the - * constructor and sent in the request body. + *

    + *
  • {@code AZURE_RESPONSE_ENDPOINT} — REST Responses API + *
  • {@code AZURE_REALTIME_ENDPOINT} — WebSocket voice-agent Realtime API + *
  • {@code AZURE_TRANSLATE_ENDPOINT} — WebSocket GPT Realtime Translate + *
  • {@code AZURE_MODEL_ENDPOINT} — (legacy) fallback for all contracts above + *
  • {@code AZURE_OPENAI_API_KEY} — API key + *
  • {@code AZURE_REALTIME_VOICE} — (optional) voice for realtime models + *
* * @author Alfred Jimmy - * @see Azure - * OpenAI Responses API documentation */ public class AzureBaseLM extends BaseLlm { private static final Logger logger = LoggerFactory.getLogger(AzureBaseLM.class); - public static final String API_KEY_ENV = "AZURE_OPENAI_API_KEY"; - public static final String ENDPOINT_ENV = "AZURE_MODEL_ENDPOINT"; - - private static final int CONNECT_TIMEOUT_SECONDS = 60; - private static final int READ_TIMEOUT_SECONDS = 180; - - private static final ObjectMapper OBJECT_MAPPER = - new ObjectMapper().registerModule(new Jdk8Module()); - - private static final String CONTINUE_OUTPUT_MESSAGE = - "Continue output. DO NOT look at this line. ONLY look at the content before this line and" - + " system instruction."; - - private static final HttpClient httpClient = - HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_2) - .connectTimeout(Duration.ofSeconds(CONNECT_TIMEOUT_SECONDS)) - .build(); - - private final String modelName; + private final AzureConfig config; + private final AzureTransport transport; /** - * Creates an AzureBaseLM for the given model name. The endpoint URL and API key are resolved from - * environment variables {@code AZURE_MODEL_ENDPOINT} and {@code AZURE_OPENAI_API_KEY}. + * Creates an AzureBaseLM for the given model/deployment name. * - * @param modelName model/deployment name sent in the request body (e.g. "gpt5pro") + * @param modelName the Azure deployment name (e.g. "gpt5pro", "gpt-4o-realtime-preview") */ public AzureBaseLM(String modelName) { super(modelName); - this.modelName = modelName; - warnIfMissing(ENDPOINT_ENV); - warnIfMissing(API_KEY_ENV); - } - - private void warnIfMissing(String envVar) { - String val = System.getenv(envVar); - if (val == null || val.isBlank()) { - logger.warn("{} is not set. Azure API calls for '{}' will fail.", envVar, modelName); - } - } - - private String resolveEndpointUrl() { - String envUrl = System.getenv(ENDPOINT_ENV); - if (envUrl != null && !envUrl.isBlank()) { - return envUrl; - } - throw new IllegalStateException(ENDPOINT_ENV + " environment variable is not set."); + this.config = AzureConfig.fromEnvironment(modelName); + this.transport = selectTransport(modelName); + logger.info( + "AzureBaseLM initialized: model={}, transport={}", + modelName, + transport.getClass().getSimpleName()); } - private String resolveApiKey() { - String key = System.getenv(API_KEY_ENV); - if (key == null || key.isBlank()) { - throw new IllegalStateException(API_KEY_ENV + " environment variable is not set."); - } - return key; - } - - // ==================== BaseLlm contract ==================== - @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { - return stream ? generateContentStream(llmRequest) : generateContentSync(llmRequest); + return transport.generateContent(llmRequest, config, stream); } @Override public BaseLlmConnection connect(LlmRequest llmRequest) { - return new GenericLlmConnection(this, llmRequest); + return transport.connect(llmRequest, config); } - // ==================== Non-streaming ==================== - - private Flowable generateContentSync(LlmRequest llmRequest) { - List contents = ensureLastContentIsUser(llmRequest.contents()); - String instructions = extractInstructions(llmRequest); - JSONArray inputItems = buildInputItems(contents); - JSONArray tools = buildTools(llmRequest); - - boolean lastRespToolExecuted = - Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); - - Optional temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature); - Optional maxTokens = - llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); - - JSONObject payload = new JSONObject(); - payload.put("model", modelName); - payload.put("input", inputItems); - if (!instructions.isEmpty()) { - payload.put("instructions", instructions); - } - temperature.ifPresent(t -> payload.put("temperature", t)); - payload.put("stream", false); - payload.put("store", false); - payload.put("reasoning", new JSONObject().put("summary", "auto")); - if (maxTokens.isPresent() && maxTokens.get() > 0) { - payload.put("max_output_tokens", maxTokens.get()); - } - if (!lastRespToolExecuted && tools.length() > 0) { - payload.put("tools", tools); - } - - logger.debug("Azure Responses API request payload size: {} bytes", payload.toString().length()); - - JSONObject response = callApi(payload); - - if (response.has("error") && !response.isNull("error")) { - logger.error("Azure Responses API error: {}", response); - return Flowable.just( - LlmResponse.builder() - .content(Content.builder().role("model").parts(Part.fromText("")).build()) - .build()); - } - - GenerateContentResponseUsageMetadata usageMetadata = extractUsageMetadata(response); - LlmResponse llmResponse = parseOutputToLlmResponse(response, usageMetadata); - return Flowable.just(llmResponse); - } - - // ==================== Streaming ==================== - - private Flowable generateContentStream(LlmRequest llmRequest) { - List contents = ensureLastContentIsUser(llmRequest.contents()); - String instructions = extractInstructions(llmRequest); - JSONArray inputItems = buildInputItems(contents); - JSONArray tools = buildTools(llmRequest); - - boolean lastRespToolExecuted = - Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); - - Optional temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature); - Optional maxTokens = - llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); - - JSONObject payload = new JSONObject(); - payload.put("model", modelName); - payload.put("input", inputItems); - if (!instructions.isEmpty()) { - payload.put("instructions", instructions); - } - temperature.ifPresent(t -> payload.put("temperature", t)); - payload.put("stream", true); - payload.put("store", false); - payload.put("reasoning", new JSONObject().put("summary", "auto")); - if (maxTokens.isPresent() && maxTokens.get() > 0) { - payload.put("max_output_tokens", maxTokens.get()); - } - if (!lastRespToolExecuted && tools.length() > 0) { - payload.put("tools", tools); - } - - final StringBuilder accumulatedText = new StringBuilder(); - final StringBuilder reasoningSummary = new StringBuilder(); - final StringBuilder functionCallName = new StringBuilder(); - final StringBuilder functionCallCallId = new StringBuilder(); - final StringBuilder functionCallArgs = new StringBuilder(); - final AtomicBoolean inFunctionCall = new AtomicBoolean(false); - final AtomicBoolean finalTextEmitted = new AtomicBoolean(false); - final AtomicInteger inputTokens = new AtomicInteger(0); - final AtomicInteger outputTokens = new AtomicInteger(0); - - logger.info("[STREAM-DEBUG] Starting streaming request for model: {}", modelName); - logger.info("[STREAM-DEBUG] Payload size: {} bytes", payload.toString().length()); - - return Flowable.create( - emitter -> { - BufferedReader reader = null; - try { - logger.info("[STREAM-DEBUG] Opening SSE connection..."); - reader = callApiStream(payload); - if (reader == null) { - logger.warn("[STREAM-DEBUG] Reader is null — stream failed to open."); - emitter.onComplete(); - return; - } - logger.info("[STREAM-DEBUG] SSE connection opened successfully."); - long streamStartMs = System.currentTimeMillis(); - int chunkCount = 0; - - String lastEventName = null; - String line; - while ((line = reader.readLine()) != null) { - if (emitter.isCancelled()) { - logger.info("[STREAM-DEBUG] Emitter cancelled, breaking out of read loop."); - break; - } - - logger.debug( - "SSE raw: {}", line.length() > 200 ? line.substring(0, 200) + "..." : line); - - if (line.isEmpty()) continue; - if (line.startsWith("event:")) { - lastEventName = line.substring(6).trim(); - continue; - } - if (!line.startsWith("data:")) continue; - - String jsonStr = line.substring(5).trim(); - if (jsonStr.equals("[DONE]")) { - long elapsed = System.currentTimeMillis() - streamStartMs; - logger.info( - "[STREAM-DEBUG] [DONE] marker received after {}ms, total chunks: {}", - elapsed, - chunkCount); - break; - } - - chunkCount++; - JSONObject event; - try { - event = new JSONObject(jsonStr); - } catch (JSONException e) { - logger.warn( - "[STREAM-DEBUG] Failed to parse SSE chunk #{}: {}", chunkCount, jsonStr); - logger.warn("Failed to parse Azure SSE chunk: {}", jsonStr); - continue; - } - - String eventType = event.optString("type", ""); - if (eventType.isEmpty() && lastEventName != null) { - eventType = lastEventName; - } - lastEventName = null; - - logger.debug( - "[STREAM-DEBUG] Chunk #{} eventType='{}' keys={}", - chunkCount, - eventType, - event.keySet()); - logger.debug("SSE event type='{}' keys={}", eventType, event.keySet()); - - switch (eventType) { - case "response.output_item.added": - { - JSONObject item = event.optJSONObject("item"); - if (item == null) break; - String itemType = item.optString("type", ""); - logger.debug("[STREAM-DEBUG] output_item.added — itemType='{}'", itemType); - if ("function_call".equals(itemType)) { - inFunctionCall.set(true); - String name = item.optString("name", ""); - String callId = item.optString("call_id", ""); - logger.info( - "[STREAM-DEBUG] Function call starting: name='{}' callId='{}'", - name, - callId); - if (!name.isEmpty()) functionCallName.append(name); - if (!callId.isEmpty()) functionCallCallId.append(callId); - } else if ("reasoning".equals(itemType)) { - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText("\ud83e\udde0 Thinking...\n")) - .build()) - .partial(true) - .build()); - } - break; - } - - case "response.reasoning_summary_text.delta": - { - String delta = event.optString("delta", ""); - if (!delta.isEmpty()) { - logger.debug( - "[STREAM-DEBUG] Reasoning delta ({} chars): {}", - delta.length(), - delta.length() > 80 ? delta.substring(0, 80) + "..." : delta); - reasoningSummary.append(delta); - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(delta)) - .build()) - .partial(true) - .build()); - } - break; - } - - case "response.reasoning_summary_text.done": - { - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText("\n\n")) - .build()) - .partial(true) - .build()); - break; - } - - case "response.output_text.delta": - { - String delta = extractTextDeltaFromStreamEvent(event); - if (!delta.isEmpty()) { - logger.debug( - "[STREAM-DEBUG] Text delta ({} chars): {}", - delta.length(), - delta.length() > 100 ? delta.substring(0, 100) + "..." : delta); - logger.debug( - "[STREAM-DEBUG] Accumulated text so far: {} chars", - accumulatedText.length()); - accumulatedText.append(delta); - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(delta)) - .build()) - .partial(true) - .build()); - } - break; - } - - case "response.output_text.done": - { - String fullText = event.optString("text", ""); - logger.info( - "[STREAM-DEBUG] output_text.done — full text length: {} chars", - fullText.length()); - if (!fullText.isEmpty()) { - accumulatedText.setLength(0); - accumulatedText.append(fullText); - finalTextEmitted.set(true); - String finalContent = fullText; - if (reasoningSummary.length() > 0) { - finalContent = - "\ud83e\udde0 **Thinking:**\n> " - + reasoningSummary.toString().replace("\n", "\n> ") - + "\n\n" - + fullText; - } - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(finalContent)) - .build()) - .partial(false) - .build()); - } - break; - } - - case "response.output_item.done": - { - logger.debug( - "[STREAM-DEBUG] output_item.done — finalTextEmitted={}", - finalTextEmitted.get()); - if (finalTextEmitted.get()) break; - JSONObject item = event.optJSONObject("item"); - if (item != null && "message".equals(item.optString("type"))) { - String fullText = extractTextFromOutputMessageItem(item); - if (!fullText.isEmpty()) { - accumulatedText.setLength(0); - accumulatedText.append(fullText); - finalTextEmitted.set(true); - String finalContent = fullText; - if (reasoningSummary.length() > 0) { - finalContent = - "\ud83e\udde0 **Thinking:**\n> " - + reasoningSummary.toString().replace("\n", "\n> ") - + "\n\n" - + fullText; - } - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(finalContent)) - .build()) - .partial(false) - .build()); - } - } - break; - } - - case "response.function_call_arguments.delta": - { - String delta = extractTextDeltaFromStreamEvent(event); - if (!delta.isEmpty()) { - logger.debug( - "[STREAM-DEBUG] Function args delta ({} chars): {}", - delta.length(), - delta.length() > 100 ? delta.substring(0, 100) + "..." : delta); - functionCallArgs.append(delta); - } - break; - } - - case "response.function_call_arguments.done": - { - logger.info( - "[STREAM-DEBUG] function_call_arguments.done — name='{}' argsLength={}", - functionCallName, - functionCallArgs.length()); - if (functionCallName.length() > 0) { - String argsStr = - functionCallArgs.length() > 0 ? functionCallArgs.toString() : "{}"; - Map args; - try { - args = new JSONObject(argsStr).toMap(); - } catch (JSONException e) { - logger.warn("Failed to parse function args: {}", argsStr); - args = Map.of(); - } - FunctionCall fc = - FunctionCall.builder() - .name(functionCallName.toString()) - .args(args) - .build(); - emitter.onNext( - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts( - ImmutableList.of(Part.builder().functionCall(fc).build())) - .build()) - .partial(false) - .build()); - } - break; - } - - case "response.completed": - { - logger.info("[STREAM-DEBUG] response.completed received."); - JSONObject resp = event.optJSONObject("response"); - if (resp != null) { - JSONObject usage = resp.optJSONObject("usage"); - if (usage != null) { - inputTokens.set(usage.optInt("input_tokens", 0)); - outputTokens.set(usage.optInt("output_tokens", 0)); - logger.info( - "[STREAM-DEBUG] Token usage — input: {}, output: {}", - inputTokens.get(), - outputTokens.get()); - } - } - break; - } - - default: - break; - } - } - - long totalElapsed = System.currentTimeMillis() - streamStartMs; - logger.info( - "[STREAM-DEBUG] Stream read loop finished — elapsed: {}ms, chunks: {}," - + " accumulatedText: {} chars, finalTextEmitted: {}, inFunctionCall: {}", - totalElapsed, - chunkCount, - accumulatedText.length(), - finalTextEmitted.get(), - inFunctionCall.get()); - - if (!emitter.isCancelled()) { - if (!finalTextEmitted.get()) { - logger.info("[STREAM-DEBUG] Emitting final accumulated response from post-loop."); - emitFinalStreamResponse( - emitter, - accumulatedText, - inFunctionCall, - functionCallName, - functionCallCallId, - functionCallArgs, - inputTokens.get(), - outputTokens.get()); - } - logger.info("[STREAM-DEBUG] Calling emitter.onComplete()."); - emitter.onComplete(); - } - } catch (IOException e) { - logger.error("[STREAM-DEBUG] IOException in stream: {}", e.getMessage()); - logger.error("IOException in Azure stream", e); - if (!emitter.isCancelled()) emitter.onError(e); - } catch (Exception e) { - logger.error("[STREAM-DEBUG] Exception in stream: {}", e.getMessage()); - logger.error("Error in Azure streaming", e); - if (!emitter.isCancelled()) emitter.onError(e); - } finally { - if (reader != null) { - try { - reader.close(); - } catch (IOException e) { - logger.error("Error closing stream reader", e); - } - } - } - }, - io.reactivex.rxjava3.core.BackpressureStrategy.BUFFER); - } - - /** Delta may be a string or a nested object depending on API version. */ - private static String extractTextDeltaFromStreamEvent(JSONObject event) { - if (event == null || event.isNull("delta")) { - return ""; - } - Object delta = event.opt("delta"); - if (delta instanceof String) { - return (String) delta; - } - if (delta instanceof JSONObject) { - JSONObject o = (JSONObject) delta; - return o.optString("text", o.optString("content", "")); - } - return ""; - } - - /** Full assistant text from a Responses API output message item (streaming completion). */ - private static String extractTextFromOutputMessageItem(JSONObject messageItem) { - JSONArray content = messageItem.optJSONArray("content"); - if (content == null) { - return ""; - } - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < content.length(); i++) { - JSONObject part = content.optJSONObject(i); - if (part == null) { - continue; - } - String pType = part.optString("type", ""); - if ("output_text".equals(pType) || "text".equals(pType)) { - sb.append(part.optString("text", "")); - } - } - return sb.toString(); - } - - private void emitFinalStreamResponse( - io.reactivex.rxjava3.core.Emitter emitter, - StringBuilder accumulatedText, - AtomicBoolean inFunctionCall, - StringBuilder functionCallName, - StringBuilder functionCallCallId, - StringBuilder functionCallArgs, - int promptTokens, - int completionTokens) { - - GenerateContentResponseUsageMetadata usageMetadata = - buildUsageMetadata(promptTokens, completionTokens); - - if (inFunctionCall.get() && functionCallName.length() > 0) { - // Function call was already emitted in response.function_call_arguments.done - // but if it wasn't (edge case), emit it now with usage - return; - } - - if (accumulatedText.length() > 0) { - LlmResponse.Builder builder = - LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(accumulatedText.toString())) - .build()) - .partial(false); - if (usageMetadata != null) { - builder.usageMetadata(usageMetadata); - } - emitter.onNext(builder.build()); - } - } - - // ==================== Request building ==================== - - private List ensureLastContentIsUser(List contents) { - if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { - Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); - return Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); - } - return contents; - } - - private String extractInstructions(LlmRequest llmRequest) { - return llmRequest - .config() - .flatMap(GenerateContentConfig::systemInstruction) - .flatMap(Content::parts) - .map( - parts -> - parts.stream() - .filter(p -> p.text().isPresent()) - .map(p -> p.text().get()) - .collect(Collectors.joining("\n"))) - .filter(text -> !text.isEmpty()) - .orElse(""); - } - - /** - * Converts ADK Content list to Responses API input items. - * - *

Unlike Chat Completions (which uses a flat messages array with roles), the Responses API - * uses typed items: plain messages use {@code {role, content}}, function calls use {@code {type: - * "function_call", ...}}, and tool results use {@code {type: "function_call_output", ...}}. - */ - private JSONArray buildInputItems(List contents) { - JSONArray items = new JSONArray(); - - for (Content item : contents) { - String role = item.role().orElse("user"); - List parts = item.parts().orElse(ImmutableList.of()); - - if (parts.isEmpty()) { - JSONObject msg = new JSONObject(); - msg.put("role", role.equals("model") ? "assistant" : role); - msg.put("content", item.text()); - items.put(msg); - continue; - } - - Part firstPart = parts.get(0); - - if (firstPart.functionResponse().isPresent()) { - JSONObject output = new JSONObject(); - output.put("type", "function_call_output"); - output.put( - "call_id", "call_" + firstPart.functionResponse().get().name().orElse("unknown")); - output.put( - "output", - new JSONObject(firstPart.functionResponse().get().response().get()).toString()); - items.put(output); - } else if (firstPart.functionCall().isPresent()) { - FunctionCall fc = firstPart.functionCall().get(); - JSONObject fcItem = new JSONObject(); - fcItem.put("type", "function_call"); - fcItem.put("call_id", "call_" + fc.name().orElse("unknown")); - fcItem.put("name", fc.name().orElse("")); - fcItem.put("arguments", new JSONObject(fc.args().orElse(Map.of())).toString()); - items.put(fcItem); - } else { - JSONObject msg = new JSONObject(); - msg.put("role", role.equals("model") ? "assistant" : role); - msg.put("content", item.text()); - items.put(msg); - } - } - return items; - } - - /** - * Builds Responses API tool definitions (internally-tagged). - * - *

Unlike Chat Completions' externally-tagged {@code {type:"function", function:{name:...}}}, - * the Responses API uses {@code {type:"function", name:..., parameters:...}} at the top level. - */ - private JSONArray buildTools(LlmRequest llmRequest) { - JSONArray tools = new JSONArray(); - llmRequest - .tools() - .forEach( - (name, baseTool) -> { - Optional declOpt = baseTool.declaration(); - if (declOpt.isEmpty()) { - logger.warn("Skipping tool '{}' with missing declaration.", baseTool.name()); - return; - } - - FunctionDeclaration decl = declOpt.get(); - JSONObject tool = new JSONObject(); - tool.put("type", "function"); - tool.put("name", cleanForIdentifierPattern(decl.name().get())); - tool.put("description", decl.description().orElse("")); - - Optional paramsOpt = decl.parameters(); - if (paramsOpt.isPresent()) { - Schema paramsSchema = paramsOpt.get(); - Map paramsMap = new HashMap<>(); - paramsMap.put("type", "object"); - - Optional> propsOpt = paramsSchema.properties(); - if (propsOpt.isPresent()) { - Map propsMap = new HashMap<>(); - propsOpt - .get() - .forEach( - (key, schema) -> { - Map schemaMap = - OBJECT_MAPPER.convertValue( - schema, new TypeReference>() {}); - normalizeTypeStrings(schemaMap); - propsMap.put(key, schemaMap); - }); - paramsMap.put("properties", propsMap); - } - - paramsSchema - .required() - .ifPresent(requiredList -> paramsMap.put("required", requiredList)); - tool.put("parameters", new JSONObject(paramsMap)); - } - - tools.put(tool); - }); - return tools; - } - - // ==================== HTTP transport ==================== - - private JSONObject callApi(JSONObject payload) { - try { - String url = resolveEndpointUrl(); - String apiKey = resolveApiKey(); - String jsonString = payload.toString(); - - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Content-Type", "application/json; charset=UTF-8") - .header("api-key", apiKey) - .timeout(Duration.ofSeconds(READ_TIMEOUT_SECONDS)) - .POST(HttpRequest.BodyPublishers.ofString(jsonString, StandardCharsets.UTF_8)) - .build(); - - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); - - int statusCode = response.statusCode(); - logger.info("Azure Responses API status: {} for model: {}", statusCode, model()); - - if (statusCode >= 200 && statusCode < 300) { - return new JSONObject(response.body()); - } else { - logger.error("Azure API error: status={} body={}", statusCode, response.body()); - try { - return new JSONObject(response.body()); - } catch (JSONException e) { - return new JSONObject().put("error", response.body()); - } - } - } catch (IOException | InterruptedException ex) { - logger.error("HTTP request failed for Azure Responses API", ex); - return new JSONObject().put("error", ex.getMessage()); - } - } - - private BufferedReader callApiStream(JSONObject payload) { - try { - String url = resolveEndpointUrl(); - String apiKey = resolveApiKey(); - String jsonString = payload.toString(); - - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Content-Type", "application/json; charset=UTF-8") - .header("api-key", apiKey) - .header("Accept", "text/event-stream") - .timeout(Duration.ofSeconds(READ_TIMEOUT_SECONDS)) - .POST(HttpRequest.BodyPublishers.ofString(jsonString, StandardCharsets.UTF_8)) - .build(); - - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - - int statusCode = response.statusCode(); - logger.info("Azure Responses API streaming status: {} for model: {}", statusCode, model()); - - if (statusCode >= 200 && statusCode < 300) { - return new BufferedReader(new InputStreamReader(response.body(), StandardCharsets.UTF_8)); - } else { - try (BufferedReader errorReader = - new BufferedReader(new InputStreamReader(response.body(), StandardCharsets.UTF_8))) { - StringBuilder errorBody = new StringBuilder(); - String errorLine; - while ((errorLine = errorReader.readLine()) != null) { - errorBody.append(errorLine); - } - logger.error("Azure streaming failed: status={} body={}", statusCode, errorBody); - } - return null; - } - } catch (IOException | InterruptedException ex) { - logger.error("HTTP request failed for Azure streaming", ex); - return null; - } - } - - // ==================== Response parsing ==================== - - private LlmResponse parseOutputToLlmResponse( - JSONObject response, GenerateContentResponseUsageMetadata usageMetadata) { - - JSONArray output = response.optJSONArray("output"); - if (output == null || output.length() == 0) { - logger.warn("Azure Responses API returned empty output: {}", response); - return LlmResponse.builder() - .content(Content.builder().role("model").parts(Part.fromText("")).build()) - .build(); - } - - List parts = new ArrayList<>(); - - for (int i = 0; i < output.length(); i++) { - JSONObject item = output.getJSONObject(i); - String type = item.optString("type", ""); - - switch (type) { - case "message": - { - JSONArray content = item.optJSONArray("content"); - if (content != null) { - for (int j = 0; j < content.length(); j++) { - JSONObject contentItem = content.getJSONObject(j); - if ("output_text".equals(contentItem.optString("type"))) { - parts.add(Part.fromText(contentItem.optString("text", ""))); - } - } - } - break; - } - - case "function_call": - { - String name = item.optString("name", null); - String argsStr = item.optString("arguments", "{}"); - if (name != null) { - Map args; - try { - args = new JSONObject(argsStr).toMap(); - } catch (JSONException e) { - logger.warn("Failed to parse function arguments: {}", argsStr); - args = Map.of(); - } - FunctionCall fc = FunctionCall.builder().name(name).args(args).build(); - parts.add(Part.builder().functionCall(fc).build()); - } - break; - } - - default: - // Skip reasoning items and other non-actionable output types - break; - } - } - - if (parts.isEmpty()) { - parts.add(Part.fromText("")); - } - - boolean hasFunctionCall = parts.stream().anyMatch(p -> p.functionCall().isPresent()); - - LlmResponse.Builder builder = LlmResponse.builder(); - if (hasFunctionCall) { - Part fcPart = parts.stream().filter(p -> p.functionCall().isPresent()).findFirst().get(); - builder.content(Content.builder().role("model").parts(ImmutableList.of(fcPart)).build()); - } else { - builder.content(Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build()); - } - - if (usageMetadata != null) { - builder.usageMetadata(usageMetadata); - } - - return builder.build(); - } - - private GenerateContentResponseUsageMetadata extractUsageMetadata(JSONObject response) { - if (response == null || !response.has("usage")) { - return null; - } - try { - JSONObject usage = response.getJSONObject("usage"); - int inputTok = usage.optInt("input_tokens", 0); - int outputTok = usage.optInt("output_tokens", 0); - int totalTok = usage.optInt("total_tokens", inputTok + outputTok); - - if (totalTok > 0 || inputTok > 0 || outputTok > 0) { - logger.info( - "Azure token usage: input={}, output={}, total={}", inputTok, outputTok, totalTok); - return GenerateContentResponseUsageMetadata.builder() - .promptTokenCount(inputTok) - .candidatesTokenCount(outputTok) - .totalTokenCount(totalTok) - .build(); - } - } catch (Exception e) { - logger.warn("Failed to parse token usage from Azure response", e); + /** Returns true if the given model name is GPT Realtime Translate. */ + public static boolean isTranslateModel(String modelName) { + if (modelName == null) { + return false; } - return null; + return modelName.toLowerCase().contains("realtime-translate"); } - private GenerateContentResponseUsageMetadata buildUsageMetadata(int inputTok, int outputTok) { - int totalTok = inputTok + outputTok; - if (totalTok > 0 || inputTok > 0 || outputTok > 0) { - return GenerateContentResponseUsageMetadata.builder() - .promptTokenCount(inputTok) - .candidatesTokenCount(outputTok) - .totalTokenCount(totalTok) - .build(); + /** Returns true if the given model name indicates an Azure Realtime voice-agent model. */ + public static boolean isRealtimeModel(String modelName) { + if (modelName == null) { + return false; } - return null; + return modelName.toLowerCase().contains("realtime") && !isTranslateModel(modelName); } - @SuppressWarnings("unchecked") - private void normalizeTypeStrings(Map valueDict) { - if (valueDict == null) return; - if (valueDict.containsKey("type") && valueDict.get("type") instanceof String) { - valueDict.put("type", ((String) valueDict.get("type")).toLowerCase()); + private static AzureTransport selectTransport(String modelName) { + if (isTranslateModel(modelName)) { + return new AzureRealtimeTranslateTransport(); } - if (valueDict.containsKey("items") && valueDict.get("items") instanceof Map) { - Map itemsMap = (Map) valueDict.get("items"); - normalizeTypeStrings(itemsMap); - if (itemsMap.containsKey("properties") && itemsMap.get("properties") instanceof Map) { - Map properties = (Map) itemsMap.get("properties"); - for (Object value : properties.values()) { - if (value instanceof Map) { - normalizeTypeStrings((Map) value); - } - } - } + if (isRealtimeModel(modelName)) { + return new AzureRealtimeTransport(); } + return new AzureRestTransport(); } } diff --git a/core/src/main/java/com/google/adk/models/BaseLlmConnection.java b/core/src/main/java/com/google/adk/models/BaseLlmConnection.java index c8093ff9c..6addc7f4b 100644 --- a/core/src/main/java/com/google/adk/models/BaseLlmConnection.java +++ b/core/src/main/java/com/google/adk/models/BaseLlmConnection.java @@ -49,6 +49,15 @@ public interface BaseLlmConnection { */ Completable sendRealtime(Blob blob); + /** + * Clears the realtime input audio buffer on connections that use the Realtime protocol (e.g. + * Azure OpenAI {@code input_audio_buffer}). Default is a no-op for connections that do not expose + * such a buffer. + */ + default Completable clearRealtimeAudioBuffer() { + return Completable.complete(); + } + /** Receives the model responses. */ Flowable receive(); diff --git a/core/src/main/java/com/google/adk/models/LlmRegistry.java b/core/src/main/java/com/google/adk/models/LlmRegistry.java index bb2930b95..36e519d85 100644 --- a/core/src/main/java/com/google/adk/models/LlmRegistry.java +++ b/core/src/main/java/com/google/adk/models/LlmRegistry.java @@ -41,6 +41,18 @@ public interface LlmFactory { registerLlm("gemma-.*", modelName -> Gemma.builder().modelName(modelName).build()); registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build()); registerLlm("gpt-oss-.*", modelName -> GptOssLlm.builder().modelName(modelName).build()); + registerLlm( + ".*realtime.*", + modelName -> { + String actualModel = modelName.contains("|") ? modelName.split("\\|", 2)[1] : modelName; + return new AzureBaseLM(actualModel); + }); + registerLlm( + "Azure\\|.*", + modelName -> { + String actualModel = modelName.split("\\|", 2)[1]; + return new AzureBaseLM(actualModel); + }); } /** diff --git a/core/src/main/java/com/google/adk/models/azure/AzureConfig.java b/core/src/main/java/com/google/adk/models/azure/AzureConfig.java new file mode 100644 index 000000000..c187caedd --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureConfig.java @@ -0,0 +1,274 @@ +package com.google.adk.models.azure; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Shared configuration for all Azure transports (REST, Realtime voice, Realtime translate). + * + *

Each API contract has its own endpoint environment variable. {@code AZURE_MODEL_ENDPOINT} is + * kept as a legacy fallback when a contract-specific variable is not set. + * + *

Environment variables: + * + *

    + *
  • {@code AZURE_RESPONSE_ENDPOINT} — HTTP Responses API + *
  • {@code AZURE_REALTIME_ENDPOINT} — WebSocket voice-agent Realtime API + *
  • {@code AZURE_TRANSLATE_ENDPOINT} — WebSocket GPT Realtime Translate + *
  • {@code AZURE_MODEL_ENDPOINT} — (legacy) fallback for all of the above + *
  • {@code AZURE_OPENAI_API_KEY} — API key + *
  • {@code AZURE_REALTIME_VOICE} — (optional) voice for realtime models, defaults to "alloy" + *
  • {@code AZURE_TRANSLATE_TARGET_LANGUAGE} — (optional) default target language, defaults to + * "en" + *
+ */ +public final class AzureConfig { + + private static final Logger logger = LoggerFactory.getLogger(AzureConfig.class); + + /** + * @deprecated Use contract-specific endpoint variables. + */ + public static final String LEGACY_ENDPOINT_ENV = "AZURE_MODEL_ENDPOINT"; + + /** + * @deprecated Use {@link #LEGACY_ENDPOINT_ENV} or contract-specific variables. + */ + @Deprecated public static final String ENDPOINT_ENV = LEGACY_ENDPOINT_ENV; + + public static final String RESPONSE_ENDPOINT_ENV = "AZURE_RESPONSE_ENDPOINT"; + public static final String REALTIME_ENDPOINT_ENV = "AZURE_REALTIME_ENDPOINT"; + public static final String TRANSLATE_ENDPOINT_ENV = "AZURE_TRANSLATE_ENDPOINT"; + + public static final String API_KEY_ENV = "AZURE_OPENAI_API_KEY"; + public static final String VOICE_ENV = "AZURE_REALTIME_VOICE"; + public static final String TRANSLATE_TARGET_LANGUAGE_ENV = "AZURE_TRANSLATE_TARGET_LANGUAGE"; + + private static final String DEFAULT_VOICE = "alloy"; + private static final String DEFAULT_TRANSLATE_LANGUAGE = "en"; + + private final String modelName; + private final String responseEndpoint; + private final String realtimeEndpoint; + private final String translateEndpoint; + private final String apiKey; + private final String voice; + private final String translateTargetLanguage; + + private AzureConfig( + String modelName, + String responseEndpoint, + String realtimeEndpoint, + String translateEndpoint, + String apiKey, + String voice, + String translateTargetLanguage) { + this.modelName = modelName; + this.responseEndpoint = responseEndpoint; + this.realtimeEndpoint = realtimeEndpoint; + this.translateEndpoint = translateEndpoint; + this.apiKey = apiKey; + this.voice = voice; + this.translateTargetLanguage = translateTargetLanguage; + } + + public static AzureConfig fromEnvironment(String modelName) { + String legacy = resolveOptionalEnv(LEGACY_ENDPOINT_ENV); + String responseEndpoint = + resolveContractEndpoint(RESPONSE_ENDPOINT_ENV, legacy, "Responses API"); + String realtimeEndpoint = + resolveContractEndpoint(REALTIME_ENDPOINT_ENV, legacy, "Realtime voice API"); + String translateEndpoint = resolveTranslateEndpoint(legacy, modelName); + + String apiKey = resolveRequired(API_KEY_ENV); + String voice = resolveOptional(VOICE_ENV, DEFAULT_VOICE); + String translateTargetLanguage = + resolveOptional(TRANSLATE_TARGET_LANGUAGE_ENV, DEFAULT_TRANSLATE_LANGUAGE); + + logger.info( + "AzureConfig for model={}: response={}, realtime={}, translate={}", + modelName, + maskEndpoint(responseEndpoint), + maskEndpoint(realtimeEndpoint), + maskEndpoint(translateEndpoint)); + + return new AzureConfig( + modelName, + responseEndpoint, + realtimeEndpoint, + translateEndpoint, + apiKey, + voice, + translateTargetLanguage); + } + + public String modelName() { + return modelName; + } + + /** HTTP endpoint for the Azure Responses API (REST). */ + public String responseEndpoint() { + return responseEndpoint; + } + + /** + * @deprecated Use {@link #responseEndpoint()}, {@link #realtimeWebSocketUrl()}, or {@link + * #translationsWebSocketUrl()}. + */ + @Deprecated + public String endpoint() { + return responseEndpoint; + } + + public String apiKey() { + return apiKey; + } + + public String voice() { + return voice; + } + + public String translateTargetLanguage() { + return translateTargetLanguage; + } + + public AzureConfig withTranslateTargetLanguage(String language) { + String lang = + (language != null && !language.isBlank()) ? language.trim() : translateTargetLanguage; + return new AzureConfig( + modelName, responseEndpoint, realtimeEndpoint, translateEndpoint, apiKey, voice, lang); + } + + /** WebSocket URL for bidirectional voice-agent Realtime. Uses {@link #REALTIME_ENDPOINT_ENV}. */ + public String realtimeWebSocketUrl() { + String ws = toWebSocketUrl(realtimeEndpoint); + if (ws.contains("deployment=") || ws.contains("model=")) { + return ws; + } + String param = realtimeEndpoint.contains("/v1/") ? "model" : "deployment"; + String separator = ws.contains("?") ? "&" : "?"; + return ws + separator + param + "=" + modelName; + } + + /** WebSocket URL for GPT Realtime Translate. Uses {@link #TRANSLATE_ENDPOINT_ENV}. */ + public String translationsWebSocketUrl() { + if (translateEndpoint == null || translateEndpoint.isBlank()) { + throw new IllegalStateException( + TRANSLATE_ENDPOINT_ENV + + " is not set. Example:" + + " wss://.openai.azure.com/openai/v1/realtime/translations?model=" + + modelName); + } + String normalized = normalizeTranslateWebSocketUrl(translateEndpoint, modelName); + if (!normalized.equals(toWebSocketUrl(translateEndpoint))) { + logger.warn( + "Normalized {} (was: {}). Use GA format:" + + " wss:///openai/v1/realtime/translations?model= — no api-version.", + maskEndpoint(normalized), + maskEndpoint(translateEndpoint)); + } + return normalized; + } + + /** + * Forces GA translate URL shape: {@code /openai/v1/realtime/translations?model=} without {@code + * api-version}. Preview-style URLs ({@code /openai/realtime/translations?api-version=...}) return + * HTTP 400. + */ + static String normalizeTranslateWebSocketUrl(String raw, String modelName) { + String ws = toWebSocketUrl(raw); + String http = ws.replaceFirst("^wss://", "https://").replaceFirst("^ws://", "http://"); + java.net.URI uri = java.net.URI.create(http); + String host = uri.getHost(); + if (host == null || host.isBlank()) { + throw new IllegalStateException("Invalid translate endpoint (no host): " + raw); + } + String modelParam = + extractQueryParam(raw, "model", extractQueryParam(raw, "deployment", modelName)); + return "wss://" + host + "/openai/v1/realtime/translations?model=" + modelParam; + } + + private static String resolveContractEndpoint( + String specificEnv, String legacyFallback, String label) { + String val = resolveOptionalEnv(specificEnv); + if (val == null) { + val = legacyFallback; + } + if (val == null || val.isBlank()) { + throw new IllegalStateException( + "Azure " + + label + + " endpoint not configured. Set " + + specificEnv + + " or " + + LEGACY_ENDPOINT_ENV); + } + return val; + } + + private static String resolveTranslateEndpoint(String legacyFallback, String modelName) { + String explicit = resolveOptionalEnv(TRANSLATE_ENDPOINT_ENV); + if (explicit != null) { + return normalizeTranslateWebSocketUrl(explicit, modelName); + } + + String base = resolveOptionalEnv(REALTIME_ENDPOINT_ENV); + if (base == null) { + base = legacyFallback; + } + if (base == null || base.isBlank()) { + return null; + } + + return normalizeTranslateWebSocketUrl(base, modelName); + } + + private static String extractQueryParam(String url, String key, String defaultValue) { + int q = url.indexOf('?'); + if (q < 0) { + return defaultValue; + } + for (String param : url.substring(q + 1).split("&")) { + if (param.startsWith(key + "=")) { + return param.substring((key + "=").length()); + } + } + return defaultValue; + } + + private static String toWebSocketUrl(String url) { + return url.replaceFirst("^https://", "wss://").replaceFirst("^http://", "ws://"); + } + + private static String resolveRequired(String envVar) { + String val = System.getenv(envVar); + if (val == null || val.isBlank()) { + throw new IllegalStateException(envVar + " environment variable is not set."); + } + return val.replaceAll("/+$", ""); + } + + private static String resolveOptional(String envVar, String defaultValue) { + String val = System.getenv(envVar); + return (val != null && !val.isBlank()) ? val : defaultValue; + } + + private static String resolveOptionalEnv(String envVar) { + String val = System.getenv(envVar); + return (val != null && !val.isBlank()) ? val.replaceAll("/+$", "") : null; + } + + private static String maskEndpoint(String url) { + if (url == null) { + return "unset"; + } + try { + java.net.URI u = + java.net.URI.create( + url.replaceFirst("^wss://", "https://").replaceFirst("^ws://", "http://")); + return (u.getHost() != null ? u.getHost() : "?") + (u.getPath() != null ? u.getPath() : ""); + } catch (Exception e) { + return "(configured)"; + } + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRealtimeLlmConnection.java b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeLlmConnection.java new file mode 100644 index 000000000..bd6251446 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeLlmConnection.java @@ -0,0 +1,772 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import com.google.genai.types.Transcription; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.processors.PublishProcessor; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.handshake.ServerHandshake; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * WebSocket-based connection to the Azure OpenAI Realtime API. + * + *

Implements the GA WebSocket protocol: + * + *

    + *
  1. Open a WebSocket to {@code + * wss://.openai.azure.com/openai/v1/realtime?model=} + *
  2. Authenticate via {@code api-key} header + *
  3. Send/receive JSON events for text, audio, and function calls + *
+ * + * @author Alfred Jimmy + * @see + * Azure OpenAI Realtime API via WebSockets + */ +public final class AzureRealtimeLlmConnection implements BaseLlmConnection { + + private static final Logger logger = LoggerFactory.getLogger(AzureRealtimeLlmConnection.class); + + private static final int CONNECT_TIMEOUT_SECONDS = 30; + + /** + * Close-mic / phone-held noise reduction (not {@code far_field}, which favors room/distant + * pickup). + */ + private static final String INPUT_AUDIO_NOISE_REDUCTION = "far_field"; + + private static final String SEMANTIC_VAD_EAGERNESS = "high"; + + private static final boolean CREATE_RESPONSE_AFTER_TURN = true; + + private static final boolean INTERRUPT_RESPONSE = true; + + private final AzureConfig config; + private final LlmRequest llmRequest; + private final PublishProcessor responseProcessor = PublishProcessor.create(); + private final Flowable responseFlowable = responseProcessor.serialize(); + private final AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean sessionConfigured = new AtomicBoolean(false); + private final CountDownLatch connectedLatch = new CountDownLatch(1); + + private RealtimeWebSocketClient wsClient; + + /** + * When true, we already forwarded assistant text via {@code response.*.delta} events for this + * response; the matching {@code *.done} carries the full string again and must not be printed + * twice. + */ + private final AtomicBoolean assistantOutputTextHadDelta = new AtomicBoolean(false); + + private final AtomicBoolean assistantAudioTranscriptHadDelta = new AtomicBoolean(false); + + /** True while Azure is generating a response (between response.created and response.done). */ + private final AtomicBoolean activeResponse = new AtomicBoolean(false); + + /** + * Tracks in-flight function calls by item_id so that {@code + * response.function_call_arguments.done} (which may omit name/call_id on some API versions) can + * be resolved. Populated from {@code response.output_item.added} events. + */ + private final ConcurrentHashMap pendingFunctionCalls = + new ConcurrentHashMap<>(); + + private static final Set WHISPER_HALLUCINATIONS = + Set.of( + "thank you.", + "thanks for watching.", + "bye.", + "you", + "the end.", + "thanks for watching!", + "subscribe", + "продолжение следует...", + "thank you for watching.", + "."); + + private record FunctionCallInfo(String name, String callId) {} + + AzureRealtimeLlmConnection(AzureConfig config, LlmRequest llmRequest) { + this.config = Objects.requireNonNull(config, "config cannot be null"); + this.llmRequest = Objects.requireNonNull(llmRequest, "llmRequest cannot be null"); + + try { + initializeConnection(); + } catch (Exception e) { + logger.error("Failed to initialize Azure Realtime WebSocket connection", e); + responseProcessor.onError(e); + } + } + + // ==================== Connection Initialization ==================== + + private void initializeConnection() throws Exception { + logger.info( + "Initializing Azure Realtime WebSocket connection for model: {}", config.modelName()); + + String apiKey = config.apiKey(); + + String wsUrl = config.realtimeWebSocketUrl(); + + logger.info("Connecting to WebSocket: {}", wsUrl); + + URI uri = URI.create(wsUrl); + wsClient = new RealtimeWebSocketClient(uri, apiKey); + wsClient.connectBlocking(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + + if (!wsClient.isOpen()) { + throw new IllegalStateException("WebSocket connection failed to open within timeout"); + } + + if (!connectedLatch.await(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + throw new IllegalStateException("WebSocket connected but session.created not received"); + } + + sendSessionUpdate(); + logger.info("Azure Realtime WebSocket connection established."); + } + + private void sendSessionUpdate() { + String voice = config.voice(); + String instructions = AzureRequestConverter.extractInstructions(llmRequest); + + JSONObject event = new JSONObject(); + event.put("type", "session.update"); + + JSONObject session = new JSONObject(); + if (!instructions.isEmpty()) { + session.put("instructions", instructions); + } + session.put("voice", voice); + session.put("modalities", new JSONArray().put("text").put("audio")); + + session.put("input_audio_format", "pcm16"); + session.put("output_audio_format", "pcm16"); + + JSONObject noiseReduction = new JSONObject(); + noiseReduction.put("type", INPUT_AUDIO_NOISE_REDUCTION); + session.put("input_audio_noise_reduction", noiseReduction); + + JSONObject turnDetection = new JSONObject(); + turnDetection.put("type", "semantic_vad"); + turnDetection.put("eagerness", SEMANTIC_VAD_EAGERNESS); + turnDetection.put("create_response", CREATE_RESPONSE_AFTER_TURN); + turnDetection.put("interrupt_response", INTERRUPT_RESPONSE); + session.put("turn_detection", turnDetection); + + JSONObject transcription = new JSONObject(); + transcription.put("model", "whisper-1"); + session.put("input_audio_transcription", transcription); + + JSONArray toolsArray = AzureRequestConverter.buildTools(llmRequest); + if (toolsArray.length() > 0) { + session.put("tools", toolsArray); + session.put("tool_choice", "auto"); + } + + event.put("session", session); + sendMessage(event.toString()); + logger.info( + "Sent session.update with voice={}, turn_detection={}, noise_reduction={}, tools={}", + voice, + turnDetection, + INPUT_AUDIO_NOISE_REDUCTION, + toolsArray.length()); + } + + // ==================== WebSocket Event Handling ==================== + + private void handleMessage(String json) { + if (closed.get()) return; + + try { + JSONObject event = new JSONObject(json); + String eventType = event.optString("type", ""); + + logger.debug("Realtime WS event: {}", eventType); + + switch (eventType) { + case "session.created": + logger.info( + "Realtime session created: {}", + event.optJSONObject("session") != null + ? event.optJSONObject("session").optString("id", "unknown") + : "unknown"); + sessionConfigured.set(true); + connectedLatch.countDown(); + break; + + case "session.updated": + JSONObject updatedSession = event.optJSONObject("session"); + JSONObject appliedTurnDetection = + updatedSession != null ? updatedSession.optJSONObject("turn_detection") : null; + logger.info( + "Realtime session updated; turn_detection={}", + appliedTurnDetection != null ? appliedTurnDetection.toString() : "none"); + break; + + case "response.created": + assistantOutputTextHadDelta.set(false); + assistantAudioTranscriptHadDelta.set(false); + activeResponse.set(true); + break; + + case "response.text.delta": + case "response.output_text.delta": + handleTextDelta(event); + break; + + case "response.text.done": + case "response.output_text.done": + handleTextDone(event); + break; + + case "response.audio_transcript.delta": + case "response.output_audio_transcript.delta": + handleTranscriptDelta(event); + break; + + case "response.audio_transcript.done": + case "response.output_audio_transcript.done": + handleTranscriptDone(event); + break; + + case "response.audio.delta": + case "response.output_audio.delta": + handleAudioDelta(event); + break; + + case "response.output_item.added": + handleOutputItemAdded(event); + break; + + case "response.function_call_arguments.delta": + break; + + case "response.function_call_arguments.done": + handleFunctionCallDone(event); + break; + + case "response.done": + handleResponseDone(event); + break; + + case "input_audio_buffer.speech_started": + // WebSocket clients should stop playback on speech_started during an active response + // (OpenAI Realtime guide). Gemini emits interrupted() immediately; Azure relies on + // server VAD + interrupt_response, then response.done status=cancelled — but that + // response.done can lag or be missed, so emit interrupted here as the primary signal. + if (activeResponse.get()) { + logger.info( + "Realtime: speech_started during active response — emitting interrupted (barge-in)."); + responseProcessor.onNext(LlmResponse.builder().interrupted(true).build()); + } else { + logger.debug("Realtime: speech_started (no active response)."); + } + break; + + case "input_audio_buffer.speech_stopped": + logger.debug("User speech stopped."); + break; + + case "input_audio_buffer.committed": + case "conversation.item.created": + case "response.output_item.done": + case "response.content_part.added": + case "response.content_part.done": + logger.debug("Lifecycle event: {}", eventType); + break; + + case "conversation.item.input_audio_transcription.completed": + handleInputTranscription(event); + break; + + case "error": + handleErrorEvent(event); + break; + + default: + logger.debug("Unhandled Realtime event type: {}", eventType); + break; + } + } catch (JSONException e) { + logger.warn("Failed to parse WebSocket message: {}", json, e); + } + } + + private void handleTextDelta(JSONObject event) { + String delta = event.optString("delta", ""); + if (!delta.isEmpty()) { + assistantOutputTextHadDelta.set(true); + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(delta)).build()) + .partial(true) + .build()); + } + } + + private void handleTextDone(JSONObject event) { + String text = event.optString("text", ""); + if (assistantOutputTextHadDelta.compareAndSet(true, false)) { + emitAssistantTurnTerminatorOnly(); + return; + } + if (!text.isEmpty()) { + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(text)).build()) + .partial(false) + .build()); + } + } + + private void handleTranscriptDelta(JSONObject event) { + String delta = event.optString("delta", ""); + if (!delta.isEmpty()) { + assistantAudioTranscriptHadDelta.set(true); + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(delta)).build()) + .partial(true) + .build()); + } + } + + private void handleTranscriptDone(JSONObject event) { + String transcript = event.optString("transcript", ""); + if (assistantAudioTranscriptHadDelta.compareAndSet(true, false)) { + emitAssistantTurnTerminatorOnly(); + return; + } + if (!transcript.isEmpty()) { + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(transcript)).build()) + .partial(false) + .build()); + } + } + + /** Ends the assistant line in the UI without repeating text already streamed via deltas. */ + private void emitAssistantTurnTerminatorOnly() { + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .partial(false) + .build()); + } + + private void handleAudioDelta(JSONObject event) { + String base64Audio = event.optString("delta", ""); + if (!base64Audio.isEmpty()) { + try { + byte[] audioBytes = Base64.getDecoder().decode(base64Audio); + logger.debug("Received {} bytes of audio from model", audioBytes.length); + Blob audioBlob = Blob.builder().mimeType("audio/pcm").data(audioBytes).build(); + + responseProcessor.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().inlineData(audioBlob).build())) + .build()) + .partial(true) + .build()); + } catch (IllegalArgumentException e) { + logger.warn("Failed to decode audio delta", e); + } + } + } + + /** + * Captures function_call items from {@code response.output_item.added} so that name and call_id + * are available when {@code response.function_call_arguments.done} arrives (some API versions + * omit them from the latter event). + */ + private void handleOutputItemAdded(JSONObject event) { + JSONObject item = event.optJSONObject("item"); + if (item == null) return; + String type = item.optString("type", ""); + if (!"function_call".equals(type)) return; + + String itemId = item.optString("id", ""); + String name = item.optString("name", ""); + String callId = item.optString("call_id", ""); + if (!itemId.isEmpty() && !name.isEmpty()) { + pendingFunctionCalls.put(itemId, new FunctionCallInfo(name, callId)); + logger.info( + "Tracked pending function_call: item_id={}, name={}, call_id={}", itemId, name, callId); + } + } + + private void handleFunctionCallDone(JSONObject event) { + String name = event.optString("name", ""); + String callId = event.optString("call_id", ""); + String itemId = event.optString("item_id", ""); + String argsStr = event.optString("arguments", "{}"); + + if (name.isEmpty() && !itemId.isEmpty()) { + FunctionCallInfo tracked = pendingFunctionCalls.remove(itemId); + if (tracked != null) { + name = tracked.name(); + if (callId.isEmpty()) callId = tracked.callId(); + } + } else if (!itemId.isEmpty()) { + pendingFunctionCalls.remove(itemId); + } + + if (name.isEmpty()) { + logger.warn( + "Dropping function_call_arguments.done with no resolvable name (item_id={})", itemId); + return; + } + + Map args; + try { + args = new JSONObject(argsStr).toMap(); + } catch (JSONException e) { + logger.warn("Failed to parse function call arguments: {}", argsStr); + args = Map.of(); + } + + FunctionCall.Builder fcBuilder = FunctionCall.builder().name(name).args(args); + if (!callId.isEmpty()) { + fcBuilder.id(callId); + } + FunctionCall fc = fcBuilder.build(); + logger.info( + "Emitting FunctionCall: name={}, call_id={}, args_keys={}", name, callId, args.keySet()); + responseProcessor.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().functionCall(fc).build())) + .build()) + .partial(false) + .turnComplete(true) + .build()); + } + + private void handleResponseDone(JSONObject event) { + activeResponse.set(false); + JSONObject resp = event.optJSONObject("response"); + String status = + resp != null ? resp.optString("status", "").trim().toLowerCase(java.util.Locale.ROOT) : ""; + JSONObject statusDetails = resp != null ? resp.optJSONObject("status_details") : null; + String statusReason = + statusDetails != null + ? statusDetails.optString("reason", "").trim().toLowerCase(java.util.Locale.ROOT) + : ""; + boolean interrupted = + "cancelled".equals(status) + || "canceled".equals(status) + || "interrupted".equals(status) + || ("incomplete".equals(status) && "turn_detected".equals(statusReason)); + if (interrupted) { + logger.info( + "Realtime response ended with status={} reason={} — emitting interrupted playback signal.", + status, + statusReason.isEmpty() ? "n/a" : statusReason); + responseProcessor.onNext(LlmResponse.builder().interrupted(true).build()); + } else if ("completed".equals(status) || status.isEmpty()) { + // Align turnComplete with response.done (after audio finishes), not transcript.done. + logger.info( + "Realtime response completed (status={}) — emitting turnComplete.", + status.isEmpty() ? "unknown" : status); + responseProcessor.onNext(LlmResponse.builder().turnComplete(true).build()); + } else { + logger.info("Realtime response ended with status={}.", status); + } + + if (resp != null) { + JSONObject usage = resp.optJSONObject("usage"); + if (usage != null) { + logger.info( + "Realtime token usage — input: {}, output: {}", + usage.optInt("input_tokens", 0), + usage.optInt("output_tokens", 0)); + } + } + } + + private void handleInputTranscription(JSONObject event) { + String transcript = event.optString("transcript", "").trim(); + if (transcript.isEmpty()) return; + + if (transcript.length() <= 2 + || WHISPER_HALLUCINATIONS.contains(transcript.toLowerCase(java.util.Locale.ROOT))) { + logger.debug("Filtered likely Whisper hallucination: '{}'", transcript); + return; + } + + // Mirror Gemini Live: transcription is independent of the model turn and must NOT + // arrive as user-role content (LiveAudioSession treats user-role during playback + // as a turn boundary and fires voice_complete prematurely). + responseProcessor.onNext( + LlmResponse.builder() + .inputTranscription(Transcription.builder().text(transcript).finished(true).build()) + .build()); + } + + private void handleErrorEvent(JSONObject event) { + JSONObject error = event.optJSONObject("error"); + String message = error != null ? error.optString("message", "Unknown error") : "Unknown error"; + logger.error("Realtime API error: {}", message); + responseProcessor.onNext(LlmResponse.builder().errorMessage(message).build()); + } + + // ==================== BaseLlmConnection Methods ==================== + + @Override + public Completable sendHistory(List history) { + return Completable.fromAction( + () -> { + if (closed.get()) { + throw new IllegalStateException("Connection is closed"); + } + for (Content content : history) { + sendContentOverWebSocket(content); + } + }); + } + + @Override + public Completable sendContent(Content content) { + return Completable.fromAction( + () -> { + if (closed.get()) { + throw new IllegalStateException("Connection is closed"); + } + Objects.requireNonNull(content, "content cannot be null"); + + boolean isFunctionResponse = + content.parts().isPresent() + && !content.parts().get().isEmpty() + && content.parts().get().get(0).functionResponse().isPresent(); + + if (isFunctionResponse) { + sendFunctionResponseOverWebSocket(content); + } else { + sendContentOverWebSocket(content); + sendResponseCreate(); + } + }); + } + + @Override + public Completable sendRealtime(Blob blob) { + return Completable.fromAction( + () -> { + if (closed.get()) { + throw new IllegalStateException("Connection is closed"); + } + Objects.requireNonNull(blob, "blob cannot be null"); + + byte[] audioData = blob.data().orElse(new byte[0]); + if (audioData.length == 0) { + return; + } + + String base64Audio = Base64.getEncoder().encodeToString(audioData); + JSONObject event = new JSONObject(); + event.put("type", "input_audio_buffer.append"); + event.put("audio", base64Audio); + sendMessage(event.toString()); + }); + } + + @Override + public Completable clearRealtimeAudioBuffer() { + return Completable.fromAction( + () -> { + if (closed.get()) { + throw new IllegalStateException("Connection is closed"); + } + JSONObject event = new JSONObject(); + event.put("type", "input_audio_buffer.clear"); + logger.debug("Sending input_audio_buffer.clear"); + sendMessage(event.toString()); + }); + } + + @Override + public Flowable receive() { + return responseFlowable; + } + + @Override + public void close() { + closeInternal(null); + } + + @Override + public void close(Throwable throwable) { + Objects.requireNonNull(throwable, "throwable cannot be null"); + closeInternal(throwable); + } + + // ==================== Internal Helpers ==================== + + private void sendContentOverWebSocket(Content content) { + String role = content.role().orElse("user"); + String text = + content.parts().isPresent() + ? content.parts().get().stream() + .filter(p -> p.text().isPresent()) + .map(p -> p.text().get()) + .collect(Collectors.joining("\n")) + : ""; + + JSONObject event = new JSONObject(); + event.put("type", "conversation.item.create"); + + JSONObject item = new JSONObject(); + item.put("type", "message"); + item.put("role", role.equals("model") ? "assistant" : role); + + JSONArray contentArr = new JSONArray(); + JSONObject contentItem = new JSONObject(); + contentItem.put("type", "input_text"); + contentItem.put("text", text); + contentArr.put(contentItem); + item.put("content", contentArr); + + event.put("item", item); + sendMessage(event.toString()); + } + + private void sendFunctionResponseOverWebSocket(Content content) { + content + .parts() + .ifPresent( + parts -> + parts.forEach( + part -> + part.functionResponse() + .ifPresent( + fr -> { + JSONObject event = new JSONObject(); + event.put("type", "conversation.item.create"); + + JSONObject item = new JSONObject(); + item.put("type", "function_call_output"); + String callId = + fr.id().orElse("call_" + fr.name().orElse("unknown")); + item.put("call_id", callId); + item.put( + "output", + new JSONObject(fr.response().orElse(Map.of())).toString()); + + event.put("item", item); + sendMessage(event.toString()); + }))); + + sendResponseCreate(); + } + + private void sendResponseCreate() { + JSONObject event = new JSONObject(); + event.put("type", "response.create"); + sendMessage(event.toString()); + } + + private void sendMessage(String json) { + if (wsClient == null || !wsClient.isOpen()) { + logger.warn("WebSocket is not open, cannot send message."); + return; + } + try { + wsClient.send(json); + logger.debug("Sent over WebSocket: {} bytes", json.getBytes(StandardCharsets.UTF_8).length); + } catch (Exception e) { + logger.error("Failed to send over WebSocket", e); + } + } + + private void closeInternal(Throwable throwable) { + if (closed.compareAndSet(false, true)) { + logger.info("Closing AzureRealtimeLlmConnection."); + + if (throwable == null) { + responseProcessor.onComplete(); + } else { + responseProcessor.onError(throwable); + } + + try { + if (wsClient != null && wsClient.isOpen()) { + wsClient.closeBlocking(); + wsClient = null; + } + } catch (Exception e) { + logger.warn("Error closing WebSocket", e); + } + } + } + + // ==================== WebSocket Client ==================== + + private class RealtimeWebSocketClient extends WebSocketClient { + + RealtimeWebSocketClient(URI uri, String apiKey) { + super(uri); + addHeader("api-key", apiKey); + } + + @Override + public void onOpen(ServerHandshake handshake) { + logger.info("WebSocket connection opened (status: {})", handshake.getHttpStatus()); + } + + @Override + public void onMessage(String message) { + handleMessage(message); + } + + @Override + public void onClose(int code, String reason, boolean remote) { + logger.info("WebSocket closed: code={}, reason={}, remote={}", code, reason, remote); + if (!closed.get()) { + closeInternal( + new IllegalStateException("WebSocket closed unexpectedly: " + code + " " + reason)); + } + } + + @Override + public void onError(Exception ex) { + logger.error("WebSocket error", ex); + if (!closed.get()) { + closeInternal(ex); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateLlmConnection.java b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateLlmConnection.java new file mode 100644 index 000000000..6eceb7540 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateLlmConnection.java @@ -0,0 +1,381 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import com.google.genai.types.Transcription; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.processors.PublishProcessor; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.handshake.ServerHandshake; +import org.json.JSONException; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * WebSocket connection to Azure OpenAI GPT Realtime Translate. + * + *

Uses the translation session protocol ({@code /openai/v1/realtime/translations}): continuous + * source audio in, translated audio and transcript deltas out. No {@code response.create} or agent + * turn lifecycle. + * + * @see Realtime + * translation + * @see + * GPT Realtime Translate overview + */ +public final class AzureRealtimeTranslateLlmConnection implements BaseLlmConnection { + + private static final Logger logger = + LoggerFactory.getLogger(AzureRealtimeTranslateLlmConnection.class); + + private static final int CONNECT_TIMEOUT_SECONDS = 30; + + private final AzureConfig config; + private final PublishProcessor responseProcessor = PublishProcessor.create(); + private final Flowable responseFlowable = responseProcessor.serialize(); + private final AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean sessionClosing = new AtomicBoolean(false); + private final CountDownLatch connectedLatch = new CountDownLatch(1); + + private final AtomicBoolean outputTranscriptHadDelta = new AtomicBoolean(false); + + private TranslateWebSocketClient wsClient; + + AzureRealtimeTranslateLlmConnection(AzureConfig config, LlmRequest llmRequest) { + this.config = Objects.requireNonNull(config, "config cannot be null"); + Objects.requireNonNull(llmRequest, "llmRequest cannot be null"); + + try { + initializeConnection(); + } catch (Exception e) { + logger.error("Failed to initialize Azure Realtime Translate WebSocket connection", e); + responseProcessor.onError(e); + throw new IllegalStateException( + "Failed to initialize Azure Realtime Translate WebSocket connection", e); + } + } + + /** Returns true when the translation WebSocket is open and session.created was received. */ + public boolean isConnected() { + return wsClient != null && wsClient.isOpen() && connectedLatch.getCount() == 0; + } + + private void initializeConnection() throws Exception { + String apiKey = config.apiKey(); + String wsUrl = config.translationsWebSocketUrl(); + + logger.info("Connecting to Azure Realtime Translate WebSocket: {}", wsUrl); + + URI uri = URI.create(wsUrl); + wsClient = new TranslateWebSocketClient(uri, apiKey); + wsClient.connectBlocking(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + + if (!wsClient.isOpen()) { + throw new IllegalStateException("Translation WebSocket failed to open within timeout"); + } + + if (!connectedLatch.await(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + throw new IllegalStateException( + "Translation WebSocket connected but session.created not received"); + } + + sendSessionUpdate(); + logger.info( + "Azure Realtime Translate connection established (target language={}).", + config.translateTargetLanguage()); + } + + private void sendSessionUpdate() { + JSONObject event = new JSONObject(); + event.put("type", "session.update"); + + JSONObject session = new JSONObject(); + JSONObject audio = new JSONObject(); + JSONObject output = new JSONObject(); + output.put("language", config.translateTargetLanguage()); + audio.put("output", output); + session.put("audio", audio); + + event.put("session", session); + sendMessage(event.toString()); + logger.info( + "Sent translation session.update with language={}", config.translateTargetLanguage()); + } + + private void handleMessage(String json) { + if (closed.get()) { + return; + } + + try { + JSONObject event = new JSONObject(json); + String eventType = event.optString("type", ""); + + logger.debug("Translate WS event: {}", eventType); + + switch (eventType) { + case "session.created": + logger.info( + "Translation session created: {}", + event.optJSONObject("session") != null + ? event.optJSONObject("session").optString("id", "unknown") + : "unknown"); + connectedLatch.countDown(); + break; + + case "session.updated": + logger.info("Translation session updated."); + break; + + case "session.output_audio.delta": + handleOutputAudioDelta(event); + break; + + case "session.output_transcript.delta": + handleOutputTranscriptDelta(event); + break; + + case "session.input_transcript.delta": + handleInputTranscriptDelta(event); + break; + + case "session.closed": + logger.info("Translation session closed by server."); + activeCloseComplete(); + break; + + case "error": + handleErrorEvent(event); + break; + + default: + logger.trace("Unhandled translation event type: {}", eventType); + break; + } + } catch (JSONException e) { + logger.warn("Failed to parse translation WebSocket message: {}", json, e); + } + } + + private void handleOutputAudioDelta(JSONObject event) { + String base64Audio = event.optString("delta", ""); + if (base64Audio.isEmpty()) { + return; + } + try { + byte[] audioBytes = Base64.getDecoder().decode(base64Audio); + Blob audioBlob = Blob.builder().mimeType("audio/pcm").data(audioBytes).build(); + responseProcessor.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.builder().inlineData(audioBlob).build())) + .build()) + .partial(true) + .build()); + } catch (IllegalArgumentException e) { + logger.warn("Failed to decode translation audio delta", e); + } + } + + private void handleOutputTranscriptDelta(JSONObject event) { + String delta = event.optString("delta", ""); + if (!delta.isEmpty()) { + outputTranscriptHadDelta.set(true); + responseProcessor.onNext( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText(delta)).build()) + .partial(true) + .build()); + } + } + + private void handleInputTranscriptDelta(JSONObject event) { + String delta = event.optString("delta", ""); + if (!delta.isEmpty()) { + responseProcessor.onNext( + LlmResponse.builder() + .inputTranscription(Transcription.builder().text(delta).finished(false).build()) + .build()); + } + } + + private void handleErrorEvent(JSONObject event) { + JSONObject error = event.optJSONObject("error"); + String message = error != null ? error.optString("message", "Unknown error") : "Unknown error"; + logger.error("Realtime Translate API error: {}", message); + responseProcessor.onNext(LlmResponse.builder().errorMessage(message).build()); + } + + private void activeCloseComplete() { + if (!closed.get()) { + responseProcessor.onNext(LlmResponse.builder().turnComplete(true).build()); + } + } + + @Override + public Completable sendHistory(List history) { + return Completable.complete(); + } + + @Override + public Completable sendContent(Content content) { + return Completable.complete(); + } + + @Override + public Completable sendRealtime(Blob blob) { + return Completable.fromAction( + () -> { + if (closed.get()) { + throw new IllegalStateException("Connection is closed"); + } + Objects.requireNonNull(blob, "blob cannot be null"); + + byte[] audioData = blob.data().orElse(new byte[0]); + if (audioData.length == 0) { + return; + } + + String base64Audio = Base64.getEncoder().encodeToString(audioData); + JSONObject event = new JSONObject(); + event.put("type", "session.input_audio_buffer.append"); + event.put("audio", base64Audio); + sendMessage(event.toString()); + }); + } + + @Override + public Completable clearRealtimeAudioBuffer() { + return Completable.complete(); + } + + /** Gracefully closes the translation session and flushes pending output. */ + public Completable closeTranslationSession() { + return Completable.fromAction( + () -> { + if (closed.get() || sessionClosing.getAndSet(true)) { + return; + } + JSONObject event = new JSONObject(); + event.put("type", "session.close"); + sendMessage(event.toString()); + logger.info("Sent session.close for translation."); + }); + } + + @Override + public Flowable receive() { + return responseFlowable; + } + + @Override + public void close() { + closeInternal(null); + } + + @Override + public void close(Throwable throwable) { + Objects.requireNonNull(throwable, "throwable cannot be null"); + closeInternal(throwable); + } + + private void sendMessage(String json) { + if (wsClient == null || !wsClient.isOpen()) { + logger.warn("Translation WebSocket is not open, cannot send message."); + return; + } + try { + wsClient.send(json); + logger.trace( + "Sent over translation WebSocket: {} bytes", + json.getBytes(StandardCharsets.UTF_8).length); + } catch (Exception e) { + logger.error("Failed to send over translation WebSocket", e); + } + } + + private void closeInternal(Throwable throwable) { + if (closed.compareAndSet(false, true)) { + logger.info("Closing AzureRealtimeTranslateLlmConnection."); + + if (throwable == null) { + responseProcessor.onComplete(); + } else { + responseProcessor.onError(throwable); + } + + try { + if (wsClient != null && wsClient.isOpen()) { + if (!sessionClosing.get()) { + try { + JSONObject event = new JSONObject(); + event.put("type", "session.close"); + wsClient.send(event.toString()); + } catch (Exception e) { + logger.debug("session.close on shutdown failed: {}", e.getMessage()); + } + } + wsClient.closeBlocking(); + wsClient = null; + } + } catch (Exception e) { + logger.warn("Error closing translation WebSocket", e); + } + } + } + + private class TranslateWebSocketClient extends WebSocketClient { + + TranslateWebSocketClient(URI uri, String apiKey) { + super(uri); + addHeader("api-key", apiKey); + } + + @Override + public void onOpen(ServerHandshake handshake) { + logger.info("Translation WebSocket opened (status: {})", handshake.getHttpStatus()); + } + + @Override + public void onMessage(String message) { + handleMessage(message); + } + + @Override + public void onClose(int code, String reason, boolean remote) { + logger.info( + "Translation WebSocket closed: code={}, reason={}, remote={}", code, reason, remote); + if (!closed.get()) { + closeInternal( + new IllegalStateException( + "Translation WebSocket closed unexpectedly: " + code + " " + reason)); + } + } + + @Override + public void onError(Exception ex) { + logger.error("Translation WebSocket error", ex); + if (!closed.get()) { + closeInternal(ex); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateTransport.java b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateTransport.java new file mode 100644 index 000000000..68b5fc114 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTranslateTransport.java @@ -0,0 +1,33 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import io.reactivex.rxjava3.core.Flowable; + +/** + * Azure transport for GPT Realtime Translate ({@code gpt-realtime-translate}). + * + *

Uses the {@code /openai/v1/realtime/translations} WebSocket endpoint and continuous + * translation events — not the bidirectional voice-agent protocol. + */ +public final class AzureRealtimeTranslateTransport implements AzureTransport { + + @Override + public boolean supports(String modelName) { + return modelName != null && modelName.toLowerCase().contains("realtime-translate"); + } + + @Override + public BaseLlmConnection connect(LlmRequest request, AzureConfig config) { + return new AzureRealtimeTranslateLlmConnection(config, request); + } + + @Override + public Flowable generateContent( + LlmRequest request, AzureConfig config, boolean stream) { + return Flowable.error( + new UnsupportedOperationException( + "gpt-realtime-translate requires a live WebSocket connection; use connect() instead.")); + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTransport.java b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTransport.java new file mode 100644 index 000000000..e2e2eff80 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRealtimeTransport.java @@ -0,0 +1,75 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Azure transport implementation for the WebSocket-based Realtime API. + * + *

Handles bidirectional audio/text streaming via persistent WebSocket connections. For + * non-realtime models, see {@link AzureRestTransport}. + */ +public final class AzureRealtimeTransport implements AzureTransport { + + private static final Logger logger = LoggerFactory.getLogger(AzureRealtimeTransport.class); + + @Override + public boolean supports(String modelName) { + return com.google.adk.models.AzureBaseLM.isRealtimeModel(modelName); + } + + @Override + public BaseLlmConnection connect(LlmRequest request, AzureConfig config) { + return new AzureRealtimeLlmConnection(config, request); + } + + /** + * For realtime models, {@code generateContent} is not the primary interaction mode. This provides + * a minimal fallback that opens a short-lived WebSocket, sends the last user content, and + * collects responses. + */ + @Override + public Flowable generateContent( + LlmRequest request, AzureConfig config, boolean stream) { + return Flowable.create( + emitter -> { + AzureRealtimeLlmConnection conn = null; + try { + conn = new AzureRealtimeLlmConnection(config, request); + + conn.receive() + .doOnNext(emitter::onNext) + .doOnError(emitter::onError) + .doOnComplete(emitter::onComplete) + .subscribe(); + + Optional lastUserContent = + request.contents().isEmpty() + ? Optional.empty() + : Optional.of(request.contents().get(request.contents().size() - 1)); + + if (lastUserContent.isPresent()) { + conn.sendContent(lastUserContent.get()).blockingAwait(); + } else { + conn.sendContent(Content.fromParts(Part.fromText(""))).blockingAwait(); + } + } catch (Exception e) { + logger.error("Error in AzureRealtimeTransport.generateContent", e); + if (!emitter.isCancelled()) { + emitter.onError(e); + } + if (conn != null) { + conn.close(e); + } + } + }, + io.reactivex.rxjava3.core.BackpressureStrategy.BUFFER); + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRequestConverter.java b/core/src/main/java/com/google/adk/models/azure/AzureRequestConverter.java new file mode 100644 index 000000000..99abb83f4 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRequestConverter.java @@ -0,0 +1,148 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.LlmRequest; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Schema; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.json.JSONArray; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Shared request conversion utilities for all Azure transports. + * + *

Consolidates duplicated logic that was previously in both {@code AzureBaseLM} and {@code + * AzureRealtimeLlmConnection}: instruction extraction, tool schema conversion, and schema-to-JSON + * mapping. + */ +public final class AzureRequestConverter { + + private static final Logger logger = LoggerFactory.getLogger(AzureRequestConverter.class); + + private static final String FORBIDDEN_CHARACTERS_REGEX = "[^a-zA-Z0-9_\\.-]"; + + private AzureRequestConverter() {} + + /** + * Extracts system instructions from the LlmRequest config. + * + * @return combined system instruction text, or empty string if none + */ + public static String extractInstructions(LlmRequest llmRequest) { + return llmRequest + .config() + .flatMap(GenerateContentConfig::systemInstruction) + .flatMap(Content::parts) + .map( + parts -> + parts.stream() + .filter(p -> p.text().isPresent()) + .map(p -> p.text().get()) + .collect(Collectors.joining("\n"))) + .filter(text -> !text.isEmpty()) + .orElse(""); + } + + /** + * Builds a JSON array of tool definitions from the LlmRequest tools map. + * + *

Uses {@code llmRequest.tools()} (Map of BaseTool) as the single source of truth for all + * transports. Output format matches Azure/OpenAI function tool schema. + * + * @return JSONArray of tool objects, may be empty + */ + public static JSONArray buildTools(LlmRequest llmRequest) { + JSONArray tools = new JSONArray(); + + llmRequest + .tools() + .forEach( + (name, baseTool) -> { + Optional declOpt = baseTool.declaration(); + if (declOpt.isEmpty()) { + logger.warn("Skipping tool '{}' with missing declaration.", baseTool.name()); + return; + } + + FunctionDeclaration decl = declOpt.get(); + if (decl.name().isEmpty() || decl.name().get().isBlank()) { + logger.warn("Skipping function declaration without a name"); + return; + } + + JSONObject toolObj = new JSONObject(); + toolObj.put("type", "function"); + toolObj.put("name", cleanForIdentifier(decl.name().get())); + toolObj.put("description", decl.description().orElse("")); + toolObj.put( + "parameters", + decl.parameters() + .map(AzureRequestConverter::schemaToJson) + .orElseGet( + () -> + new JSONObject() + .put("type", "object") + .put("properties", new JSONObject()))); + + tools.put(toolObj); + }); + + return tools; + } + + /** + * Recursively converts a {@link Schema} to a JSON object suitable for the OpenAI/Azure tool + * parameter format. + */ + public static JSONObject schemaToJson(Schema schema) { + JSONObject obj = new JSONObject(); + schema + .type() + .ifPresent(type -> obj.put("type", type.knownEnum().name().toLowerCase(Locale.ROOT))); + schema.description().ifPresent(desc -> obj.put("description", desc)); + + schema + .properties() + .ifPresent( + props -> { + JSONObject propsObj = new JSONObject(); + for (Map.Entry entry : props.entrySet()) { + propsObj.put(entry.getKey(), schemaToJson(entry.getValue())); + } + obj.put("properties", propsObj); + }); + + schema.required().ifPresent(req -> obj.put("required", new JSONArray(req))); + schema.items().ifPresent(items -> obj.put("items", schemaToJson(items))); + + schema + .enum_() + .ifPresent( + enums -> { + JSONArray enumArr = new JSONArray(); + for (String e : enums) { + enumArr.put(e); + } + obj.put("enum", enumArr); + }); + + return obj; + } + + /** + * Sanitizes a string for use as a function/tool identifier by removing forbidden characters. + * Allows: {@code [a-zA-Z0-9_.-]} + */ + public static String cleanForIdentifier(String input) { + if (input == null) { + return null; + } + return input.replaceAll(FORBIDDEN_CHARACTERS_REGEX, ""); + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureRestTransport.java b/core/src/main/java/com/google/adk/models/azure/AzureRestTransport.java new file mode 100644 index 000000000..d6b37e35b --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureRestTransport.java @@ -0,0 +1,796 @@ +package com.google.adk.models.azure; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.GenericLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Azure transport implementation for the HTTP-based Responses API. + * + *

Handles both non-streaming and SSE streaming requests to Azure OpenAI. + */ +public final class AzureRestTransport implements AzureTransport { + + private static final Logger logger = LoggerFactory.getLogger(AzureRestTransport.class); + + private static final int CONNECT_TIMEOUT_SECONDS = 60; + private static final int READ_TIMEOUT_SECONDS = 180; + + private static final String CONTINUE_OUTPUT_MESSAGE = + "Continue output. DO NOT look at this line. ONLY look at the content before this line and" + + " system instruction."; + + private static final HttpClient httpClient = + HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .connectTimeout(Duration.ofSeconds(CONNECT_TIMEOUT_SECONDS)) + .build(); + + @Override + public boolean supports(String modelName) { + if (modelName == null) return false; + return !modelName.toLowerCase().contains("realtime"); + } + + @Override + public Flowable generateContent( + LlmRequest request, AzureConfig config, boolean stream) { + return stream ? generateContentStream(request, config) : generateContentSync(request, config); + } + + @Override + public BaseLlmConnection connect(LlmRequest request, AzureConfig config) { + BaseLlm proxy = + new BaseLlm(config.modelName()) { + @Override + public Flowable generateContent(LlmRequest req, boolean stream) { + return AzureRestTransport.this.generateContent(req, config, stream); + } + + @Override + public BaseLlmConnection connect(LlmRequest req) { + throw new UnsupportedOperationException("Nested connect not supported"); + } + }; + return new GenericLlmConnection(proxy, request); + } + + // ==================== Non-streaming ==================== + + private Flowable generateContentSync(LlmRequest llmRequest, AzureConfig config) { + List contents = ensureLastContentIsUser(llmRequest.contents()); + String instructions = AzureRequestConverter.extractInstructions(llmRequest); + JSONArray inputItems = buildInputItems(contents); + JSONArray tools = AzureRequestConverter.buildTools(llmRequest); + + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); + + Optional temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); + + JSONObject payload = new JSONObject(); + payload.put("model", config.modelName()); + payload.put("input", inputItems); + if (!instructions.isEmpty()) { + payload.put("instructions", instructions); + } + temperature.ifPresent(t -> payload.put("temperature", t)); + payload.put("stream", false); + payload.put("store", false); + payload.put("reasoning", new JSONObject().put("summary", "auto")); + if (maxTokens.isPresent() && maxTokens.get() > 0) { + payload.put("max_output_tokens", maxTokens.get()); + } + if (!lastRespToolExecuted && tools.length() > 0) { + payload.put("tools", tools); + } + + logger.debug("Azure Responses API request payload size: {} bytes", payload.toString().length()); + + JSONObject response = callApi(payload, config); + + if (response.has("error") && !response.isNull("error")) { + logger.error("Azure Responses API error: {}", response); + return Flowable.just( + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build()); + } + + GenerateContentResponseUsageMetadata usageMetadata = extractUsageMetadata(response); + LlmResponse llmResponse = parseOutputToLlmResponse(response, usageMetadata); + return Flowable.just(llmResponse); + } + + // ==================== Streaming ==================== + + private Flowable generateContentStream(LlmRequest llmRequest, AzureConfig config) { + List contents = ensureLastContentIsUser(llmRequest.contents()); + String instructions = AzureRequestConverter.extractInstructions(llmRequest); + JSONArray inputItems = buildInputItems(contents); + JSONArray tools = AzureRequestConverter.buildTools(llmRequest); + + boolean lastRespToolExecuted = + Iterables.getLast(Iterables.getLast(contents).parts().get()).functionResponse().isPresent(); + + Optional temperature = llmRequest.config().flatMap(GenerateContentConfig::temperature); + Optional maxTokens = + llmRequest.config().flatMap(GenerateContentConfig::maxOutputTokens); + + JSONObject payload = new JSONObject(); + payload.put("model", config.modelName()); + payload.put("input", inputItems); + if (!instructions.isEmpty()) { + payload.put("instructions", instructions); + } + temperature.ifPresent(t -> payload.put("temperature", t)); + payload.put("stream", true); + payload.put("store", false); + payload.put("reasoning", new JSONObject().put("summary", "auto")); + if (maxTokens.isPresent() && maxTokens.get() > 0) { + payload.put("max_output_tokens", maxTokens.get()); + } + if (!lastRespToolExecuted && tools.length() > 0) { + payload.put("tools", tools); + } + + final StringBuilder accumulatedText = new StringBuilder(); + final StringBuilder reasoningSummary = new StringBuilder(); + final StringBuilder functionCallName = new StringBuilder(); + final StringBuilder functionCallCallId = new StringBuilder(); + final StringBuilder functionCallArgs = new StringBuilder(); + final AtomicBoolean inFunctionCall = new AtomicBoolean(false); + final AtomicBoolean finalTextEmitted = new AtomicBoolean(false); + final AtomicInteger inputTokens = new AtomicInteger(0); + final AtomicInteger outputTokens = new AtomicInteger(0); + + logger.debug("Starting streaming request for model: {}", config.modelName()); + logger.debug("Streaming payload size: {} bytes", payload.toString().length()); + + return Flowable.create( + emitter -> { + BufferedReader reader = null; + try { + logger.debug("Opening SSE connection..."); + reader = callApiStream(payload, config); + if (reader == null) { + logger.warn("Azure SSE reader is null — stream failed to open."); + emitter.onComplete(); + return; + } + logger.debug("SSE connection opened successfully."); + long streamStartMs = System.currentTimeMillis(); + int chunkCount = 0; + + String lastEventName = null; + String line; + while ((line = reader.readLine()) != null) { + if (emitter.isCancelled()) { + logger.debug("Emitter cancelled, breaking out of read loop."); + break; + } + + logger.debug( + "SSE raw: {}", line.length() > 200 ? line.substring(0, 200) + "..." : line); + + if (line.isEmpty()) continue; + if (line.startsWith("event:")) { + lastEventName = line.substring(6).trim(); + continue; + } + if (!line.startsWith("data:")) continue; + + String jsonStr = line.substring(5).trim(); + if (jsonStr.equals("[DONE]")) { + long elapsed = System.currentTimeMillis() - streamStartMs; + logger.debug( + "[DONE] marker received after {}ms, total chunks: {}", elapsed, chunkCount); + break; + } + + chunkCount++; + JSONObject event; + try { + event = new JSONObject(jsonStr); + } catch (JSONException e) { + logger.warn("Failed to parse SSE chunk #{}: {}", chunkCount, jsonStr); + continue; + } + + String eventType = event.optString("type", ""); + if (eventType.isEmpty() && lastEventName != null) { + eventType = lastEventName; + } + lastEventName = null; + + logger.debug( + "SSE chunk #{} eventType='{}' keys={}", chunkCount, eventType, event.keySet()); + + switch (eventType) { + case "response.output_item.added": + { + JSONObject item = event.optJSONObject("item"); + if (item == null) break; + String itemType = item.optString("type", ""); + if ("function_call".equals(itemType)) { + inFunctionCall.set(true); + String name = item.optString("name", ""); + String callId = item.optString("call_id", ""); + logger.debug("Function call starting: name='{}' callId='{}'", name, callId); + if (!name.isEmpty()) functionCallName.append(name); + if (!callId.isEmpty()) functionCallCallId.append(callId); + } else if ("reasoning".equals(itemType)) { + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText("\ud83e\udde0 Thinking...\n")) + .build()) + .partial(true) + .build()); + } + break; + } + + case "response.reasoning_summary_text.delta": + { + String delta = event.optString("delta", ""); + if (!delta.isEmpty()) { + reasoningSummary.append(delta); + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(delta)) + .build()) + .partial(true) + .build()); + } + break; + } + + case "response.reasoning_summary_text.done": + { + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText("\n\n")) + .build()) + .partial(true) + .build()); + break; + } + + case "response.output_text.delta": + { + String delta = extractTextDeltaFromStreamEvent(event); + if (!delta.isEmpty()) { + accumulatedText.append(delta); + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(delta)) + .build()) + .partial(true) + .build()); + } + break; + } + + case "response.output_text.done": + { + String fullText = event.optString("text", ""); + if (!fullText.isEmpty()) { + accumulatedText.setLength(0); + accumulatedText.append(fullText); + finalTextEmitted.set(true); + String finalContent = fullText; + if (reasoningSummary.length() > 0) { + finalContent = + "\ud83e\udde0 **Thinking:**\n> " + + reasoningSummary.toString().replace("\n", "\n> ") + + "\n\n" + + fullText; + } + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(finalContent)) + .build()) + .partial(false) + .build()); + } + break; + } + + case "response.output_item.done": + { + if (finalTextEmitted.get()) break; + JSONObject item = event.optJSONObject("item"); + if (item != null && "message".equals(item.optString("type"))) { + String fullText = extractTextFromOutputMessageItem(item); + if (!fullText.isEmpty()) { + accumulatedText.setLength(0); + accumulatedText.append(fullText); + finalTextEmitted.set(true); + String finalContent = fullText; + if (reasoningSummary.length() > 0) { + finalContent = + "\ud83e\udde0 **Thinking:**\n> " + + reasoningSummary.toString().replace("\n", "\n> ") + + "\n\n" + + fullText; + } + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(finalContent)) + .build()) + .partial(false) + .build()); + } + } + break; + } + + case "response.function_call_arguments.delta": + { + String delta = extractTextDeltaFromStreamEvent(event); + if (!delta.isEmpty()) { + functionCallArgs.append(delta); + } + break; + } + + case "response.function_call_arguments.done": + { + if (functionCallName.length() > 0) { + String argsStr = + functionCallArgs.length() > 0 ? functionCallArgs.toString() : "{}"; + Map args; + try { + args = new JSONObject(argsStr).toMap(); + } catch (JSONException e) { + logger.warn("Failed to parse function args: {}", argsStr); + args = Map.of(); + } + FunctionCall fc = + FunctionCall.builder() + .name(functionCallName.toString()) + .args(args) + .build(); + emitter.onNext( + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts( + ImmutableList.of(Part.builder().functionCall(fc).build())) + .build()) + .partial(false) + .build()); + } + break; + } + + case "response.completed": + { + JSONObject resp = event.optJSONObject("response"); + if (resp != null) { + JSONObject usage = resp.optJSONObject("usage"); + if (usage != null) { + inputTokens.set(usage.optInt("input_tokens", 0)); + outputTokens.set(usage.optInt("output_tokens", 0)); + logger.debug( + "Stream token usage — input: {}, output: {}", + inputTokens.get(), + outputTokens.get()); + } + } + break; + } + + default: + break; + } + } + + long totalElapsed = System.currentTimeMillis() - streamStartMs; + logger.debug( + "Stream read loop finished — elapsed: {}ms, chunks: {}, accumulatedText: {} chars," + + " finalTextEmitted: {}, inFunctionCall: {}", + totalElapsed, + chunkCount, + accumulatedText.length(), + finalTextEmitted.get(), + inFunctionCall.get()); + + if (!emitter.isCancelled()) { + if (!finalTextEmitted.get()) { + emitFinalStreamResponse( + emitter, + accumulatedText, + inFunctionCall, + functionCallName, + functionCallArgs, + inputTokens.get(), + outputTokens.get()); + } + emitter.onComplete(); + } + } catch (IOException e) { + logger.error("IOException in Azure stream", e); + if (!emitter.isCancelled()) emitter.onError(e); + } catch (Exception e) { + logger.error("Error in Azure streaming", e); + if (!emitter.isCancelled()) emitter.onError(e); + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + logger.error("Error closing stream reader", e); + } + } + } + }, + io.reactivex.rxjava3.core.BackpressureStrategy.BUFFER); + } + + // ==================== Helpers ==================== + + private static String extractTextDeltaFromStreamEvent(JSONObject event) { + if (event == null || event.isNull("delta")) { + return ""; + } + Object delta = event.opt("delta"); + if (delta instanceof String) { + return (String) delta; + } + if (delta instanceof JSONObject) { + JSONObject o = (JSONObject) delta; + return o.optString("text", o.optString("content", "")); + } + return ""; + } + + private static String extractTextFromOutputMessageItem(JSONObject messageItem) { + JSONArray content = messageItem.optJSONArray("content"); + if (content == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < content.length(); i++) { + JSONObject part = content.optJSONObject(i); + if (part == null) continue; + String pType = part.optString("type", ""); + if ("output_text".equals(pType) || "text".equals(pType)) { + sb.append(part.optString("text", "")); + } + } + return sb.toString(); + } + + private void emitFinalStreamResponse( + io.reactivex.rxjava3.core.Emitter emitter, + StringBuilder accumulatedText, + AtomicBoolean inFunctionCall, + StringBuilder functionCallName, + StringBuilder functionCallArgs, + int promptTokens, + int completionTokens) { + + GenerateContentResponseUsageMetadata usageMetadata = + buildUsageMetadata(promptTokens, completionTokens); + + if (inFunctionCall.get() && functionCallName.length() > 0) { + return; + } + + if (accumulatedText.length() > 0) { + LlmResponse.Builder builder = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(accumulatedText.toString())) + .build()) + .partial(false); + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); + } + emitter.onNext(builder.build()); + } + } + + private List ensureLastContentIsUser(List contents) { + if (contents.isEmpty() || !Iterables.getLast(contents).role().orElse("").equals("user")) { + Content userContent = Content.fromParts(Part.fromText(CONTINUE_OUTPUT_MESSAGE)); + return Stream.concat(contents.stream(), Stream.of(userContent)).collect(toImmutableList()); + } + return contents; + } + + private JSONArray buildInputItems(List contents) { + JSONArray items = new JSONArray(); + + for (Content item : contents) { + String role = item.role().orElse("user"); + List parts = item.parts().orElse(ImmutableList.of()); + + if (parts.isEmpty()) { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); + msg.put("content", item.text()); + items.put(msg); + continue; + } + + Part firstPart = parts.get(0); + + if (firstPart.functionResponse().isPresent()) { + JSONObject output = new JSONObject(); + output.put("type", "function_call_output"); + output.put( + "call_id", "call_" + firstPart.functionResponse().get().name().orElse("unknown")); + output.put( + "output", + new JSONObject(firstPart.functionResponse().get().response().get()).toString()); + items.put(output); + } else if (firstPart.functionCall().isPresent()) { + FunctionCall fc = firstPart.functionCall().get(); + JSONObject fcItem = new JSONObject(); + fcItem.put("type", "function_call"); + fcItem.put("call_id", "call_" + fc.name().orElse("unknown")); + fcItem.put("name", fc.name().orElse("")); + fcItem.put("arguments", new JSONObject(fc.args().orElse(Map.of())).toString()); + items.put(fcItem); + } else { + JSONObject msg = new JSONObject(); + msg.put("role", role.equals("model") ? "assistant" : role); + msg.put("content", item.text()); + items.put(msg); + } + } + return items; + } + + // ==================== HTTP transport ==================== + + private JSONObject callApi(JSONObject payload, AzureConfig config) { + try { + String jsonString = payload.toString(); + + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(config.responseEndpoint())) + .header("Content-Type", "application/json; charset=UTF-8") + .header("api-key", config.apiKey()) + .timeout(Duration.ofSeconds(READ_TIMEOUT_SECONDS)) + .POST(HttpRequest.BodyPublishers.ofString(jsonString, StandardCharsets.UTF_8)) + .build(); + + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + + int statusCode = response.statusCode(); + logger.info("Azure Responses API status: {} for model: {}", statusCode, config.modelName()); + + if (statusCode >= 200 && statusCode < 300) { + return new JSONObject(response.body()); + } else { + logger.error("Azure API error: status={} body={}", statusCode, response.body()); + try { + return new JSONObject(response.body()); + } catch (JSONException e) { + return new JSONObject().put("error", response.body()); + } + } + } catch (IOException | InterruptedException ex) { + logger.error("HTTP request failed for Azure Responses API", ex); + return new JSONObject().put("error", ex.getMessage()); + } + } + + private BufferedReader callApiStream(JSONObject payload, AzureConfig config) { + try { + String jsonString = payload.toString(); + + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(config.responseEndpoint())) + .header("Content-Type", "application/json; charset=UTF-8") + .header("api-key", config.apiKey()) + .header("Accept", "text/event-stream") + .timeout(Duration.ofSeconds(READ_TIMEOUT_SECONDS)) + .POST(HttpRequest.BodyPublishers.ofString(jsonString, StandardCharsets.UTF_8)) + .build(); + + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + int statusCode = response.statusCode(); + logger.info( + "Azure Responses API streaming status: {} for model: {}", statusCode, config.modelName()); + + if (statusCode >= 200 && statusCode < 300) { + return new BufferedReader(new InputStreamReader(response.body(), StandardCharsets.UTF_8)); + } else { + try (BufferedReader errorReader = + new BufferedReader(new InputStreamReader(response.body(), StandardCharsets.UTF_8))) { + StringBuilder errorBody = new StringBuilder(); + String errorLine; + while ((errorLine = errorReader.readLine()) != null) { + errorBody.append(errorLine); + } + logger.error("Azure streaming failed: status={} body={}", statusCode, errorBody); + } + return null; + } + } catch (IOException | InterruptedException ex) { + logger.error("HTTP request failed for Azure streaming", ex); + return null; + } + } + + // ==================== Response parsing ==================== + + private LlmResponse parseOutputToLlmResponse( + JSONObject response, GenerateContentResponseUsageMetadata usageMetadata) { + + JSONArray output = response.optJSONArray("output"); + if (output == null || output.length() == 0) { + logger.warn("Azure Responses API returned empty output: {}", response); + return LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build(); + } + + List parts = new ArrayList<>(); + + for (int i = 0; i < output.length(); i++) { + JSONObject item = output.getJSONObject(i); + String type = item.optString("type", ""); + + switch (type) { + case "message": + { + JSONArray content = item.optJSONArray("content"); + if (content != null) { + for (int j = 0; j < content.length(); j++) { + JSONObject contentItem = content.getJSONObject(j); + if ("output_text".equals(contentItem.optString("type"))) { + parts.add(Part.fromText(contentItem.optString("text", ""))); + } + } + } + break; + } + + case "function_call": + { + String name = item.optString("name", null); + String argsStr = item.optString("arguments", "{}"); + if (name != null) { + Map args; + try { + args = new JSONObject(argsStr).toMap(); + } catch (JSONException e) { + logger.warn("Failed to parse function arguments: {}", argsStr); + args = Map.of(); + } + FunctionCall fc = FunctionCall.builder().name(name).args(args).build(); + parts.add(Part.builder().functionCall(fc).build()); + } + break; + } + + default: + break; + } + } + + if (parts.isEmpty()) { + parts.add(Part.fromText("")); + } + + boolean hasFunctionCall = parts.stream().anyMatch(p -> p.functionCall().isPresent()); + + LlmResponse.Builder builder = LlmResponse.builder(); + if (hasFunctionCall) { + Part fcPart = parts.stream().filter(p -> p.functionCall().isPresent()).findFirst().get(); + builder.content(Content.builder().role("model").parts(ImmutableList.of(fcPart)).build()); + } else { + builder.content(Content.builder().role("model").parts(ImmutableList.copyOf(parts)).build()); + } + + if (usageMetadata != null) { + builder.usageMetadata(usageMetadata); + } + + return builder.build(); + } + + private GenerateContentResponseUsageMetadata extractUsageMetadata(JSONObject response) { + if (response == null || !response.has("usage")) { + return null; + } + try { + JSONObject usage = response.getJSONObject("usage"); + int inputTok = usage.optInt("input_tokens", 0); + int outputTok = usage.optInt("output_tokens", 0); + int totalTok = usage.optInt("total_tokens", inputTok + outputTok); + + if (totalTok > 0 || inputTok > 0 || outputTok > 0) { + logger.info( + "Azure token usage: input={}, output={}, total={}", inputTok, outputTok, totalTok); + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(inputTok) + .candidatesTokenCount(outputTok) + .totalTokenCount(totalTok) + .build(); + } + } catch (Exception e) { + logger.warn("Failed to parse token usage from Azure response", e); + } + return null; + } + + private GenerateContentResponseUsageMetadata buildUsageMetadata(int inputTok, int outputTok) { + int totalTok = inputTok + outputTok; + if (totalTok > 0 || inputTok > 0 || outputTok > 0) { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(inputTok) + .candidatesTokenCount(outputTok) + .totalTokenCount(totalTok) + .build(); + } + return null; + } +} diff --git a/core/src/main/java/com/google/adk/models/azure/AzureTransport.java b/core/src/main/java/com/google/adk/models/azure/AzureTransport.java new file mode 100644 index 000000000..970d6bd16 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/azure/AzureTransport.java @@ -0,0 +1,38 @@ +package com.google.adk.models.azure; + +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import io.reactivex.rxjava3.core.Flowable; + +/** + * Strategy interface for Azure LLM transport protocols. + * + *

Each implementation handles a specific Azure API surface (REST Responses API, WebSocket + * Realtime API, etc.) while sharing common configuration and request conversion via {@link + * AzureConfig} and {@link AzureRequestConverter}. + */ +public interface AzureTransport { + + /** Returns true if this transport can handle the given model name. */ + boolean supports(String modelName); + + /** + * Generates content using this transport's protocol. + * + * @param request the ADK LLM request + * @param config shared Azure configuration + * @param stream whether to stream the response + * @return a Flowable of LLM responses + */ + Flowable generateContent(LlmRequest request, AzureConfig config, boolean stream); + + /** + * Opens a persistent bidirectional connection using this transport's protocol. + * + * @param request the ADK LLM request (tools, instructions, etc.) + * @param config shared Azure configuration + * @return a live connection + */ + BaseLlmConnection connect(LlmRequest request, AzureConfig config); +} diff --git a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java index d031572aa..5577a8a47 100644 --- a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java @@ -154,14 +154,7 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { @Test public void convertToServerResponse_withUsageMetadata_returnsEmpty() { LiveServerMessage message = - LiveServerMessage.builder() - .usageMetadata( - UsageMetadata.builder() - .promptTokenCount(10) - .responseTokenCount(20) - .totalTokenCount(30) - .build()) - .build(); + LiveServerMessage.builder().usageMetadata(UsageMetadata.builder().build()).build(); Optional result = GeminiLlmConnection.convertToServerResponse(message);