Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pytest tests/unit/test_action_item_idempotency.py -v
pytest tests/unit/test_tools_router.py -v
pytest tests/unit/test_kg_user_type_mismatch.py -v
pytest tests/unit/test_kg_edge_id_sanitization.py -v
pytest tests/unit/test_goal_extraction_batch.py -v
pytest tests/unit/test_listen_pipeline.py -v
pytest tests/unit/test_fair_use_models.py -v
pytest tests/unit/test_fair_use_engine.py -v
Expand Down
183 changes: 134 additions & 49 deletions backend/tests/unit/test_goal_extraction_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
Verifies extract_and_update_goal_progress makes exactly 1 LLM call regardless of goal count.
"""

import importlib
import json
import os
import sys
import types
from contextlib import nullcontext
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -21,14 +23,15 @@ def _stub_module(name: str) -> types.ModuleType:
if name not in sys.modules:
mod = types.ModuleType(name)
sys.modules[name] = mod
if "." in name:
parent_name, attr = name.rsplit(".", 1)
parent = sys.modules.get(parent_name)
if parent is not None:
setattr(parent, attr, sys.modules[name])
return sys.modules[name]


# --- Stub database package and submodules ---
database_mod = _stub_module("database")
if not hasattr(database_mod, '__path__'):
database_mod.__path__ = []
for submodule in [
_DATABASE_SUBMODULES = (
"redis_db",
"memories",
"conversations",
Expand All @@ -48,7 +51,65 @@ def _stub_module(name: str) -> types.ModuleType:
"daily_summaries",
"mem_db",
"notifications",
]:
)
_RESTORED_MODULES = tuple(
["database"]
+ [f"database.{submodule}" for submodule in _DATABASE_SUBMODULES]
+ [
"utils",
"utils.llm",
"utils.llms",
"utils.llms.memory",
"utils.llm.clients",
"utils.llm.usage_tracker",
"utils.llm.goals",
]
)
Comment thread
tianmind-studio marked this conversation as resolved.
_PARENT_ATTRS = tuple(
[("database", submodule) for submodule in _DATABASE_SUBMODULES]
+ [
("utils", "llm"),
("utils", "llms"),
("utils.llms", "memory"),
("utils.llm", "clients"),
("utils.llm", "usage_tracker"),
("utils.llm", "goals"),
]
)
_MISSING = object()
_saved_modules = {name: sys.modules.get(name, _MISSING) for name in _RESTORED_MODULES}
_saved_parent_attrs = {
(parent_name, attr): getattr(sys.modules.get(parent_name), attr, _MISSING) for parent_name, attr in _PARENT_ATTRS
}


def _restore_stub_modules():
current_modules = {name: sys.modules.get(name, _MISSING) for name in _RESTORED_MODULES}
for name in sorted(_RESTORED_MODULES, key=lambda module_name: module_name.count("."), reverse=True):
original = _saved_modules[name]
if original is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = original

for (parent_name, attr), original in _saved_parent_attrs.items():
parent = sys.modules.get(parent_name)
if parent is None:
continue
if original is _MISSING:
child_name = f"{parent_name}.{attr}"
current = current_modules.get(child_name, _MISSING)
if current is not _MISSING and getattr(parent, attr, _MISSING) is current:
delattr(parent, attr)
else:
setattr(parent, attr, original)


# --- Stub database package and submodules ---
database_mod = _stub_module("database")
if not hasattr(database_mod, '__path__'):
database_mod.__path__ = []
for submodule in _DATABASE_SUBMODULES:
mod = _stub_module(f"database.{submodule}")
setattr(database_mod, submodule, mod)

Expand Down Expand Up @@ -76,9 +137,27 @@ def _stub_module(name: str) -> types.ModuleType:
clients_mod.llm_mini = MagicMock()
if not hasattr(clients_mod, 'llm_medium'):
clients_mod.llm_medium = MagicMock()
if not hasattr(clients_mod, 'get_llm'):
clients_mod.get_llm = MagicMock()

# Shortcut references to mocked db functions
# Stub usage tracking so importing utils.llm.goals does not pull optional usage deps.
usage_tracker_mod = _stub_module("utils.llm.usage_tracker")
usage_tracker_mod.track_usage = MagicMock(side_effect=lambda *args, **kwargs: nullcontext())
usage_tracker_mod.Features = types.SimpleNamespace(GOALS="goals")
Comment thread
tianmind-studio marked this conversation as resolved.

# Shortcut references to mocked modules and functions
mock_llm_usage_db = sys.modules["database.llm_usage"]
mock_goals_db = sys.modules["database.goals"]
mock_memories_db = sys.modules["database.memories"]
mock_conversations_db = sys.modules["database.conversations"]
mock_chat_db = sys.modules["database.chat"]
mock_vector_db = sys.modules["database.vector_db"]
mock_memory_module = sys.modules["utils.llms.memory"]

try:
_goals_module = importlib.import_module("utils.llm.goals")
finally:
_restore_stub_modules()


# --- Test data ---
Expand Down Expand Up @@ -108,17 +187,41 @@ def _stub_module(name: str) -> types.ModuleType:
}


def _import_fn():
"""Lazy import to avoid capturing mock references at module load time."""
from utils.llm.goals import extract_and_update_goal_progress
def _import_goals_module():
"""Return the isolated module imported while heavy dependencies were stubbed."""
return _goals_module


def _run_with_llm(mock_llm, uid: str, text: str):
goals_module = _import_goals_module()
with patch.object(goals_module, "get_llm", MagicMock(return_value=mock_llm)):
return goals_module.extract_and_update_goal_progress(uid, text)


return extract_and_update_goal_progress
def _reset_mock(mock, *, return_value=_MISSING, side_effect=_MISSING):
mock.reset_mock(return_value=True, side_effect=True)
if return_value is not _MISSING:
mock.return_value = return_value
if side_effect is not _MISSING:
mock.side_effect = side_effect


@pytest.fixture(autouse=True)
def reset_mocks():
mock_goals_db.get_user_goals.reset_mock()
mock_goals_db.update_goal_progress.reset_mock()
_reset_mock(mock_llm_usage_db.record_llm_usage)
_reset_mock(mock_goals_db.get_user_goal, return_value=None)
_reset_mock(mock_goals_db.get_user_goals, return_value=[])
_reset_mock(mock_goals_db.update_goal_progress)
_reset_mock(mock_memories_db.get_memories, return_value=[])
_reset_mock(mock_conversations_db.get_conversations, return_value=[])
_reset_mock(mock_conversations_db.get_conversations_by_id, return_value=[])
_reset_mock(mock_chat_db.get_messages, return_value=[])
_reset_mock(mock_vector_db.query_vectors, return_value=[])
_reset_mock(mock_memory_module.get_prompt_memories, return_value=("TestUser", "some memories"))

goals_module = _import_goals_module()
_reset_mock(goals_module.track_usage, side_effect=lambda *args, **kwargs: nullcontext())
_reset_mock(goals_module.get_llm)
yield


Expand All @@ -131,8 +234,7 @@ def test_single_goal_one_llm_call(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value = MagicMock(content='[]')

with patch("utils.llm.goals.llm_mini", mock_llm):
_import_fn()("uid-1", "I went for a walk today")
_run_with_llm(mock_llm, "uid-1", "I went for a walk today")
assert mock_llm.invoke.call_count == 1

def test_three_goals_one_llm_call(self):
Expand All @@ -141,8 +243,7 @@ def test_three_goals_one_llm_call(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value = MagicMock(content='[]')

with patch("utils.llm.goals.llm_mini", mock_llm):
_import_fn()("uid-1", "Just had lunch")
_run_with_llm(mock_llm, "uid-1", "Just had lunch")
assert mock_llm.invoke.call_count == 1

def test_prompt_contains_all_goals(self):
Expand All @@ -151,8 +252,7 @@ def test_prompt_contains_all_goals(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value = MagicMock(content='[]')

with patch("utils.llm.goals.llm_mini", mock_llm):
_import_fn()("uid-1", "Some message")
_run_with_llm(mock_llm, "uid-1", "Some message")

prompt = mock_llm.invoke.call_args[0][0]
assert "Save $10,000" in prompt
Expand All @@ -171,8 +271,7 @@ def test_single_match_updates_db(self):
content=json.dumps([{"goal_id": "goal-a", "found": True, "value": 2500, "reasoning": "saved $500 more"}])
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I saved another $500 today")
result = _run_with_llm(mock_llm, "uid-1", "I saved another $500 today")

assert result["status"] == "updated"
assert len(result["updates"]) == 1
Expand All @@ -193,8 +292,7 @@ def test_multiple_matches_update_all(self):
)
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "Saved $1000 and finished reading a book")
result = _run_with_llm(mock_llm, "uid-1", "Saved $1000 and finished reading a book")

assert result["status"] == "updated"
assert len(result["updates"]) == 2
Expand All @@ -209,8 +307,7 @@ def test_no_match_returns_no_update(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value = MagicMock(content='[]')

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "Weather is nice today")
result = _run_with_llm(mock_llm, "uid-1", "Weather is nice today")

assert result["status"] == "no_update"
mock_goals_db.update_goal_progress.assert_not_called()
Expand All @@ -223,8 +320,7 @@ def test_same_value_not_updated(self):
content=json.dumps([{"goal_id": "goal-a", "found": True, "value": 2000, "reasoning": "same value"}])
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I have $2000 saved")
result = _run_with_llm(mock_llm, "uid-1", "I have $2000 saved")

assert result["status"] == "no_update"
mock_goals_db.update_goal_progress.assert_not_called()
Expand All @@ -237,8 +333,7 @@ def test_unknown_goal_id_ignored(self):
content=json.dumps([{"goal_id": "nonexistent", "found": True, "value": 999, "reasoning": "wrong id"}])
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "Some message")
result = _run_with_llm(mock_llm, "uid-1", "Some message")

assert result["status"] == "no_update"
mock_goals_db.update_goal_progress.assert_not_called()
Expand All @@ -252,8 +347,7 @@ def test_no_goals_returns_none(self):
mock_goals_db.get_user_goals.return_value = []
mock_llm = MagicMock()

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I saved $500")
result = _run_with_llm(mock_llm, "uid-1", "I saved $500")

assert result is None
mock_llm.invoke.assert_not_called()
Expand All @@ -263,8 +357,7 @@ def test_short_text_returns_none(self):
mock_goals_db.get_user_goals.return_value = [GOAL_A]
mock_llm = MagicMock()

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "hi")
result = _run_with_llm(mock_llm, "uid-1", "hi")

assert result is None
mock_llm.invoke.assert_not_called()
Expand All @@ -273,8 +366,7 @@ def test_empty_text_returns_none(self):
mock_goals_db.get_user_goals.return_value = [GOAL_A]
mock_llm = MagicMock()

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "")
result = _run_with_llm(mock_llm, "uid-1", "")

assert result is None
mock_llm.invoke.assert_not_called()
Expand All @@ -285,8 +377,7 @@ def test_malformed_llm_response_no_crash(self):
mock_llm = MagicMock()
mock_llm.invoke.return_value = MagicMock(content="Sorry, I cannot help with that.")

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I saved $500 today")
result = _run_with_llm(mock_llm, "uid-1", "I saved $500 today")

assert result["status"] == "no_update"
mock_goals_db.update_goal_progress.assert_not_called()
Expand All @@ -297,8 +388,7 @@ def test_llm_exception_returns_error(self):
mock_llm = MagicMock()
mock_llm.invoke.side_effect = Exception("API timeout")

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I saved $500 today")
result = _run_with_llm(mock_llm, "uid-1", "I saved $500 today")

assert result["status"] == "error"

Expand All @@ -310,8 +400,7 @@ def test_negative_value_rejected(self):
content=json.dumps([{"goal_id": "goal-a", "found": True, "value": -500, "reasoning": "negative"}])
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I lost $500")
result = _run_with_llm(mock_llm, "uid-1", "I lost $500")

assert result["status"] == "no_update"
mock_goals_db.update_goal_progress.assert_not_called()
Expand All @@ -324,8 +413,7 @@ def test_nan_value_rejected(self):
content='[{"goal_id": "goal-a", "found": true, "value": NaN, "reasoning": "bad"}]'
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "Something happened")
result = _run_with_llm(mock_llm, "uid-1", "Something happened")

# NaN in JSON is invalid, so parsing fails gracefully
assert result["status"] == "no_update"
Expand All @@ -344,8 +432,7 @@ def test_duplicate_goal_id_deduped(self):
)
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I saved $3000 or maybe $5000")
result = _run_with_llm(mock_llm, "uid-1", "I saved $3000 or maybe $5000")

assert result["status"] == "updated"
assert len(result["updates"]) == 1
Expand All @@ -360,8 +447,7 @@ def test_llm_returns_array_with_extra_text(self):
content='Here is the analysis:\n[{"goal_id": "goal-a", "found": true, "value": 4000, "reasoning": "saved more"}]\nHope this helps!'
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I now have $4000 saved")
result = _run_with_llm(mock_llm, "uid-1", "I now have $4000 saved")

assert result["status"] == "updated"
assert result["updates"][0]["new_value"] == 4000
Expand All @@ -379,8 +465,7 @@ def test_one_bad_result_doesnt_block_others(self):
)
)

with patch("utils.llm.goals.llm_mini", mock_llm):
result = _import_fn()("uid-1", "I ran 50 miles and saved some money")
result = _run_with_llm(mock_llm, "uid-1", "I ran 50 miles and saved some money")

assert result["status"] == "updated"
assert len(result["updates"]) == 1
Expand Down
Loading