diff --git a/backend/tests/unit/test_prompt_cache_integration.py b/backend/tests/unit/test_prompt_cache_integration.py index 0cff89a867..868a13d8d1 100644 --- a/backend/tests/unit/test_prompt_cache_integration.py +++ b/backend/tests/unit/test_prompt_cache_integration.py @@ -691,6 +691,112 @@ def __init__(self, **kwargs): assert "prompt_cache_key" not in mkw, f"Client {call.get('model')} should not have prompt_cache_key" +def _load_clients_namespace(captured_calls, byok_key=None): + """Exec clients.py in an isolated namespace with fake providers, returning the namespace. + + captured_calls collects every ChatOpenAI(**kwargs) construction so tests can assert the + runtime kwargs (extra_body / prompt_cache_retention) per model. + """ + + class FakeChatOpenAI: + def __init__(self, **kwargs): + self.kwargs = kwargs + captured_calls.append(kwargs) + + def bind(self, **kwargs): + self.bound = kwargs + return self + + class FakeOpenAIEmbeddings: + def __init__(self, **kwargs): + pass + + fake_tiktoken = _stub_module("tiktoken_fake2") + fake_tiktoken.encoding_for_model = MagicMock(return_value=MagicMock()) + fake_anthropic = _stub_module("anthropic_fake2") + fake_anthropic.AsyncAnthropic = MagicMock + + source = (BACKEND_DIR / "utils" / "llm" / "clients.py").read_text(encoding="utf-8") + for line in [ + "from langchain_core.language_models import BaseChatModel", + "from langchain_openai import ChatOpenAI, OpenAIEmbeddings", + "from langchain_google_genai import ChatGoogleGenerativeAI", + "import tiktoken", + "import anthropic", + "from langchain_core.output_parsers import PydanticOutputParser", + "from models.structured import Structured", + "from utils.byok import get_byok_key", + "from utils.llm.usage_tracker import get_usage_callback", + ]: + source = source.replace(line, "") + + ns = { + "os": os, + "BaseChatModel": object, + "ChatOpenAI": FakeChatOpenAI, + "ChatGoogleGenerativeAI": FakeChatOpenAI, + "OpenAIEmbeddings": FakeOpenAIEmbeddings, + "tiktoken": fake_tiktoken, + "anthropic": fake_anthropic, + "PydanticOutputParser": MagicMock(), + "Structured": MagicMock(), + "get_byok_key": MagicMock(return_value=byok_key), + "get_usage_callback": MagicMock(return_value=[]), + "List": list, + } + exec(source, ns) + return ns + + +def test_renamed_gpt5_model_still_gets_cache_features(): + """ + Capability-based gating (not exact model names): a different gpt-5 family model that is + not the hardcoded 'gpt-5.1'/'gpt-5.4' must still receive prompt_cache_retention and be + eligible for prompt_cache_key routing. + """ + captured_calls = [] + ns = _load_clients_namespace(captured_calls) + + # A hypothetical future/renamed gpt-5 family model. + new_model = "gpt-5.9-turbo" + + # Retention must be applied for the new gpt-5 model via the default factory. + ns["_get_or_create_openai_llm"](new_model) + new_calls = [c for c in captured_calls if c.get("model") == new_model] + assert new_calls, "Factory should have constructed the new gpt-5 model client" + for call in new_calls: + assert ( + call.get("extra_body", {}).get("prompt_cache_retention") == "24h" + ), f"Renamed gpt-5 model should get 24h retention: {call}" + + # prompt_cache_key routing must be eligible for the new model. + assert ns["_supports_prompt_cache_key"](new_model), "Renamed gpt-5 model should support prompt_cache_key" + assert ns["_supports_cache_retention"](new_model), "Renamed gpt-5 model should support cache retention" + + +def test_non_cache_capable_model_is_unchanged(): + """A non-cache-capable model (e.g. a Gemini name) must not get cache retention/routing.""" + captured_calls = [] + ns = _load_clients_namespace(captured_calls) + + assert not ns["_supports_prompt_cache_key"]("gemini-2.5-flash-lite") + assert not ns["_supports_cache_retention"]("gemini-2.5-flash-lite") + # gpt-4.1-mini supports routing but not 24h retention. + assert ns["_supports_prompt_cache_key"]("gpt-4.1-mini") + assert not ns["_supports_cache_retention"]("gpt-4.1-mini") + + +def test_get_llm_binds_cache_key_for_cache_capable_models(): + """get_llm must bind prompt_cache_key for cache-capable models beyond the old hardcoded set.""" + captured_calls = [] + ns = _load_clients_namespace(captured_calls) + + # Force the active profile to route chat_responses to a gpt-5 model so get_llm resolves it. + ns["_active_profile"]["chat_responses"] = ("gpt-5.4", "openai") + llm = ns["get_llm"]("chat_responses", cache_key="omi-chat") + assert getattr(llm, "bound", {}).get("prompt_cache_key") == "omi-chat", "get_llm should bind prompt_cache_key" + + # --------------------------------------------------------------------------- # Tests: Tool list construction in execute functions # --------------------------------------------------------------------------- diff --git a/backend/tests/unit/test_prompt_cache_optimization.py b/backend/tests/unit/test_prompt_cache_optimization.py index 45df596517..75c41305f2 100644 --- a/backend/tests/unit/test_prompt_cache_optimization.py +++ b/backend/tests/unit/test_prompt_cache_optimization.py @@ -129,7 +129,9 @@ def test_qos_cache_key_in_clients(): """Omi QoS get_llm() should support cache_key parameter for prompt cache routing.""" source = _read_clients_source() assert "cache_key" in source, "clients.py get_llm() should accept cache_key parameter" - assert "_CACHE_KEY_MODELS" in source, "clients.py should define _CACHE_KEY_MODELS for model-safe cache key handling" + assert ( + "_supports_prompt_cache_key" in source + ), "clients.py should gate prompt_cache_key by capability (_supports_prompt_cache_key)" def test_qos_medium_tier_uses_extra_body_for_cache_retention(): diff --git a/backend/tests/unit/test_prompt_caching.py b/backend/tests/unit/test_prompt_caching.py index 2805c1e854..21bc38ae93 100644 --- a/backend/tests/unit/test_prompt_caching.py +++ b/backend/tests/unit/test_prompt_caching.py @@ -330,21 +330,32 @@ def _read_clients_source(): clients_path = Path(__file__).resolve().parent.parent.parent / "utils" / "llm" / "clients.py" return clients_path.read_text(encoding="utf-8") - def test_qos_gpt51_has_cache_retention(self): - """QoS _get_or_create_openai_llm must set prompt_cache_retention=24h for gpt-5.1.""" + def test_qos_openai_llm_gates_retention_by_capability(self): + """_get_or_create_openai_llm must set prompt_cache_retention=24h via a capability check.""" source = self._read_clients_source() match = re.search( - r"_get_or_create_openai_llm.*?gpt-5\.1.*?prompt_cache_retention.*?24h", + r"_get_or_create_openai_llm.*?_supports_cache_retention\(.*?prompt_cache_retention.*?24h", source, re.DOTALL, ) - assert match, "_get_or_create_openai_llm should set prompt_cache_retention='24h' for gpt-5.1" + assert ( + match + ), "_get_or_create_openai_llm should gate prompt_cache_retention='24h' by _supports_cache_retention()" - def test_qos_tier_medium_gets_cache_retention(self): - """Omi QoS tier medium (gpt-5.1) must set prompt_cache_retention=24h via _get_or_create_openai_llm.""" - source = self._read_clients_source() - match = re.search(r'_get_or_create_openai_llm.*?gpt-5\.1.*?prompt_cache_retention.*?24h', source, re.DOTALL) - assert match, "QoS _get_or_create_openai_llm should set prompt_cache_retention='24h' for gpt-5.1" + def test_gpt51_is_cache_retention_capable(self): + """gpt-5.1 must still be recognized as 24h-retention capable after the capability refactor.""" + from pathlib import Path + + clients_path = Path(__file__).resolve().parent.parent.parent / "utils" / "llm" / "clients.py" + # Extract the prefix tuple and the predicate directly from source to avoid importing + # clients.py (which pulls heavy provider SDKs not available in the unit-test env). + source = clients_path.read_text(encoding="utf-8") + m = re.search(r"_CACHE_RETENTION_MODEL_PREFIXES\s*=\s*\(([^)]*)\)", source) + assert m, "clients.py should define _CACHE_RETENTION_MODEL_PREFIXES" + prefixes = tuple(p.strip().strip("'\"") for p in m.group(1).split(",") if p.strip()) + assert "gpt-5.1".startswith(prefixes), f"gpt-5.1 should match a retention-capable prefix in {prefixes}" + # A renamed gpt-5 family model should also be covered (the whole point of the refactor). + assert "gpt-5.4".startswith(prefixes), "gpt-5.4 should be retention-capable" def test_cache_retention_not_in_model_kwargs(self): """prompt_cache_retention must NOT be in model_kwargs (SDK rejects it there).""" diff --git a/backend/utils/llm/clients.py b/backend/utils/llm/clients.py index 3e96730fcb..64e0a736b2 100644 --- a/backend/utils/llm/clients.py +++ b/backend/utils/llm/clients.py @@ -153,7 +153,7 @@ def _create_byok_client( ) -> Optional[ChatOpenAI]: """Create a ChatOpenAI using the user's BYOK key. Returns None if BYOK not supported for this provider.""" kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1} - if model == 'gpt-5.1': + if _supports_cache_retention(model): kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: kwargs['streaming'] = True @@ -397,8 +397,34 @@ def get_openai_chat(model: str, **kwargs) -> ChatOpenAI: 'wrapped_analysis': 0.7, } -# Models that support OpenAI prompt caching (prompt_cache_key routing). -_CACHE_KEY_MODELS = {'gpt-5.4', 'gpt-5.4-mini'} +# Prompt-cache capability detection. +# +# OpenAI prompt caching is a capability of whole model families, not specific point +# releases. Gating on exact names (e.g. {'gpt-5.4', 'gpt-5.4-mini'}) silently breaks the +# moment a model is renamed or a new family member ships, so we detect by family prefix +# instead. +# +# prompt_cache_key — request routing for the prefix cache. Supported by the gpt-4o, +# gpt-4.1, gpt-5.x and o-series families. +# prompt_cache_retention='24h' — extended (24h) cache retention. Supported by the +# gpt-5.x and o-series families. + +# Family prefixes whose models support OpenAI prompt caching (prompt_cache_key routing). +_CACHE_KEY_MODEL_PREFIXES = ('gpt-5', 'gpt-4.1', 'gpt-4o', 'o1', 'o3', 'o4') + +# Family prefixes whose models support 24h prompt-cache retention. +_CACHE_RETENTION_MODEL_PREFIXES = ('gpt-5', 'o1', 'o3', 'o4') + + +def _supports_prompt_cache_key(model: str) -> bool: + """Whether a model supports OpenAI prompt-cache routing (prompt_cache_key).""" + return bool(model) and model.startswith(_CACHE_KEY_MODEL_PREFIXES) + + +def _supports_cache_retention(model: str) -> bool: + """Whether a model supports 24h OpenAI prompt-cache retention.""" + return bool(model) and model.startswith(_CACHE_RETENTION_MODEL_PREFIXES) + # Features that call .with_structured_output() — logged when resolving to Gemini for compat monitoring. _STRUCTURED_OUTPUT_FEATURES = { @@ -463,7 +489,7 @@ def _get_or_create_openai_llm(model_name: str, streaming: bool = False) -> ChatO 'request_timeout': 120, 'max_retries': 1, } - if model_name == 'gpt-5.1': + if _supports_cache_retention(model_name): kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: kwargs['streaming'] = True @@ -642,7 +668,7 @@ def get_llm(feature: str, streaming: bool = False, cache_key: Optional[str] = No else: result = _get_default_client(model, provider, streaming, feature) - if cache_key and model in _CACHE_KEY_MODELS: + if cache_key and _supports_prompt_cache_key(model): return result.bind(prompt_cache_key=cache_key) return result