diff --git a/backend/test.sh b/backend/test.sh index 5ee65c457e..ab6de38e32 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -82,6 +82,7 @@ pytest tests/unit/test_action_item_idempotency.py -v pytest tests/unit/test_tools_router.py -v pytest tests/unit/test_kg_user_type_mismatch.py -v pytest tests/unit/test_kg_edge_id_sanitization.py -v +pytest tests/unit/test_goal_extraction_batch.py -v pytest tests/unit/test_listen_pipeline.py -v pytest tests/unit/test_fair_use_models.py -v pytest tests/unit/test_fair_use_engine.py -v diff --git a/backend/tests/unit/test_goal_extraction_batch.py b/backend/tests/unit/test_goal_extraction_batch.py index b6b0d7757f..3531e39de5 100644 --- a/backend/tests/unit/test_goal_extraction_batch.py +++ b/backend/tests/unit/test_goal_extraction_batch.py @@ -3,10 +3,12 @@ Verifies extract_and_update_goal_progress makes exactly 1 LLM call regardless of goal count. """ +import importlib import json import os import sys import types +from contextlib import nullcontext from unittest.mock import MagicMock, patch import pytest @@ -21,14 +23,15 @@ def _stub_module(name: str) -> types.ModuleType: if name not in sys.modules: mod = types.ModuleType(name) sys.modules[name] = mod + if "." in name: + parent_name, attr = name.rsplit(".", 1) + parent = sys.modules.get(parent_name) + if parent is not None: + setattr(parent, attr, sys.modules[name]) return sys.modules[name] -# --- Stub database package and submodules --- -database_mod = _stub_module("database") -if not hasattr(database_mod, '__path__'): - database_mod.__path__ = [] -for submodule in [ +_DATABASE_SUBMODULES = ( "redis_db", "memories", "conversations", @@ -48,7 +51,65 @@ def _stub_module(name: str) -> types.ModuleType: "daily_summaries", "mem_db", "notifications", -]: +) +_RESTORED_MODULES = tuple( + ["database"] + + [f"database.{submodule}" for submodule in _DATABASE_SUBMODULES] + + [ + "utils", + "utils.llm", + "utils.llms", + "utils.llms.memory", + "utils.llm.clients", + "utils.llm.usage_tracker", + "utils.llm.goals", + ] +) +_PARENT_ATTRS = tuple( + [("database", submodule) for submodule in _DATABASE_SUBMODULES] + + [ + ("utils", "llm"), + ("utils", "llms"), + ("utils.llms", "memory"), + ("utils.llm", "clients"), + ("utils.llm", "usage_tracker"), + ("utils.llm", "goals"), + ] +) +_MISSING = object() +_saved_modules = {name: sys.modules.get(name, _MISSING) for name in _RESTORED_MODULES} +_saved_parent_attrs = { + (parent_name, attr): getattr(sys.modules.get(parent_name), attr, _MISSING) for parent_name, attr in _PARENT_ATTRS +} + + +def _restore_stub_modules(): + current_modules = {name: sys.modules.get(name, _MISSING) for name in _RESTORED_MODULES} + for name in sorted(_RESTORED_MODULES, key=lambda module_name: module_name.count("."), reverse=True): + original = _saved_modules[name] + if original is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + for (parent_name, attr), original in _saved_parent_attrs.items(): + parent = sys.modules.get(parent_name) + if parent is None: + continue + if original is _MISSING: + child_name = f"{parent_name}.{attr}" + current = current_modules.get(child_name, _MISSING) + if current is not _MISSING and getattr(parent, attr, _MISSING) is current: + delattr(parent, attr) + else: + setattr(parent, attr, original) + + +# --- Stub database package and submodules --- +database_mod = _stub_module("database") +if not hasattr(database_mod, '__path__'): + database_mod.__path__ = [] +for submodule in _DATABASE_SUBMODULES: mod = _stub_module(f"database.{submodule}") setattr(database_mod, submodule, mod) @@ -76,9 +137,27 @@ def _stub_module(name: str) -> types.ModuleType: clients_mod.llm_mini = MagicMock() if not hasattr(clients_mod, 'llm_medium'): clients_mod.llm_medium = MagicMock() +if not hasattr(clients_mod, 'get_llm'): + clients_mod.get_llm = MagicMock() -# Shortcut references to mocked db functions +# Stub usage tracking so importing utils.llm.goals does not pull optional usage deps. +usage_tracker_mod = _stub_module("utils.llm.usage_tracker") +usage_tracker_mod.track_usage = MagicMock(side_effect=lambda *args, **kwargs: nullcontext()) +usage_tracker_mod.Features = types.SimpleNamespace(GOALS="goals") + +# Shortcut references to mocked modules and functions +mock_llm_usage_db = sys.modules["database.llm_usage"] mock_goals_db = sys.modules["database.goals"] +mock_memories_db = sys.modules["database.memories"] +mock_conversations_db = sys.modules["database.conversations"] +mock_chat_db = sys.modules["database.chat"] +mock_vector_db = sys.modules["database.vector_db"] +mock_memory_module = sys.modules["utils.llms.memory"] + +try: + _goals_module = importlib.import_module("utils.llm.goals") +finally: + _restore_stub_modules() # --- Test data --- @@ -108,17 +187,41 @@ def _stub_module(name: str) -> types.ModuleType: } -def _import_fn(): - """Lazy import to avoid capturing mock references at module load time.""" - from utils.llm.goals import extract_and_update_goal_progress +def _import_goals_module(): + """Return the isolated module imported while heavy dependencies were stubbed.""" + return _goals_module + + +def _run_with_llm(mock_llm, uid: str, text: str): + goals_module = _import_goals_module() + with patch.object(goals_module, "get_llm", MagicMock(return_value=mock_llm)): + return goals_module.extract_and_update_goal_progress(uid, text) + - return extract_and_update_goal_progress +def _reset_mock(mock, *, return_value=_MISSING, side_effect=_MISSING): + mock.reset_mock(return_value=True, side_effect=True) + if return_value is not _MISSING: + mock.return_value = return_value + if side_effect is not _MISSING: + mock.side_effect = side_effect @pytest.fixture(autouse=True) def reset_mocks(): - mock_goals_db.get_user_goals.reset_mock() - mock_goals_db.update_goal_progress.reset_mock() + _reset_mock(mock_llm_usage_db.record_llm_usage) + _reset_mock(mock_goals_db.get_user_goal, return_value=None) + _reset_mock(mock_goals_db.get_user_goals, return_value=[]) + _reset_mock(mock_goals_db.update_goal_progress) + _reset_mock(mock_memories_db.get_memories, return_value=[]) + _reset_mock(mock_conversations_db.get_conversations, return_value=[]) + _reset_mock(mock_conversations_db.get_conversations_by_id, return_value=[]) + _reset_mock(mock_chat_db.get_messages, return_value=[]) + _reset_mock(mock_vector_db.query_vectors, return_value=[]) + _reset_mock(mock_memory_module.get_prompt_memories, return_value=("TestUser", "some memories")) + + goals_module = _import_goals_module() + _reset_mock(goals_module.track_usage, side_effect=lambda *args, **kwargs: nullcontext()) + _reset_mock(goals_module.get_llm) yield @@ -131,8 +234,7 @@ def test_single_goal_one_llm_call(self): mock_llm = MagicMock() mock_llm.invoke.return_value = MagicMock(content='[]') - with patch("utils.llm.goals.llm_mini", mock_llm): - _import_fn()("uid-1", "I went for a walk today") + _run_with_llm(mock_llm, "uid-1", "I went for a walk today") assert mock_llm.invoke.call_count == 1 def test_three_goals_one_llm_call(self): @@ -141,8 +243,7 @@ def test_three_goals_one_llm_call(self): mock_llm = MagicMock() mock_llm.invoke.return_value = MagicMock(content='[]') - with patch("utils.llm.goals.llm_mini", mock_llm): - _import_fn()("uid-1", "Just had lunch") + _run_with_llm(mock_llm, "uid-1", "Just had lunch") assert mock_llm.invoke.call_count == 1 def test_prompt_contains_all_goals(self): @@ -151,8 +252,7 @@ def test_prompt_contains_all_goals(self): mock_llm = MagicMock() mock_llm.invoke.return_value = MagicMock(content='[]') - with patch("utils.llm.goals.llm_mini", mock_llm): - _import_fn()("uid-1", "Some message") + _run_with_llm(mock_llm, "uid-1", "Some message") prompt = mock_llm.invoke.call_args[0][0] assert "Save $10,000" in prompt @@ -171,8 +271,7 @@ def test_single_match_updates_db(self): content=json.dumps([{"goal_id": "goal-a", "found": True, "value": 2500, "reasoning": "saved $500 more"}]) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I saved another $500 today") + result = _run_with_llm(mock_llm, "uid-1", "I saved another $500 today") assert result["status"] == "updated" assert len(result["updates"]) == 1 @@ -193,8 +292,7 @@ def test_multiple_matches_update_all(self): ) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "Saved $1000 and finished reading a book") + result = _run_with_llm(mock_llm, "uid-1", "Saved $1000 and finished reading a book") assert result["status"] == "updated" assert len(result["updates"]) == 2 @@ -209,8 +307,7 @@ def test_no_match_returns_no_update(self): mock_llm = MagicMock() mock_llm.invoke.return_value = MagicMock(content='[]') - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "Weather is nice today") + result = _run_with_llm(mock_llm, "uid-1", "Weather is nice today") assert result["status"] == "no_update" mock_goals_db.update_goal_progress.assert_not_called() @@ -223,8 +320,7 @@ def test_same_value_not_updated(self): content=json.dumps([{"goal_id": "goal-a", "found": True, "value": 2000, "reasoning": "same value"}]) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I have $2000 saved") + result = _run_with_llm(mock_llm, "uid-1", "I have $2000 saved") assert result["status"] == "no_update" mock_goals_db.update_goal_progress.assert_not_called() @@ -237,8 +333,7 @@ def test_unknown_goal_id_ignored(self): content=json.dumps([{"goal_id": "nonexistent", "found": True, "value": 999, "reasoning": "wrong id"}]) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "Some message") + result = _run_with_llm(mock_llm, "uid-1", "Some message") assert result["status"] == "no_update" mock_goals_db.update_goal_progress.assert_not_called() @@ -252,8 +347,7 @@ def test_no_goals_returns_none(self): mock_goals_db.get_user_goals.return_value = [] mock_llm = MagicMock() - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I saved $500") + result = _run_with_llm(mock_llm, "uid-1", "I saved $500") assert result is None mock_llm.invoke.assert_not_called() @@ -263,8 +357,7 @@ def test_short_text_returns_none(self): mock_goals_db.get_user_goals.return_value = [GOAL_A] mock_llm = MagicMock() - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "hi") + result = _run_with_llm(mock_llm, "uid-1", "hi") assert result is None mock_llm.invoke.assert_not_called() @@ -273,8 +366,7 @@ def test_empty_text_returns_none(self): mock_goals_db.get_user_goals.return_value = [GOAL_A] mock_llm = MagicMock() - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "") + result = _run_with_llm(mock_llm, "uid-1", "") assert result is None mock_llm.invoke.assert_not_called() @@ -285,8 +377,7 @@ def test_malformed_llm_response_no_crash(self): mock_llm = MagicMock() mock_llm.invoke.return_value = MagicMock(content="Sorry, I cannot help with that.") - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I saved $500 today") + result = _run_with_llm(mock_llm, "uid-1", "I saved $500 today") assert result["status"] == "no_update" mock_goals_db.update_goal_progress.assert_not_called() @@ -297,8 +388,7 @@ def test_llm_exception_returns_error(self): mock_llm = MagicMock() mock_llm.invoke.side_effect = Exception("API timeout") - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I saved $500 today") + result = _run_with_llm(mock_llm, "uid-1", "I saved $500 today") assert result["status"] == "error" @@ -310,8 +400,7 @@ def test_negative_value_rejected(self): content=json.dumps([{"goal_id": "goal-a", "found": True, "value": -500, "reasoning": "negative"}]) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I lost $500") + result = _run_with_llm(mock_llm, "uid-1", "I lost $500") assert result["status"] == "no_update" mock_goals_db.update_goal_progress.assert_not_called() @@ -324,8 +413,7 @@ def test_nan_value_rejected(self): content='[{"goal_id": "goal-a", "found": true, "value": NaN, "reasoning": "bad"}]' ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "Something happened") + result = _run_with_llm(mock_llm, "uid-1", "Something happened") # NaN in JSON is invalid, so parsing fails gracefully assert result["status"] == "no_update" @@ -344,8 +432,7 @@ def test_duplicate_goal_id_deduped(self): ) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I saved $3000 or maybe $5000") + result = _run_with_llm(mock_llm, "uid-1", "I saved $3000 or maybe $5000") assert result["status"] == "updated" assert len(result["updates"]) == 1 @@ -360,8 +447,7 @@ def test_llm_returns_array_with_extra_text(self): content='Here is the analysis:\n[{"goal_id": "goal-a", "found": true, "value": 4000, "reasoning": "saved more"}]\nHope this helps!' ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I now have $4000 saved") + result = _run_with_llm(mock_llm, "uid-1", "I now have $4000 saved") assert result["status"] == "updated" assert result["updates"][0]["new_value"] == 4000 @@ -379,8 +465,7 @@ def test_one_bad_result_doesnt_block_others(self): ) ) - with patch("utils.llm.goals.llm_mini", mock_llm): - result = _import_fn()("uid-1", "I ran 50 miles and saved some money") + result = _run_with_llm(mock_llm, "uid-1", "I ran 50 miles and saved some money") assert result["status"] == "updated" assert len(result["updates"]) == 1