diff --git a/packages/agent-cache-py/betterdb_agent_cache/adapters/openai_agents.py b/packages/agent-cache-py/betterdb_agent_cache/adapters/openai_agents.py new file mode 100644 index 00000000..9b45830d --- /dev/null +++ b/packages/agent-cache-py/betterdb_agent_cache/adapters/openai_agents.py @@ -0,0 +1,459 @@ +"""OpenAI Agents SDK adapter. + +Wraps any Agents SDK ``Model`` with an exact-match LLM cache. Cache is +consulted before each ``get_response()`` call; on miss the underlying model +is invoked and the response is stored. ``stream_response()`` is not cached +(streaming responses are not cached by any adapter — documented convention). + +Usage via ModelProvider (recommended):: + + from agents import Agent, RunConfig, Runner + from betterdb_agent_cache.adapters.openai_agents import CachedModelProvider + + cached_provider = CachedModelProvider(provider, cache=agent_cache) + result = await Runner.run( + agent, "Hello", run_config=RunConfig(model_provider=cached_provider), + ) + +Usage via direct Model wrapping:: + + from agents import Agent + from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel + from betterdb_agent_cache.adapters.openai_agents import CachedModel + + base_model = OpenAIChatCompletionsModel(model="gpt-4o", openai_client=client) + agent = Agent(name="Assistant", model=CachedModel(base_model, cache=agent_cache)) + +Also exposes ``prepare_params`` for users who want to manage caching +manually rather than through the wrapper. + +Limitations +~~~~~~~~~~~ +* ``stream_response()`` is delegated directly — streaming is not cached. +* Binary / multimodal content in input items is JSON-serialised raw via + ``_to_text()``. A follow-up can add explicit normalizer dispatch + matching ``openai.py``. +* ``tools``, ``handoffs``, and ``output_schema`` are excluded from the + cache key — safe when one CachedModel wraps a single Agent whose tools + don't change between calls. +* ``ResponseOutputRefusal`` content is stored as a plain text block; the + cached hit returns the refusal message as text rather than a typed refusal + object. +""" +from __future__ import annotations + +import inspect +import json +from dataclasses import dataclass, field, is_dataclass +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any + +from ..normalizer import BinaryNormalizer, default_normalizer +from ..types import ContentBlock, LlmCacheParams, LlmStoreOptions +from ..utils import parse_tool_call_args + +if TYPE_CHECKING: + from ..agent_cache import AgentCache + + +@dataclass +class OpenAIAgentsPrepareOptions: + normalizer: BinaryNormalizer = field(default_factory=lambda: default_normalizer) + + +def _to_text(value: Any) -> str: + """Serialize a value to a stable text representation for cache keys.""" + if value is None: + return "" + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + +async def _normalize_input_item( + item: Any, +) -> dict[str, Any]: + """Reduce a single Responses API input item to a canonical dict for hashing. + + .. note:: + Binary / image content is JSON-serialised raw via ``_to_text()``. + A follow-up can add explicit normalizer dispatch matching ``openai.py``. + """ + if isinstance(item, str): + return {"type": "message", "role": "user", "content": item} + if isinstance(item, dict): + # Responses API items are already dicts — normalize nested content + # by sorting keys for deterministic hashing. + return json.loads(json.dumps(item, ensure_ascii=False, sort_keys=True)) + if hasattr(item, "model_dump"): + return json.loads( + json.dumps(item.model_dump(exclude_none=True), ensure_ascii=False, sort_keys=True), + ) + if is_dataclass(item) and not isinstance(item, type): + try: + from dataclasses import asdict + + return json.loads(json.dumps(asdict(item), ensure_ascii=False, sort_keys=True)) + except TypeError: + pass + return {"type": "unknown", "content": _to_text(item)} + + +async def prepare_params( + system_instructions: str | None, + input: str | list[Any], + model_name: str, + model_settings: Any | None = None, + opts: OpenAIAgentsPrepareOptions | None = None, +) -> LlmCacheParams: + """Convert OpenAI Agents SDK get_response() args to canonical ``LlmCacheParams``.""" + # opts.normalizer is reserved for follow-up binary/multimodal normalizer + # dispatch in _normalize_input_item — matching the peer adapter API surface. + + messages: list[Any] = [] + + if system_instructions: + messages.append({"role": "system", "content": system_instructions}) + + if isinstance(input, str): + messages.append({"role": "user", "content": [{"type": "text", "text": input}]}) + else: + for item in input: + messages.append(await _normalize_input_item(item)) + + result: LlmCacheParams = {"model": model_name, "messages": messages} + + settings: dict[str, Any] = {} + if model_settings is not None: + if hasattr(model_settings, "model_dump"): + settings = model_settings.model_dump(exclude_none=True) or {} + elif isinstance(model_settings, dict): + settings = model_settings + else: + try: + settings = {k: v for k, v in vars(model_settings).items() if v is not None} + except TypeError: + settings = {} + + if settings.get("temperature") is not None: + result["temperature"] = settings["temperature"] + if settings.get("top_p") is not None: + result["top_p"] = settings["top_p"] + if settings.get("max_tokens") is not None: + result["max_tokens"] = settings["max_tokens"] + if settings.get("max_output_tokens") is not None: + result["max_tokens"] = settings["max_output_tokens"] + if settings.get("seed") is not None: + result["seed"] = settings["seed"] + if settings.get("stop") is not None: + stop = settings["stop"] + result["stop"] = [stop] if isinstance(stop, str) else stop + if settings.get("tool_choice") is not None: + result["tool_choice"] = settings["tool_choice"] + if settings.get("frequency_penalty") is not None: + result["frequency_penalty"] = settings["frequency_penalty"] + if settings.get("presence_penalty") is not None: + result["presence_penalty"] = settings["presence_penalty"] + if settings.get("parallel_tool_calls") is not None: + result["parallel_tool_calls"] = settings["parallel_tool_calls"] + if settings.get("reasoning") is not None: + result["reasoning"] = settings["reasoning"] + + return result + + +def _parse_args(args: Any) -> dict[str, Any]: + """Parse function call arguments (string or dict).""" + if isinstance(args, dict): + return args + return parse_tool_call_args(args) if isinstance(args, str) else {} + + +def _extract_blocks(response: Any) -> list[ContentBlock]: + """Extract ContentBlock dicts from a ModelResponse.output list.""" + blocks: list[ContentBlock] = [] + raw_out = getattr(response, "output", []) or [] + for item in raw_out: + item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None) + if item_type == "message": + parts = item.get("content") if isinstance(item, dict) else getattr(item, "content", []) + parts = parts or [] + for part in parts: + part_type = part.get("type") if isinstance(part, dict) else getattr(part, "type", None) + if part_type in ("output_text", "text"): + text_val = "" + if isinstance(part, dict): + text_val = part.get("text") or "" + else: + text_val = getattr(part, "text", "") or "" + blocks.append({"type": "text", "text": text_val}) + elif part_type == "refusal": + # ResponseOutputRefusal — store refusal text so cache hits + # preserve the refusal content rather than silently dropping it. + refusal_text = "" + if isinstance(part, dict): + refusal_text = part.get("refusal") or "" + else: + refusal_text = getattr(part, "refusal", "") or "" + blocks.append({"type": "text", "text": refusal_text}) + elif item_type == "function_call": + if isinstance(item, dict): + call_id = item.get("call_id", "") + name = item.get("name", "") + arguments = item.get("arguments", "") + else: + call_id = getattr(item, "call_id", "") or "" + name = getattr(item, "name", "") or "" + arguments = getattr(item, "arguments", "") or "" + blocks.append({ + "type": "tool_call", + "id": call_id, + "name": name, + "args": _parse_args(arguments), + }) + return blocks + + +def _rebuild_output( + content_blocks: list[ContentBlock] | None, + response_text: str | None, +) -> list[Any]: + """Rebuild Responses API output items from cached ContentBlocks. + + Uses OpenAI SDK output models when available so ``ModelResponse`` passes + Pydantic validation (``openai-agents`` 0.1+). Falls back to ``SimpleNamespace`` + for older stacks that use plain dataclasses. + """ + try: + from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputText, + ) + except ImportError: + ResponseOutputMessage = None # type: ignore[assignment,misc] + ResponseOutputText = None # type: ignore[assignment,misc] + ResponseFunctionToolCall = None # type: ignore[assignment,misc] + + def text_part(text_val: str) -> Any: + if ResponseOutputText is None: + return SimpleNamespace(type="output_text", text=text_val) + try: + return ResponseOutputText.model_construct( + type="output_text", + text=text_val, + annotations=[], + ) + except TypeError: + try: + return ResponseOutputText.model_construct(type="output_text", text=text_val) + except Exception: + return SimpleNamespace(type="output_text", text=text_val) + + def tool_part(call_id: str, name: str, arguments: str) -> Any: + if ResponseFunctionToolCall is None: + return SimpleNamespace( + type="function_call", + call_id=call_id, + name=name, + arguments=arguments, + ) + try: + return ResponseFunctionToolCall.model_construct( + type="function_call", + call_id=call_id, + name=name, + arguments=arguments, + ) + except Exception: + return SimpleNamespace( + type="function_call", + call_id=call_id, + name=name, + arguments=arguments, + ) + + output: list[Any] = [] + text_parts: list[Any] = [] + + if content_blocks: + for block in content_blocks: + if block["type"] == "text": + text_parts.append(text_part(block["text"])) + elif block["type"] == "tool_call": + args_str = json.dumps(block.get("args", {}), ensure_ascii=False, sort_keys=True) + output.append(tool_part(block.get("id", ""), block.get("name", ""), args_str)) + elif response_text is not None: + text_parts.append(text_part(response_text)) + + if text_parts: + if ResponseOutputMessage is None: + output.insert(0, SimpleNamespace( + type="message", role="assistant", content=text_parts, + )) + else: + try: + output.insert( + 0, + ResponseOutputMessage.model_construct( + id="betterdb-cache", + type="message", + role="assistant", + status="completed", + content=text_parts, + ), + ) + except TypeError: + output.insert( + 0, + ResponseOutputMessage.model_construct( + id="betterdb-cache", + type="message", + role="assistant", + content=text_parts, + ), + ) + except Exception: + output.insert(0, SimpleNamespace( + type="message", role="assistant", content=text_parts, + )) + + return output + + +def _make_usage(input_tokens: int, output_tokens: int) -> Any: + """Create a minimal ``Usage`` object for cache hits.""" + from agents.usage import Usage + + return Usage( + requests=0, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) + + +def _cache_hit_model_response(output: list[Any], usage: Any) -> Any: + """Build ``ModelResponse`` for a cache hit, compatible across SDK releases.""" + from agents.items import ModelResponse + + fields = inspect.signature(ModelResponse.__init__).parameters + kw: dict[str, Any] = {"output": output, "usage": usage, "response_id": None} + if "request_id" in fields: + kw["request_id"] = None + if "referenceable_id" in fields: + kw["referenceable_id"] = None + return ModelResponse(**kw) + + +class CachedModel: + """Agents SDK ``Model`` wrapper that checks the cache before each + ``get_response()`` call. ``stream_response()`` is delegated directly. + """ + + def __init__( + self, + model: Any, + cache: "AgentCache", + opts: OpenAIAgentsPrepareOptions | None = None, + ) -> None: + self._model = model + self._cache = cache + self._opts = opts or OpenAIAgentsPrepareOptions() + + def __getattr__(self, name: str) -> Any: + return getattr(self._model, name) + + def stream_response(self, *args: Any, **kwargs: Any) -> Any: + """Streaming is not cached — delegate directly.""" + return self._model.stream_response(*args, **kwargs) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: Any, + tools: list[Any], + output_schema: Any | None, + handoffs: list[Any], + tracing: Any, + *, + previous_response_id: str | None = None, + **kwargs: Any, + ) -> Any: + model_name = str(getattr(self._model, "model", "unknown")) + + # tools, handoffs, and output_schema are excluded from the cache key. + # This is safe when one CachedModel wraps a single Agent whose tools + # don't change between calls — the typical usage pattern. + # previous_response_id, conversation_id, and prompt are also excluded: + # they are server-side context references, not content. Including them + # would prevent caching the same logical prompt across conversation turns. + # If server-side context affects your responses, create separate + # CachedModel instances per conversation thread. + params = await prepare_params( + system_instructions, input, model_name, model_settings, self._opts, + ) + + cached = await self._cache.llm.check(params) + if cached.hit: + output = _rebuild_output(cached.content_blocks, cached.response) + return _cache_hit_model_response( + output, + _make_usage(cached.input_tokens, cached.output_tokens), + ) + + response = await self._model.get_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id=previous_response_id, + **kwargs, + ) + + store_blocks = _extract_blocks(response) + + usage = getattr(response, "usage", None) + inp = int(getattr(usage, "input_tokens", 0) or 0) + out_tok = int(getattr(usage, "output_tokens", 0) or 0) + await self._cache.llm.store_multipart( + params, + store_blocks, + LlmStoreOptions(tokens={"input": inp, "output": out_tok}), + ) + return response + + +class CachedModelProvider: + """Wraps a ``ModelProvider`` so every ``Model`` it returns is cache-enabled. + + This is the recommended integration point:: + + from agents import RunConfig, Runner + from betterdb_agent_cache.adapters.openai_agents import CachedModelProvider + + provider = CachedModelProvider(original_provider, cache=agent_cache) + result = await Runner.run(agent, "hi", run_config=RunConfig(model_provider=provider)) + """ + + def __init__( + self, + provider: Any, + cache: "AgentCache", + opts: OpenAIAgentsPrepareOptions | None = None, + ) -> None: + self._provider = provider + self._cache = cache + self._opts = opts or OpenAIAgentsPrepareOptions() + + def get_model(self, model_name: str | None) -> CachedModel: + base = self._provider.get_model(model_name) + return CachedModel(base, self._cache, self._opts) + + async def aclose(self) -> None: + if hasattr(self._provider, "aclose"): + await self._provider.aclose() diff --git a/packages/agent-cache-py/betterdb_agent_cache/tiers/llm_cache.py b/packages/agent-cache-py/betterdb_agent_cache/tiers/llm_cache.py index b6afd110..f4804d98 100644 --- a/packages/agent-cache-py/betterdb_agent_cache/tiers/llm_cache.py +++ b/packages/agent-cache-py/betterdb_agent_cache/tiers/llm_cache.py @@ -107,11 +107,14 @@ async def check(self, params: LlmCacheParams) -> LlmCacheResult: ).inc() span.set_attribute("cache.hit", True) + stored_tokens: dict[str, int] = entry.get("tokens") or {} return LlmCacheResult( hit=True, response=entry.get("response"), content_blocks=entry.get("contentBlocks"), key=key, + input_tokens=int(stored_tokens.get("input", 0)), + output_tokens=int(stored_tokens.get("output", 0)), ) await self._inc_stats({"llm:misses": 1}) diff --git a/packages/agent-cache-py/betterdb_agent_cache/types.py b/packages/agent-cache-py/betterdb_agent_cache/types.py index 227898e4..6d19f59f 100644 --- a/packages/agent-cache-py/betterdb_agent_cache/types.py +++ b/packages/agent-cache-py/betterdb_agent_cache/types.py @@ -183,6 +183,8 @@ class LlmCacheResult: response: str | None = None content_blocks: list[ContentBlock] | None = None key: str | None = None + input_tokens: int = 0 + output_tokens: int = 0 @dataclass diff --git a/packages/agent-cache-py/examples/openai_agents/README.md b/packages/agent-cache-py/examples/openai_agents/README.md new file mode 100644 index 00000000..34b6ad66 --- /dev/null +++ b/packages/agent-cache-py/examples/openai_agents/README.md @@ -0,0 +1,25 @@ +# OpenAI Agents SDK example + +This example shows how to wrap an OpenAI Agents SDK `ModelProvider` with `CachedModelProvider` so LLM responses are served from `betterdb-agent-cache` on repeat requests. + +It demonstrates: +- text prompts via `Runner.run()` +- tool-calling flows with `@function_tool` + +## Install + +```bash +docker run -d --name valkey -p 6379:6379 valkey/valkey:8 +pip install "betterdb-agent-cache[openai_agents]" +export OPENAI_API_KEY=sk-... +``` + +## Run + +```bash +python main.py +``` + +## Expected output + +The first call in each scenario is a miss and the second is a hit. At the end,cache stats show non-zero LLM hits and a positive cost-saved value. diff --git a/packages/agent-cache-py/examples/openai_agents/main.py b/packages/agent-cache-py/examples/openai_agents/main.py new file mode 100644 index 00000000..7c3a71c6 --- /dev/null +++ b/packages/agent-cache-py/examples/openai_agents/main.py @@ -0,0 +1,72 @@ +""" +OpenAI Agents SDK + betterdb-agent-cache example + +Demonstrates caching agent responses with two scenarios: + 1. Simple text agent — responses cached by prompt hash + 2. Agent with tools — tool calls round-trip through cache + +Usage: + docker run -d --name valkey -p 6379:6379 valkey/valkey:8 + pip install "betterdb-agent-cache[openai_agents]" + export OPENAI_API_KEY=sk-... + python main.py +""" +from __future__ import annotations + +import asyncio + +import valkey.asyncio as valkey_client +from agents import Agent, Runner, RunConfig, function_tool, OpenAIProvider + +from betterdb_agent_cache import AgentCache, ModelCost, TierDefaults +from betterdb_agent_cache.adapters.openai_agents import CachedModelProvider +from betterdb_agent_cache.types import AgentCacheOptions + + +@function_tool +def get_weather(city: str) -> str: + """Get the current weather for a city.""" + return f"Weather in {city}: sunny, 22°C" + + +async def main() -> None: + client = valkey_client.Valkey(host="localhost", port=6379) + cache = AgentCache( + AgentCacheOptions( + client=client, + tier_defaults={"llm": TierDefaults(ttl=3600)}, + cost_table={ + "gpt-4o-mini": ModelCost(input_per_1k=0.00015, output_per_1k=0.0006), + }, + ), + ) + + cached_provider = CachedModelProvider(OpenAIProvider(), cache=cache) + run_config = RunConfig(model="gpt-4o-mini", model_provider=cached_provider) + + text_agent = Agent(name="Concise", instructions="You are concise.") + print("\n=== 1. Simple text agent ===") + for i in range(2): + result = await Runner.run(text_agent, "What is 2+2? One word.", run_config=run_config) + print(f" [{i + 1}] {result.final_output}") + + tools_agent = Agent(name="Weather", instructions="Use tools.", tools=[get_weather]) + print("\n=== 2. Agent with tools ===") + for i in range(2): + result = await Runner.run(tools_agent, "Weather in London?", run_config=run_config) + print(f" [{i + 1}] {result.final_output}") + + stats = await cache.stats() + print("\n-- Cache Stats --") + print( + "LLM: " + f"{stats.llm.hits} hits / {stats.llm.misses} misses ({stats.llm.hit_rate:.0%})", + ) + print(f"Cost saved: ${stats.cost_saved_micros / 1_000_000:.6f}") + + await cache.shutdown() + await client.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/agent-cache-py/pyproject.toml b/packages/agent-cache-py/pyproject.toml index 353ebb9e..ceaf6a68 100644 --- a/packages/agent-cache-py/pyproject.toml +++ b/packages/agent-cache-py/pyproject.toml @@ -6,7 +6,18 @@ build-backend = "hatchling.build" name = "betterdb-agent-cache" version = "0.6.0" description = "Multi-tier exact-match cache for AI agent workloads backed by Valkey. LLM responses, tool results, and session state with built-in OpenTelemetry and Prometheus instrumentation." -keywords = ["valkey", "redis", "agent", "cache", "llm", "opentelemetry", "prometheus", "langchain", "langgraph"] +keywords = [ + "valkey", + "redis", + "agent", + "cache", + "llm", + "opentelemetry", + "prometheus", + "langchain", + "langgraph", + "openai-agents", +] license = { text = "MIT" } readme = "README.md" requires-python = ">=3.11" @@ -22,12 +33,14 @@ anthropic = ["anthropic>=0.20.0"] langchain = ["langchain-core>=0.1.0"] langgraph = ["langgraph>=0.1.0"] llamaindex = ["llama-index-core>=0.10.0"] +openai_agents = ["openai-agents>=0.0.14"] analytics = ["posthog>=3.0.0"] normalizer = ["aiohttp>=3.9.0"] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", "fakeredis[aioredis]>=2.20.0", + "openai-agents>=0.0.14", ] all = [ "openai>=1.0.0", @@ -36,6 +49,7 @@ all = [ "langgraph>=0.1.0", "llama-index-core>=0.10.0", "posthog>=3.0.0", + "openai-agents>=0.0.14", "aiohttp>=3.9.0", ] diff --git a/packages/agent-cache-py/tests/adapters/test_openai_agents.py b/packages/agent-cache-py/tests/adapters/test_openai_agents.py new file mode 100644 index 00000000..a9097016 --- /dev/null +++ b/packages/agent-cache-py/tests/adapters/test_openai_agents.py @@ -0,0 +1,314 @@ +"""Tests for the OpenAI Agents SDK adapter.""" +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from betterdb_agent_cache.adapters.openai_agents import ( + CachedModel, + CachedModelProvider, + prepare_params, +) +from betterdb_agent_cache.agent_cache import AgentCache +from betterdb_agent_cache.types import AgentCacheOptions, TierDefaults + +from ..conftest import make_persisting_valkey_client + +try: + import agents # noqa: F401 +except Exception as exc: # pragma: no cover - environment dependent + pytest.skip( + f"openai-agents unavailable or incompatible in this environment: {exc}", + allow_module_level=True, + ) + + +def _make_cache() -> AgentCache: + client = make_persisting_valkey_client() + with patch("betterdb_agent_cache.agent_cache.create_analytics"): + return AgentCache( + AgentCacheOptions( + client=client, + tier_defaults={"llm": TierDefaults(ttl=300)}, + ), + ) + + +class _FakeModel: + """Minimal mock of agents.models.interface.Model.""" + model = "fake-model" + + def __init__(self, response: object, *, raise_error: Exception | None = None) -> None: + self.response = response + self.raise_error = raise_error + self.calls = 0 + + async def get_response( + self, + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + *, + previous_response_id=None, + **kwargs, + ): + self.calls += 1 + if self.raise_error is not None: + raise self.raise_error + return self.response + + def stream_response(self, *args, **kwargs): + raise NotImplementedError("stream not mocked") + + async def close(self): + pass + + +class _FakeProvider: + def __init__(self, model: _FakeModel): + self._model = model + + def get_model(self, model_name: str | None) -> _FakeModel: + return self._model + + async def aclose(self): + pass + + +def _make_text_response(text: str) -> SimpleNamespace: + return SimpleNamespace( + output=[ + SimpleNamespace( + type="message", + role="assistant", + content=[ + SimpleNamespace(type="output_text", text=text), + ], + ), + ], + usage=SimpleNamespace(input_tokens=10, output_tokens=5), + referenceable_id=None, + request_id=None, + ) + + +def _make_tool_response(call_id: str, name: str, args: str) -> SimpleNamespace: + return SimpleNamespace( + output=[ + SimpleNamespace( + type="function_call", + call_id=call_id, + name=name, + arguments=args, + ), + ], + usage=SimpleNamespace(input_tokens=8, output_tokens=12), + referenceable_id=None, + request_id=None, + ) + + +_DEFAULT_KWARGS = dict( + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + previous_response_id=None, + conversation_id=None, + prompt=None, +) + + +@pytest.mark.asyncio +async def test_prepare_params_string_input(): + params = await prepare_params("Be concise.", "hello", "gpt-4o") + assert params["model"] == "gpt-4o" + assert params["messages"][0] == {"role": "system", "content": "Be concise."} + assert params["messages"][1]["role"] == "user" + + +@pytest.mark.asyncio +async def test_prepare_params_list_input(): + items = [ + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}, + ] + params = await prepare_params(None, items, "gpt-4o-mini") + assert params["model"] == "gpt-4o-mini" + assert len(params["messages"]) == 1 + + +@pytest.mark.asyncio +async def test_prepare_params_settings(): + settings = SimpleNamespace( + temperature=0.5, + top_p=0.9, + max_tokens=100, + seed=42, + stop=None, + tool_choice=None, + max_output_tokens=None, + ) + settings.model_dump = lambda exclude_none=False: { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 100, + "seed": 42, + } + params = await prepare_params(None, "test", "gpt-4o", settings) + assert params["temperature"] == 0.5 + assert params["top_p"] == 0.9 + assert params["max_tokens"] == 100 + assert params["seed"] == 42 + + +@pytest.mark.asyncio +async def test_cached_model_getattr_delegation(): + base = _FakeModel(_make_text_response("ok")) + wrapped = CachedModel(base, _make_cache()) + assert wrapped.model == "fake-model" + + +@pytest.mark.asyncio +async def test_cached_model_miss_stores_tool_calls(): + cache = _make_cache() + response = _make_tool_response("call_fn", "get_weather", '{"city":"Berlin"}') + base = _FakeModel(response) + wrapped = CachedModel(base, cache) + + await wrapped.get_response(None, "weather?", None, **_DEFAULT_KWARGS) + + params = await prepare_params(None, "weather?", "fake-model") + cached = await cache.llm.check(params) + assert cached.hit is True + assert cached.content_blocks[0]["type"] == "tool_call" + assert cached.content_blocks[0]["name"] == "get_weather" + assert cached.content_blocks[0]["args"] == {"city": "Berlin"} + + +@pytest.mark.asyncio +async def test_cached_model_miss_stores(): + cache = _make_cache() + response = _make_text_response("miss response") + base = _FakeModel(response) + wrapped = CachedModel(base, cache) + + out = await wrapped.get_response( + "Be concise.", + "hello", + None, + **_DEFAULT_KWARGS, + ) + assert out is response + assert base.calls == 1 + + params = await prepare_params("Be concise.", "hello", "fake-model") + cached = await cache.llm.check(params) + assert cached.hit is True + assert cached.content_blocks[0]["text"] == "miss response" + + +@pytest.mark.asyncio +async def test_cached_model_hit_skips_underlying(): + cache = _make_cache() + params = await prepare_params(None, "cached prompt", "fake-model") + await cache.llm.store_multipart( + params, + [ + {"type": "text", "text": "from cache"}, + {"type": "tool_call", "id": "call_1", "name": "lookup", "args": {"q": "x"}}, + ], + ) + + base = _FakeModel(_make_text_response("should not be called")) + wrapped = CachedModel(base, cache) + out = await wrapped.get_response( + None, + "cached prompt", + None, + **_DEFAULT_KWARGS, + ) + assert base.calls == 0 + assert hasattr(out, "output") + # Verify usage carries stored token counts from the miss (10 input, 5 output per _make_text_response) + # Note: when stored via store_multipart with no LlmStoreOptions, tokens default to 0 + assert out.usage.input_tokens == 0 + assert out.usage.output_tokens == 0 + + +@pytest.mark.asyncio +async def test_cached_model_hit_propagates_stored_tokens(): + """Cache hit returns Usage with the token counts from the original miss.""" + cache = _make_cache() + response = _make_text_response("response with tokens") + # _make_text_response sets usage.input_tokens=10, output_tokens=5 + base = _FakeModel(response) + wrapped = CachedModel(base, cache) + + # Miss: stores with real token counts (10 input, 5 output from _make_text_response) + await wrapped.get_response(None, "prompt", None, **_DEFAULT_KWARGS) + assert base.calls == 1 + + # Hit: should return stored token counts + out = await wrapped.get_response(None, "prompt", None, **_DEFAULT_KWARGS) + assert base.calls == 1 # not called again + assert out.usage.input_tokens == 10 + assert out.usage.output_tokens == 5 + + +@pytest.mark.asyncio +async def test_cached_model_different_prompts(): + cache = _make_cache() + base = _FakeModel(_make_text_response("live")) + wrapped = CachedModel(base, cache) + + await wrapped.get_response("sys", "first", None, **_DEFAULT_KWARGS) + await wrapped.get_response("sys", "first", None, **_DEFAULT_KWARGS) # hit + await wrapped.get_response("sys", "second", None, **_DEFAULT_KWARGS) # miss + assert base.calls == 2 + + +@pytest.mark.asyncio +async def test_cached_model_propagates_errors(): + cache = _make_cache() + base = _FakeModel(_make_text_response(""), raise_error=RuntimeError("boom")) + wrapped = CachedModel(base, cache) + with pytest.raises(RuntimeError, match="boom"): + await wrapped.get_response(None, "hello", None, **_DEFAULT_KWARGS) + + +@pytest.mark.asyncio +async def test_stream_response_delegates_directly(): + """stream_response is not cached — it must delegate without interception.""" + base = _FakeModel(_make_text_response("ok")) + wrapped = CachedModel(base, _make_cache()) + with pytest.raises(NotImplementedError, match="stream not mocked"): + wrapped.stream_response( + None, + "hello", + None, + [], + None, + [], + None, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + + +@pytest.mark.asyncio +async def test_cached_provider_wraps_models(): + cache = _make_cache() + base_model = _FakeModel(_make_text_response("provided")) + provider = CachedModelProvider(_FakeProvider(base_model), cache) + wrapped = provider.get_model("gpt-4o") + assert isinstance(wrapped, CachedModel) + out = await wrapped.get_response(None, "test", None, **_DEFAULT_KWARGS) + assert base_model.calls == 1 + assert out is base_model.response diff --git a/packages/agent-cache-py/tests/conftest.py b/packages/agent-cache-py/tests/conftest.py index b326830a..11b44f6c 100644 --- a/packages/agent-cache-py/tests/conftest.py +++ b/packages/agent-cache-py/tests/conftest.py @@ -77,6 +77,68 @@ def make_client() -> MagicMock: return client +def make_persisting_valkey_client() -> MagicMock: + """Return an async mock valkey client backed by in-memory state.""" + kv: dict[str, str] = {} + hashes: dict[str, dict[str, str]] = {} + + client = make_client() + + async def _get(key: str): + return kv.get(key) + + async def _set(key: str, value: str, ex=None): # noqa: ANN001 + _ = ex + kv[key] = value + return True + + async def _delete(*keys: str): + deleted = 0 + for key in keys: + if key in kv: + del kv[key] + deleted += 1 + if key in hashes: + del hashes[key] + deleted += 1 + return deleted + + async def _hget(name: str, key: str): + return hashes.get(name, {}).get(key) + + async def _hset(name: str, key: str, value: str): + bucket = hashes.setdefault(name, {}) + is_new = key not in bucket + bucket[key] = value + return 1 if is_new else 0 + + async def _hgetall(name: str): + return dict(hashes.get(name, {})) + + async def _hincrby(name: str, key: str, amount: int): + bucket = hashes.setdefault(name, {}) + current = int(bucket.get(key, "0")) + updated = current + amount + bucket[key] = str(updated) + return updated + + async def _scan(cursor=0, match=None, count=None): # noqa: ANN001 + _ = (cursor, match, count) + return (0, []) + + client.get = AsyncMock(side_effect=_get) + client.set = AsyncMock(side_effect=_set) + client.delete = AsyncMock(side_effect=_delete) + client.hget = AsyncMock(side_effect=_hget) + client.hset = AsyncMock(side_effect=_hset) + client.hgetall = AsyncMock(side_effect=_hgetall) + client.hincrby = AsyncMock(side_effect=_hincrby) + client.scan = AsyncMock(side_effect=_scan) + client.expire = AsyncMock(return_value=1) + + return client + + @pytest.fixture def telemetry() -> Telemetry: return make_telemetry()