diff --git a/packages/semantic-cache-py/betterdb_semantic_cache/embed/google.py b/packages/semantic-cache-py/betterdb_semantic_cache/embed/google.py new file mode 100644 index 00000000..7eda1bd8 --- /dev/null +++ b/packages/semantic-cache-py/betterdb_semantic_cache/embed/google.py @@ -0,0 +1,99 @@ +"""Google AI (Gemini) embedding helper for betterdb-semantic-cache. + +Uses the Google AI REST API directly via httpx. +Requires the 'httpx' extra: pip install betterdb-semantic-cache[httpx] + +Usage:: + + from betterdb_semantic_cache.embed.google import create_google_embed + embed = create_google_embed(model="text-embedding-004") + cache = SemanticCache(SemanticCacheOptions(client=client, embed_fn=embed)) +""" +from __future__ import annotations + +import os +from typing import Any, Literal + +from betterdb_semantic_cache.types import EmbedFn + +GoogleEmbedTaskType = Literal[ + "RETRIEVAL_QUERY", + "RETRIEVAL_DOCUMENT", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING", +] + + +def create_google_embed( + *, + model: str = "text-embedding-004", + api_key: str | None = None, + base_url: str = "https://generativelanguage.googleapis.com/v1beta", + task_type: GoogleEmbedTaskType = "RETRIEVAL_QUERY", + title: str | None = None, + output_dimensionality: int | None = None, +) -> EmbedFn: + """Create an EmbedFn backed by the Google AI (Gemini) Embeddings API. + + Args: + model: Google AI embedding model. Default: 'text-embedding-004' (768-dim). + Other options: 'text-multilingual-embedding-002', 'embedding-001'. + api_key: Google AI API key. Default: GOOGLE_API_KEY env var. + base_url: API base URL. + task_type: Task type hint. Default: 'RETRIEVAL_QUERY'. + Use 'RETRIEVAL_DOCUMENT' when storing documents. + title: Optional document title. Only used with task_type='RETRIEVAL_DOCUMENT'. + output_dimensionality: Optional output dimensionality (truncation). + Supported by text-embedding-004+. + + When finished, release the connection pool:: + + await embed.close() + """ + _client: list[Any] = [] + + async def _get_client() -> Any: + if not _client: + try: + import httpx + except ImportError: + raise ImportError( + 'betterdb-semantic-cache embed/google requires the "httpx" package. ' + "Install it: pip install betterdb-semantic-cache[httpx]" + ) + _client.append(httpx.AsyncClient(timeout=30)) + return _client[0] + + async def embed(text: str) -> list[float]: + key = api_key or os.environ.get("GOOGLE_API_KEY") + if not key: + raise ValueError( + "Google API key is required. Set GOOGLE_API_KEY env var or pass api_key." + ) + client = await _get_client() + body: dict[str, Any] = { + "model": f"models/{model}", + "content": {"parts": [{"text": text}]}, + "taskType": task_type, + } + if title is not None: + body["title"] = title + if output_dimensionality is not None: + body["outputDimensionality"] = output_dimensionality + + resp = await client.post( + f"{base_url}/models/{model}:embedContent", + headers={"Content-Type": "application/json", "x-goog-api-key": key}, + json=body, + ) + resp.raise_for_status() + return resp.json().get("embedding", {}).get("values") or [] + + async def close() -> None: + if _client: + await _client[0].aclose() + _client.clear() + + embed.close = close # type: ignore[attr-defined] + return embed diff --git a/packages/semantic-cache/src/embed/google.ts b/packages/semantic-cache/src/embed/google.ts new file mode 100644 index 00000000..4d029e01 --- /dev/null +++ b/packages/semantic-cache/src/embed/google.ts @@ -0,0 +1,97 @@ +/** + * Google AI (Gemini) embedding helper for @betterdb/semantic-cache. + * + * Supports text-embedding-004 and other Gemini embedding models via the + * Google AI REST API. Uses native fetch - no SDK required. + * + * Usage: + * import { createGoogleEmbed } from '@betterdb/semantic-cache/embed/google'; + * const embed = createGoogleEmbed({ model: 'text-embedding-004' }); + * const cache = new SemanticCache({ client, embedFn: embed }); + */ +import type { EmbedFn } from '../types'; + +export type GoogleEmbedTaskType = + | 'RETRIEVAL_QUERY' + | 'RETRIEVAL_DOCUMENT' + | 'SEMANTIC_SIMILARITY' + | 'CLASSIFICATION' + | 'CLUSTERING' + | (string & {}); + +export interface GoogleEmbedOptions { + /** + * Google AI embedding model. + * Default: 'text-embedding-004' (768 dimensions). + * Other options: 'text-multilingual-embedding-002', 'embedding-001'. + */ + model?: string; + /** Google AI (Gemini) API key. Default: GOOGLE_API_KEY env var. */ + apiKey?: string; + /** API base URL. Default: 'https://generativelanguage.googleapis.com/v1beta'. */ + baseUrl?: string; + /** + * Task type hint for the embedding. + * Default: 'RETRIEVAL_QUERY'. Use 'RETRIEVAL_DOCUMENT' when storing. + */ + taskType?: GoogleEmbedTaskType; + /** + * Optional document title, used only with taskType 'RETRIEVAL_DOCUMENT'. + * Improves retrieval quality when provided alongside the document body. + */ + title?: string; + /** + * Optional output dimensionality (truncation). Supported by text-embedding-004+. + * When omitted, the model's full dimensionality is returned. + */ + outputDimensionality?: number; +} + +/** + * Create an EmbedFn backed by the Google AI (Gemini) Embeddings API. + * Uses native fetch - no SDK required. + */ +export function createGoogleEmbed(opts?: GoogleEmbedOptions): EmbedFn { + const model = opts?.model ?? 'text-embedding-004'; + const baseUrl = opts?.baseUrl ?? 'https://generativelanguage.googleapis.com/v1beta'; + const taskType = opts?.taskType ?? 'RETRIEVAL_QUERY'; + + return async (text: string): Promise => { + const apiKey = opts?.apiKey ?? process.env.GOOGLE_API_KEY; + if (!apiKey) { + throw new Error( + 'Google API key is required. Set GOOGLE_API_KEY env var or pass apiKey in options.', + ); + } + + const requestBody: Record = { + model: `models/${model}`, + content: { parts: [{ text }] }, + taskType, + }; + + if (opts?.title !== undefined) { + requestBody.title = opts.title; + } + if (opts?.outputDimensionality !== undefined) { + requestBody.outputDimensionality = opts.outputDimensionality; + } + + const res = await fetch(`${baseUrl}/models/${model}:embedContent`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-goog-api-key': apiKey, + }, + body: JSON.stringify(requestBody), + }); + + if (!res.ok) { + const body = await res.text().catch(() => ''); + throw new Error(`Google AI API error: ${res.status} ${body}`); + } + + const json = (await res.json()) as { embedding: { values: number[] } }; + return json.embedding?.values ?? []; + }; +}