diff --git a/tests/unit/agents/test_agents.py b/tests/unit/agents/test_agents.py new file mode 100644 index 0000000..ca66bb3 --- /dev/null +++ b/tests/unit/agents/test_agents.py @@ -0,0 +1,126 @@ +"""Tests for fenn/agents/agent.py""" +import pytest +from unittest.mock import MagicMock, patch, mock_open + + +FAKE_CONFIG = """ +agent: + system_prompt: "You are a helpful assistant." + max_iterations: 5 +""" + + +def _make_agent(): + """Construct an Agent with all heavy dependencies mocked out.""" + mock_llm = MagicMock() + + with patch("fenn.agents.agent.yaml.safe_load", return_value={ + "agent": {"system_prompt": "You are helpful.", "max_iterations": 3} + }): + with patch("builtins.open", mock_open(read_data=FAKE_CONFIG)): + with patch("fenn.agents.agent.ThinkNode") as MockThink, \ + patch("fenn.agents.agent.ActNode") as MockAct, \ + patch("fenn.agents.agent.ObserveNode") as MockObserve, \ + patch("fenn.agents.agent.Flow") as MockFlow, \ + patch("fenn.agents.agent.get_tool_schema", return_value=[]): + + # Make node - operator return the node itself so chaining doesn't error + for MockNode in (MockThink, MockAct, MockObserve): + instance = MockNode.return_value + instance.__sub__ = MagicMock(return_value=instance) + instance.__rshift__ = MagicMock(return_value=instance) + + from fenn.agents.agent import Agent + agent = Agent(config="fake_config.yaml", llm=mock_llm) + + return agent, mock_llm, MockFlow + + +# ── __init__ ─────────────────────────────────────────────────────────────────── + +def test_agent_stores_llm(): + agent, mock_llm, _ = _make_agent() + assert agent.llm is mock_llm + + +def test_agent_loads_config(): + agent, _, _ = _make_agent() + assert agent.config["agent"]["system_prompt"] == "You are helpful." + assert agent.config["agent"]["max_iterations"] == 3 + + +def test_agent_creates_flow(): + agent, _, MockFlow = _make_agent() + assert agent.flow is MockFlow.return_value + + +def test_agent_config_file_not_found(): + from fenn.agents.agent import Agent + with pytest.raises((FileNotFoundError, OSError)): + Agent(config="/nonexistent/path/config.yaml", llm=MagicMock()) + + +# ── run ──────────────────────────────────────────────────────────────────────── + +def test_run_calls_flow_run(): + agent, _, _ = _make_agent() + agent.flow.run = MagicMock() + + # Simulate flow.run appending a response message + def fake_run(shared): + shared["messages"].append({"role": "assistant", "content": "Done!"}) + + agent.flow.run.side_effect = fake_run + + with patch("fenn.agents.agent.get_tool_schema", return_value=[]): + result = agent.run("Hello") + + agent.flow.run.assert_called_once() + assert result == "Done!" + + +def test_run_returns_last_message_content(): + agent, _, _ = _make_agent() + + def fake_run(shared): + shared["messages"].append({"role": "assistant", "content": "Final answer"}) + + agent.flow.run.side_effect = fake_run + + result = agent.run("What is 2+2?") + assert result == "Final answer" + + +def test_run_builds_correct_shared_state(): + agent, mock_llm, _ = _make_agent() + captured = {} + + def fake_run(shared): + captured.update(shared) + shared["messages"].append({"role": "assistant", "content": "ok"}) + + agent.flow.run.side_effect = fake_run + + with patch("fenn.agents.agent.get_tool_schema", return_value=["tool1"]): + agent.run("Test input") + + assert captured["llm"] is mock_llm + assert captured["messages"][0] == {"role": "system", "content": "You are helpful."} + assert captured["messages"][1] == {"role": "user", "content": "Test input"} + assert captured["tools"] == ["tool1"] + assert captured["iterations"] == 0 + assert captured["max_iterations"] == 3 + assert captured["last_thought"] is None + assert captured["last_observation"] is None + + +def test_run_with_different_user_inputs(): + agent, _, _ = _make_agent() + + for user_input in ["Hello", "What time is it?", ""]: + def fake_run(shared): + shared["messages"].append({"role": "assistant", "content": "response"}) + + agent.flow.run.side_effect = fake_run + result = agent.run(user_input) + assert result == "response" diff --git a/tests/unit/agents/test_init.py b/tests/unit/agents/test_init.py new file mode 100644 index 0000000..b3e1ac0 --- /dev/null +++ b/tests/unit/agents/test_init.py @@ -0,0 +1,650 @@ +import asyncio +import warnings + +import pytest + +from fenn.agents import ( + AsyncBatchFlow, + AsyncBatchNode, + AsyncFlow, + AsyncNode, + AsyncParallelBatchFlow, + AsyncParallelBatchNode, + BaseNode, + BatchFlow, + BatchNode, + Flow, + Node, +) + + +class TestBaseNode: + def test_init_sets_empty_params_and_successors(self): + node = BaseNode() + assert node.params == {} + assert node.successors == {} + + def test_set_params(self): + node = BaseNode() + node.set_params({"a": 1}) + assert node.params == {"a": 1} + + def test_default_prep_exec_post_return_none(self): + node = BaseNode() + assert node.prep({}) is None + assert node.exec(None) is None + assert node.post({}, None, None) is None + + def test_run_calls_prep_exec_post_in_order(self): + calls = [] + + class Tracked(BaseNode): + def prep(self, shared): + calls.append("prep") + return "prepped" + + def exec(self, prep_res): + calls.append(("exec", prep_res)) + return "executed" + + def post(self, shared, prep_res, exec_res): + calls.append(("post", prep_res, exec_res)) + return "done" + + node = Tracked() + result = node.run({}) + assert calls == ["prep", ("exec", "prepped"), ("post", "prepped", "executed")] + assert result == "done" + + def test_run_warns_when_successors_present(self): + node = BaseNode() + node.successors["default"] = BaseNode() + with pytest.warns(UserWarning, match="won't run successors"): + node.run({}) + + def test_run_no_warning_without_successors(self): + node = BaseNode() + with warnings.catch_warnings(): + warnings.simplefilter("error") + node.run({}) # should not raise + + +class TestNode: + def test_default_max_retries_and_wait(self): + node = Node() + assert node.max_retries == 1 + assert node.wait == 0 + + def test_exec_succeeds_first_try(self): + class Succeeds(Node): + def exec(self, prep_res): + return "ok" + + node = Succeeds() + assert node._exec(None) == "ok" + + def test_exec_fallback_raises_by_default(self): + class AlwaysFails(Node): + def exec(self, prep_res): + raise ValueError("boom") + + node = AlwaysFails(max_retries=1) + with pytest.raises(ValueError, match="boom"): + node._exec(None) + + def test_retries_then_succeeds(self): + attempts = {"count": 0} + + class FlakyThenWorks(Node): + def exec(self, prep_res): + attempts["count"] += 1 + if attempts["count"] < 3: + raise RuntimeError("transient") + return "success" + + node = FlakyThenWorks(max_retries=5) + result = node._exec(None) + assert result == "success" + assert attempts["count"] == 3 + + def test_exec_fallback_called_after_exhausting_retries(self): + class FailsWithFallback(Node): + def exec(self, prep_res): + raise ValueError("always fails") + + def exec_fallback(self, prep_res, exc): + return f"fallback: {exc}" + + node = FailsWithFallback(max_retries=2) + result = node._exec(None) + assert result == "fallback: always fails" + + def test_wait_triggers_sleep_between_retries(self, monkeypatch): + sleeps = [] + monkeypatch.setattr("fenn.agents.time.sleep", lambda s: sleeps.append(s)) + + class AlwaysFails(Node): + def exec(self, prep_res): + raise ValueError("fail") + + def exec_fallback(self, prep_res, exc): + return "fallback" + + node = AlwaysFails(max_retries=3, wait=2) + node._exec(None) + assert sleeps == [2, 2] + + def test_cur_retry_tracks_attempt_number(self): + attempts = [] + + class TrackRetry(Node): + def exec(self, prep_res): + attempts.append(self.cur_retry) + raise ValueError("fail") + + def exec_fallback(self, prep_res, exc): + return "fallback" + + node = TrackRetry(max_retries=3) + node._exec(None) + assert attempts == [0, 1, 2] + + def test_full_run_pipeline(self): + class Adder(Node): + def prep(self, shared): + return shared["value"] + + def exec(self, prep_res): + return prep_res + 1 + + def post(self, shared, prep_res, exec_res): + shared["result"] = exec_res + return "done" + + node = Adder() + shared = {"value": 5} + action = node._run(shared) + assert shared["result"] == 6 + assert action == "done" + + +class TestBatchNode: + def test_exec_processes_each_item(self): + class Doubler(BatchNode): + def exec(self, item): + return item * 2 + + node = Doubler() + result = node._exec([1, 2, 3]) + assert result == [2, 4, 6] + + def test_exec_with_none_returns_empty_list(self): + class Doubler(BatchNode): + def exec(self, item): + return item * 2 + + node = Doubler() + assert node._exec(None) == [] + + def test_exec_with_empty_list(self): + class Doubler(BatchNode): + def exec(self, item): + return item * 2 + + node = Doubler() + assert node._exec([]) == [] + + def test_batch_node_retries_per_item(self): + attempts = {"a": 0, "b": 0} + + class FlakyPerItem(BatchNode): + def exec(self, item): + attempts[item] += 1 + if attempts[item] < 2: + raise RuntimeError("fail once") + return item + + def exec_fallback(self, prep_res, exc): + return "fallback" + + node = FlakyPerItem(max_retries=3) + result = node._exec(["a", "b"]) + assert result == ["a", "b"] + assert attempts == {"a": 2, "b": 2} + + +class _Echo(BaseNode): + """Simple node: records its label and returns a fixed next action. + + Uses instance attributes (not self.params) because Flow._orch() + overwrites node.params on every step. + """ + + def __init__(self, label="node", next_action="default"): + super().__init__() + self.label = label + self.next_action = next_action + + def post(self, shared, prep_res, exec_res): + shared.setdefault("visited", []).append(self.label) + return self.next_action + + +class TestFlow: + def test_start_sets_start_node_and_returns_it(self): + flow = Flow() + node = BaseNode() + result = flow.start(node) + assert flow.start_node is node + assert result is node + + def test_connect_returns_self(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + result = flow.connect(a, b) + assert result is flow + + def test_connect_sets_successor(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + flow.connect(a, b, action="next") + assert a.successors["next"] is b + + def test_connect_default_action(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + flow.connect(a, b) + assert a.successors["default"] is b + + def test_connect_none_dest_sets_terminal(self): + flow = Flow() + a = BaseNode() + flow.connect(a, None, action="done") + assert flow.get_next_node(a, "done") is None + + def test_connect_overwrite_warns(self): + flow = Flow() + a, b, c = BaseNode(), BaseNode(), BaseNode() + flow.connect(a, b, action="x") + with pytest.warns(UserWarning, match="Overwriting successor"): + flow.connect(a, c, action="x") + + def test_get_next_node_returns_successor(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + flow.connect(a, b, action="go") + assert flow.get_next_node(a, "go") is b + + def test_get_next_node_defaults_to_default_action(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + flow.connect(a, b) + assert flow.get_next_node(a, None) is b + + def test_get_next_node_unknown_action_warns(self): + flow = Flow() + a, b = BaseNode(), BaseNode() + flow.connect(a, b, action="default") + with pytest.warns(UserWarning, match="Flow ends"): + result = flow.get_next_node(a, "unknown") + assert result is None + + def test_get_next_node_no_successors_no_warning(self): + flow = Flow() + a = BaseNode() + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = flow.get_next_node(a, "anything") + assert result is None + + def test_orch_traverses_chain(self): + a = _Echo(label="a") + b = _Echo(label="b") + c = _Echo(label="c") + flow = Flow(start=a) + flow.connect(a, b).connect(b, c).connect(c, None) + + shared = {} + flow._orch(shared) + assert shared["visited"] == ["a", "b", "c"] + + def test_post_returns_exec_res(self): + flow = Flow() + assert flow.post({}, "prep", "exec_result") == "exec_result" + + def test_full_flow_run(self): + a = _Echo(label="a", next_action="next") + b = _Echo(label="b", next_action="default") + + flow = Flow(start=a) + flow.connect(a, b, action="next") + flow.connect(b, None) + + shared = {} + result = flow._run(shared) + assert shared["visited"] == ["a", "b"] + assert result == "default" + + def test_branching_flow(self): + """Flow takes different paths based on returned action.""" + + class Router(BaseNode): + def post(self, shared, prep_res, exec_res): + shared.setdefault("visited", []).append("router") + return shared["route"] + + class Leaf(BaseNode): + def __init__(self, name): + super().__init__() + self.name = name + + def post(self, shared, prep_res, exec_res): + shared.setdefault("visited", []).append(self.name) + return "default" + + router = Router() + left = Leaf("left") + right = Leaf("right") + + flow = Flow(start=router) + flow.connect(router, left, action="left") + flow.connect(router, right, action="right") + flow.connect(left, None) + flow.connect(right, None) + + shared = {"route": "right"} + flow._orch(shared) + assert shared["visited"] == ["router", "right"] + + +class TestBatchFlow: + def test_run_iterates_prep_results(self): + processed = [] + + class Worker(_Echo): + def post(self, shared, prep_res, exec_res): + processed.append(self.params.get("item")) + return "default" + + worker = Worker() + worker.successors["default"] = ( + None # no successor, terminal via warning suppressed + ) + + class MyBatchFlow(BatchFlow): + def prep(self, shared): + return [{"item": "a"}, {"item": "b"}, {"item": "c"}] + + flow = MyBatchFlow(start=worker) + # avoid "Flow ends" warning noise + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + flow._run({}) + + assert processed == ["a", "b", "c"] + + def test_run_with_empty_prep_returns_post_result(self): + class EmptyBatchFlow(BatchFlow): + def prep(self, shared): + return None + + def post(self, shared, prep_res, exec_res): + return ("post", prep_res, exec_res) + + flow = EmptyBatchFlow(start=BaseNode()) + result = flow._run({}) + assert result == ("post", [], None) + + +class TestAsyncNode: + def test_sync_run_raises_runtime_error(self): + node = AsyncNode() + with pytest.raises(RuntimeError, match="Use run_async"): + node._run({}) + + def test_run_async_calls_pipeline(self): + calls = [] + + class Tracked(AsyncNode): + async def prep_async(self, shared): + calls.append("prep") + return "p" + + async def exec_async(self, prep_res): + calls.append(("exec", prep_res)) + return "e" + + async def post_async(self, shared, prep_res, exec_res): + calls.append(("post", prep_res, exec_res)) + return "done" + + node = Tracked() + result = asyncio.run(node.run_async({})) + assert calls == ["prep", ("exec", "p"), ("post", "p", "e")] + assert result == "done" + + def test_run_async_warns_with_successors(self): + class Simple(AsyncNode): + async def exec_async(self, prep_res): + return "e" + + node = Simple() + node.successors["default"] = AsyncNode() + + async def run_it(): + with pytest.warns(UserWarning, match="won't run successors"): + await node.run_async({}) + + asyncio.run(run_it()) + + def test_exec_fallback_async_raises_by_default(self): + class AlwaysFails(AsyncNode): + async def exec_async(self, prep_res): + raise ValueError("async boom") + + node = AlwaysFails(max_retries=1) + with pytest.raises(ValueError, match="async boom"): + asyncio.run(node._exec(None)) + + def test_exec_fallback_async_called_after_retries(self): + class FailsWithFallback(AsyncNode): + async def exec_async(self, prep_res): + raise ValueError("nope") + + async def exec_fallback_async(self, prep_res, exc): + return f"fallback: {exc}" + + node = FailsWithFallback(max_retries=2) + result = asyncio.run(node._exec(None)) + assert result == "fallback: nope" + + def test_async_retries_then_succeeds(self): + attempts = {"count": 0} + + class FlakyAsync(AsyncNode): + async def exec_async(self, prep_res): + attempts["count"] += 1 + if attempts["count"] < 3: + raise RuntimeError("transient") + return "success" + + node = FlakyAsync(max_retries=5) + result = asyncio.run(node._exec(None)) + assert result == "success" + assert attempts["count"] == 3 + + def test_async_wait_triggers_sleep(self, monkeypatch): + sleeps = [] + + async def fake_sleep(s): + sleeps.append(s) + + monkeypatch.setattr("fenn.agents.asyncio.sleep", fake_sleep) + + class AlwaysFails(AsyncNode): + async def exec_async(self, prep_res): + raise ValueError("fail") + + async def exec_fallback_async(self, prep_res, exc): + return "fallback" + + node = AlwaysFails(max_retries=3, wait=1) + asyncio.run(node._exec(None)) + assert sleeps == [1, 1] + + +class TestAsyncBatchNode: + def test_processes_items_sequentially(self): + order = [] + + class Worker(AsyncBatchNode): + async def exec_async(self, item): + order.append(item) + return item * 2 + + node = Worker() + result = asyncio.run(node._exec([1, 2, 3])) + assert result == [2, 4, 6] + assert order == [1, 2, 3] + + +class TestAsyncParallelBatchNode: + def test_processes_items_in_parallel(self): + class Worker(AsyncParallelBatchNode): + async def exec_async(self, item): + await asyncio.sleep(0) + return item * 10 + + node = Worker() + result = asyncio.run(node._exec([1, 2, 3])) + assert result == [10, 20, 30] + + +class _AsyncEcho(AsyncNode): + """Async version of _Echo - uses instance attributes, not self.params.""" + + def __init__(self, label="node", next_action="default"): + super().__init__() + self.label = label + self.next_action = next_action + + async def post_async(self, shared, prep_res, exec_res): + shared.setdefault("visited", []).append(self.label) + return self.next_action + + +class TestAsyncFlow: + def test_orch_async_with_async_nodes(self): + a = _AsyncEcho(label="a") + b = _AsyncEcho(label="b") + + flow = AsyncFlow(start=a) + flow.connect(a, b).connect(b, None) + + shared = {} + asyncio.run(flow._orch_async(shared)) + assert shared["visited"] == ["a", "b"] + + def test_orch_async_with_mixed_sync_and_async_nodes(self): + a = _AsyncEcho(label="a") + b = _Echo(label="b") # sync node + + flow = AsyncFlow(start=a) + flow.connect(a, b).connect(b, None) + + shared = {} + asyncio.run(flow._orch_async(shared)) + assert shared["visited"] == ["a", "b"] + + def test_post_async_default_returns_exec_res(self): + flow = AsyncFlow() + result = asyncio.run(flow.post_async({}, "prep", "exec")) + assert result == "exec" + + def test_run_async_full_pipeline(self): + a = _AsyncEcho(label="a") + + flow = AsyncFlow(start=a) + flow.connect(a, None) + + shared = {} + result = asyncio.run(flow._run_async(shared)) + assert shared["visited"] == ["a"] + assert result == "default" + + +class TestAsyncBatchFlow: + def test_run_async_iterates_prep_results(self): + processed = [] + + class Worker(_AsyncEcho): + async def post_async(self, shared, prep_res, exec_res): + processed.append(self.params.get("item")) + return "default" + + worker = Worker() + + class MyAsyncBatchFlow(AsyncBatchFlow): + async def prep_async(self, shared): + return [{"item": "a"}, {"item": "b"}] + + flow = MyAsyncBatchFlow(start=worker) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + asyncio.run(flow._run_async({})) + + assert processed == ["a", "b"] + + def test_run_async_with_empty_prep(self): + class EmptyAsyncBatchFlow(AsyncBatchFlow): + async def prep_async(self, shared): + return None + + async def post_async(self, shared, prep_res, exec_res): + return ("post", prep_res, exec_res) + + flow = EmptyAsyncBatchFlow(start=BaseNode()) + result = asyncio.run(flow._run_async({})) + assert result == ("post", [], None) + + +class TestAsyncParallelBatchFlow: + def test_run_async_processes_concurrently(self): + processed = [] + + class Worker(_AsyncEcho): + async def post_async(self, shared, prep_res, exec_res): + await asyncio.sleep(0) + processed.append(self.params.get("item")) + return "default" + + worker = Worker() + + class MyParallelBatchFlow(AsyncParallelBatchFlow): + async def prep_async(self, shared): + return [{"item": "a"}, {"item": "b"}, {"item": "c"}] + + flow = MyParallelBatchFlow(start=worker) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + asyncio.run(flow._run_async({})) + + assert sorted(processed) == ["a", "b", "c"] + + +class TestModuleExports: + def test_llmclient_importable(self): + from fenn.agents import LLMClient + + assert LLMClient is not None + + def test_ragnode_importable(self): + from fenn.agents import RAGNode + + assert RAGNode is not None + + def test_all_exports_present(self): + import fenn.agents as agents_module + + for name in agents_module.__all__: + assert hasattr(agents_module, name), f"{name} missing from module" diff --git a/tests/unit/agents/test_llm_methods.py b/tests/unit/agents/test_llm_methods.py new file mode 100644 index 0000000..4561d65 --- /dev/null +++ b/tests/unit/agents/test_llm_methods.py @@ -0,0 +1,310 @@ +""" +Additional tests for LLMClient._openai_client, .chat_complete, .ask, and .stream. +Merge with or run alongside test_llm.py. +""" +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest +from pydantic import BaseModel + +from fenn.agents.llm import LLMClient + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _make_client(provider="openai", model="gpt-4o-mini", api_key="test-key"): + """Build an LLMClient bypassing env-var lookup.""" + client = LLMClient.__new__(LLMClient) + client.provider = provider + client.model = model + client.api_key = api_key + client.base_url = "https://api.openai.com/v1" + return client + + +def _make_completion(content="Hello!"): + """Return a minimal fake openai ChatCompletion object.""" + message = SimpleNamespace(content=content) + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice]) + + +# ── _openai_client ───────────────────────────────────────────────────────────── + +def test_openai_client_returns_openai_instance(monkeypatch): + fake_openai_cls = MagicMock() + fake_instance = MagicMock() + fake_openai_cls.return_value = fake_instance + + with patch.dict("sys.modules", {"openai": MagicMock(OpenAI=fake_openai_cls)}): + client = _make_client() + result = client._openai_client() + + fake_openai_cls.assert_called_once_with(api_key="test-key", base_url=client.base_url) + assert result is fake_instance + + +def test_openai_client_raises_on_missing_package(): + client = _make_client() + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + client._openai_client() + + +# ── chat_complete ────────────────────────────────────────────────────────────── + +def test_chat_complete_returns_text(): + client = _make_client() + fake_response = _make_completion("Hi there") + + mock_openai = MagicMock() + mock_openai.chat.completions.create.return_value = fake_response + + with patch.object(client, "_openai_client", return_value=mock_openai): + with patch("fenn.agents.llm.RateLimitError", create=True, new=Exception): + # patch the import inside chat_complete + with patch("fenn.agents.llm.__builtins__", {}): + pass # just ensure no side-effects + + # Simpler: patch at the openai module level + import fenn.agents.llm as llm_module + fake_rle = type("RateLimitError", (Exception,), {}) + with patch.object(llm_module, "time"): + with patch("builtins.__import__", side_effect=lambda name, *a, **kw: ( + SimpleNamespace(RateLimitError=fake_rle) if name == "openai" else __import__(name, *a, **kw) + )): + pass + + # Cleanest approach: mock the whole call chain + result = client.chat_complete([{"role": "user", "content": "Hello"}]) + + assert result == "Hi there" + mock_openai.chat.completions.create.assert_called_once() + + +def test_chat_complete_plain_text(monkeypatch): + """chat_complete without schema returns a string.""" + client = _make_client() + fake_response = _make_completion("plain response") + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = fake_response + + with patch.object(client, "_openai_client", return_value=mock_oa): + result = client.chat_complete([{"role": "user", "content": "hi"}]) + + assert result == "plain response" + + +def test_chat_complete_with_schema(monkeypatch): + """When schema is provided, response JSON is validated and returned as model.""" + + class Reply(BaseModel): + answer: str + + payload = json.dumps({"answer": "42"}) + fake_response = _make_completion(payload) + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = fake_response + + client = _make_client() + with patch.object(client, "_openai_client", return_value=mock_oa): + result = client.chat_complete( + [{"role": "user", "content": "What is the answer?"}], + schema=Reply, + ) + + assert isinstance(result, Reply) + assert result.answer == "42" + # Confirm response_format was injected + _, call_kwargs = mock_oa.chat.completions.create.call_args + assert call_kwargs.get("response_format") == {"type": "json_object"} + + +def test_chat_complete_with_schema_simple(): + """Schema path: JSON body is parsed and returned as a Pydantic model.""" + + class Point(BaseModel): + x: int + y: int + + client = _make_client() + fake_response = _make_completion(json.dumps({"x": 1, "y": 2})) + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = fake_response + + with patch.object(client, "_openai_client", return_value=mock_oa): + result = client.chat_complete( + [{"role": "user", "content": "give me a point"}], + schema=Point, + ) + + assert result == Point(x=1, y=2) + + +def test_chat_complete_does_not_mutate_caller_messages(): + """chat_complete should shallow-copy messages, not mutate the caller's list.""" + client = _make_client() + msgs = [{"role": "user", "content": "hello"}] + original_content = msgs[0]["content"] + + fake_response = _make_completion("ok") + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = fake_response + + with patch.object(client, "_openai_client", return_value=mock_oa): + client.chat_complete(msgs) + + assert msgs[0]["content"] == original_content + + +def test_chat_complete_retries_on_rate_limit(): + """Should retry up to `retries` times on RateLimitError, then raise.""" + import fenn.agents.llm as llm_module + + client = _make_client() + RLE = type("RateLimitError", (Exception,), {}) + + mock_oa = MagicMock() + mock_oa.chat.completions.create.side_effect = RLE("rate limited") + + with patch.object(client, "_openai_client", return_value=mock_oa): + with patch.object(llm_module, "time") as mock_time: + # patch the RateLimitError import inside chat_complete + with patch("builtins.__import__") as mock_import: + def side_import(name, *args, **kwargs): + if name == "openai": + return SimpleNamespace(RateLimitError=RLE) + return __import__(name, *args, **kwargs) + mock_import.side_effect = side_import + + with pytest.raises(RLE): + client.chat_complete( + [{"role": "user", "content": "hi"}], retries=3 + ) + + assert mock_time.sleep.call_count == 2 # retries-1 sleeps before final raise + + +def test_chat_complete_rate_limit_retry_succeeds(): + """Should succeed if a retry after RateLimitError returns a valid response.""" + import fenn.agents.llm as llm_module + + client = _make_client() + RLE = type("RateLimitError", (Exception,), {}) + fake_response = _make_completion("eventually ok") + + mock_oa = MagicMock() + mock_oa.chat.completions.create.side_effect = [RLE("limited"), fake_response] + + with patch.object(client, "_openai_client", return_value=mock_oa): + with patch.object(llm_module, "time"): + with patch("builtins.__import__") as mock_import: + def side_import(name, *args, **kwargs): + if name == "openai": + return SimpleNamespace(RateLimitError=RLE) + return __import__(name, *args, **kwargs) + mock_import.side_effect = side_import + + result = client.chat_complete( + [{"role": "user", "content": "hi"}], retries=3 + ) + + assert result == "eventually ok" + + +def test_chat_complete_raises_import_error_without_openai(): + """When openai is not installed, chat_complete raises ImportError.""" + client = _make_client() + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + client.chat_complete([{"role": "user", "content": "hi"}]) + + +# ── ask ──────────────────────────────────────────────────────────────────────── + +def test_ask_delegates_to_chat_complete(): + client = _make_client() + with patch.object(client, "chat_complete", return_value="answer") as mock_cc: + result = client.ask("What is 2+2?") + + assert result == "answer" + mock_cc.assert_called_once_with( + [{"role": "user", "content": "What is 2+2?"}], + schema=None, + retries=3, + ) + + +def test_ask_passes_schema_and_retries(): + class Ans(BaseModel): + value: int + + client = _make_client() + with patch.object(client, "chat_complete", return_value=Ans(value=4)) as mock_cc: + result = client.ask("2+2?", schema=Ans, retries=5) + + assert result == Ans(value=4) + mock_cc.assert_called_once_with( + [{"role": "user", "content": "2+2?"}], schema=Ans, retries=5 + ) + + +# ── stream ───────────────────────────────────────────────────────────────────── + +def _make_chunk(content): + delta = SimpleNamespace(content=content) + choice = SimpleNamespace(delta=delta) + return SimpleNamespace(choices=[choice]) + + +def test_stream_yields_tokens(): + client = _make_client() + chunks = [_make_chunk("Hello"), _make_chunk(", "), _make_chunk("world")] + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = iter(chunks) + + with patch.object(client, "_openai_client", return_value=mock_oa): + tokens = list(client.stream("Say hello")) + + assert tokens == ["Hello", ", ", "world"] + + +def test_stream_skips_empty_delta(): + client = _make_client() + chunks = [_make_chunk("Hi"), _make_chunk(""), _make_chunk(None), _make_chunk("!")] + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = iter(chunks) + + with patch.object(client, "_openai_client", return_value=mock_oa): + tokens = list(client.stream("hi")) + + assert tokens == ["Hi", "!"] + + +def test_stream_skips_chunk_without_choices(): + client = _make_client() + no_choices = SimpleNamespace(choices=[]) + normal = _make_chunk("ok") + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = iter([no_choices, normal]) + + with patch.object(client, "_openai_client", return_value=mock_oa): + tokens = list(client.stream("hi")) + + assert tokens == ["ok"] + + +def test_stream_passes_correct_params(): + client = _make_client() + mock_oa = MagicMock() + mock_oa.chat.completions.create.return_value = iter([]) + + with patch.object(client, "_openai_client", return_value=mock_oa): + list(client.stream("test prompt")) + + call_kwargs = mock_oa.chat.completions.create.call_args[1] + assert call_kwargs["stream"] is True + assert call_kwargs["model"] == client.model + assert call_kwargs["messages"] == [{"role": "user", "content": "test prompt"}] diff --git a/tests/unit/agents/test_llms.py b/tests/unit/agents/test_llms.py new file mode 100644 index 0000000..36cb321 --- /dev/null +++ b/tests/unit/agents/test_llms.py @@ -0,0 +1,267 @@ +"""Tests for fenn/agents/rag/llm.py (module-level ask/stream functions).""" +import json +import pytest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from fenn.agents.rag.llm import ( + DEFAULT_MODELS, + LOCAL_PROVIDERS, + PROVIDERS, + _detect_provider, + _resolve_key, + ask, + stream, +) + + +# ── _detect_provider ─────────────────────────────────────────────────────────── + +class TestDetectProvider: + @pytest.mark.parametrize("model,expected", [ + ("gpt-4o-mini", "openai"), + ("gpt-3.5-turbo", "openai"), + ("o1-mini", "openai"), + ("o3-large", "openai"), + ("o4-preview", "openai"), + ("gemini-2.0-flash", "gemini"), + ("claude-3-5-haiku", "anthropic"), + ("mistral-small", "mistral"), + ("codestral-latest", "mistral"), + ("command-r-plus", "cohere"), + ("grok-beta", "xai"), + ("deepseek-chat", "deepseek"), + ("llama-3.1-8b", "groq"), + ("mixtral-8x7b", "groq"), + ("openai/gpt-4o", "openrouter"), + (None, "openrouter"), + ]) + def test_from_model_name(self, model, expected): + assert _detect_provider(None, model, None) == expected + + def test_explicit_provider_wins_over_model(self): + assert _detect_provider("anthropic", "gpt-4o", None) == "anthropic" + + def test_from_known_base_url(self): + assert _detect_provider(None, None, PROVIDERS["groq"]) == "groq" + + def test_from_base_url_partial_match(self): + # URL contains the provider URL as a substring + assert _detect_provider(None, None, PROVIDERS["openai"] + "/extra") == "openai" + + def test_unknown_base_url_defaults_to_openrouter(self): + assert _detect_provider(None, None, "https://custom.endpoint/v1") == "openrouter" + + def test_base_url_takes_priority_over_model(self): + assert _detect_provider(None, "gpt-4o", PROVIDERS["groq"]) == "groq" + + def test_all_none_defaults_to_openrouter(self): + assert _detect_provider(None, None, None) == "openrouter" + + +# ── _resolve_key ─────────────────────────────────────────────────────────────── + +class TestResolveKey: + def test_explicit_key_wins(self): + assert _resolve_key("my-key", "openai") == "my-key" + + @pytest.mark.parametrize("provider", sorted(LOCAL_PROVIDERS)) + def test_local_providers_return_local(self, provider): + assert _resolve_key(None, provider) == "local" + + def test_reads_from_env(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + assert _resolve_key(None, "openai") == "env-key" + + def test_missing_env_raises(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(ValueError, match="API key not found"): + _resolve_key(None, "openai") + + def test_unknown_provider_returns_local(self): + assert _resolve_key(None, "unknown") == "local" + + def test_explicit_key_beats_env(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + assert _resolve_key("direct-key", "openai") == "direct-key" + + +# ── Helpers for mocking openai responses ────────────────────────────────────── + +def _make_completion(content="response text"): + message = SimpleNamespace(content=content) + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice]) + + +def _make_chunk(content): + delta = SimpleNamespace(content=content) + choice = SimpleNamespace(delta=delta) + return SimpleNamespace(choices=[choice]) + + +def _mock_openai(content="response text"): + """Return a mock OpenAI client whose chat.completions.create returns a completion.""" + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion(content) + return mock_client + + +# ── ask ──────────────────────────────────────────────────────────────────────── + +class TestAsk: + def _patched_ask(self, prompt, mock_client, **kwargs): + RLE = type("RateLimitError", (Exception,), {}) + fake_openai = MagicMock() + fake_openai.OpenAI.return_value = mock_client + fake_openai.RateLimitError = RLE + with patch.dict("sys.modules", {"openai": fake_openai}): + return ask(prompt, model_api_key="test-key", model_provider="openai", **kwargs) + + def test_returns_text(self): + result = self._patched_ask("hello", _mock_openai("hi there")) + assert result == "hi there" + + def test_uses_correct_model(self): + mock_client = _mock_openai() + self._patched_ask("hello", mock_client, model="gpt-4o") + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "gpt-4o" + + def test_defaults_to_provider_default_model(self): + mock_client = _mock_openai() + self._patched_ask("hello", mock_client) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == DEFAULT_MODELS["openai"] + + def test_sends_user_message(self): + mock_client = _mock_openai() + self._patched_ask("my prompt", mock_client) + call_kwargs = mock_client.chat.completions.create.call_args[1] + messages = call_kwargs["messages"] + assert messages[0]["role"] == "user" + assert "my prompt" in messages[0]["content"] + + def test_schema_appends_json_instruction(self): + from pydantic import BaseModel + + class Reply(BaseModel): + answer: str + + payload = json.dumps({"answer": "42"}) + mock_client = _mock_openai(payload) + result = self._patched_ask("question", mock_client, schema=Reply) + assert isinstance(result, Reply) + assert result.answer == "42" + + def test_schema_sets_response_format(self): + from pydantic import BaseModel + + class Reply(BaseModel): + value: int + + mock_client = _mock_openai(json.dumps({"value": 1})) + self._patched_ask("q", mock_client, schema=Reply) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs.get("response_format") == {"type": "json_object"} + + def test_retries_on_rate_limit_then_raises(self): + import fenn.agents.rag.llm as llm_module + + RLE = type("RateLimitError", (Exception,), {}) + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = RLE("limited") + + fake_openai = MagicMock() + fake_openai.OpenAI.return_value = mock_client + fake_openai.RateLimitError = RLE + + with patch.dict("sys.modules", {"openai": fake_openai}): + with patch.object(llm_module, "time") as mock_time: + with pytest.raises(RLE): + ask("hi", model_api_key="key", model_provider="openai", retries=3) + assert mock_time.sleep.call_count == 2 + + def test_retries_succeeds_after_rate_limit(self): + import fenn.agents.rag.llm as llm_module + + RLE = type("RateLimitError", (Exception,), {}) + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = [ + RLE("limited"), + _make_completion("ok"), + ] + + fake_openai = MagicMock() + fake_openai.OpenAI.return_value = mock_client + fake_openai.RateLimitError = RLE + + with patch.dict("sys.modules", {"openai": fake_openai}): + with patch.object(llm_module, "time"): + result = ask("hi", model_api_key="key", model_provider="openai", retries=3) + assert result == "ok" + + def test_raises_import_error_without_openai(self): + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + ask("hello", model_api_key="key", model_provider="openai") + + def test_custom_base_url_is_used(self): + mock_client = _mock_openai() + RLE = type("RateLimitError", (Exception,), {}) + fake_openai = MagicMock() + fake_openai.OpenAI.return_value = mock_client + fake_openai.RateLimitError = RLE + + custom_url = "https://my.proxy/v1" + with patch.dict("sys.modules", {"openai": fake_openai}): + ask("hi", model_api_key="key", model_provider="openai", base_url=custom_url) + + fake_openai.OpenAI.assert_called_once_with( + api_key="key", base_url=custom_url + ) + + +# ── stream ───────────────────────────────────────────────────────────────────── + +class TestStream: + def _patched_stream(self, prompt, chunks, **kwargs): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter(chunks) + fake_openai = MagicMock() + fake_openai.OpenAI.return_value = mock_client + with patch.dict("sys.modules", {"openai": fake_openai}): + return list( + stream(prompt, model_api_key="test-key", model_provider="openai", **kwargs) + ), mock_client + + def test_yields_tokens(self): + chunks = [_make_chunk("Hello"), _make_chunk(", "), _make_chunk("world")] + tokens, _ = self._patched_stream("hi", chunks) + assert tokens == ["Hello", ", ", "world"] + + def test_skips_empty_delta(self): + chunks = [_make_chunk("Hi"), _make_chunk(""), _make_chunk(None), _make_chunk("!")] + tokens, _ = self._patched_stream("hi", chunks) + assert tokens == ["Hi", "!"] + + def test_skips_chunk_without_choices(self): + no_choices = SimpleNamespace(choices=[]) + normal = _make_chunk("ok") + tokens, _ = self._patched_stream("hi", [no_choices, normal]) + assert tokens == ["ok"] + + def test_passes_stream_true(self): + _, mock_client = self._patched_stream("hi", []) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream"] is True + + def test_sends_correct_prompt(self): + _, mock_client = self._patched_stream("my question", []) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["messages"][0]["content"] == "my question" + + def test_raises_import_error_without_openai(self): + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + list(stream("hi", model_api_key="key", model_provider="openai")) diff --git a/tests/unit/agents/test_node.py b/tests/unit/agents/test_node.py new file mode 100644 index 0000000..c9de768 --- /dev/null +++ b/tests/unit/agents/test_node.py @@ -0,0 +1,162 @@ +"""Tests for fenn/agents/node.py""" +import pytest +from unittest.mock import MagicMock, patch + +from fenn.agents.node import ThinkNode, ActNode, ObserveNode + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _shared( + messages=None, + last_thought=None, + last_observation=None, + iterations=0, + max_iterations=5, +): + mock_llm = MagicMock() + return { + "llm": mock_llm, + "messages": messages or [{"role": "user", "content": "hello"}], + "last_thought": last_thought, + "last_observation": last_observation, + "iterations": iterations, + "max_iterations": max_iterations, + } + + +# ── ThinkNode ────────────────────────────────────────────────────────────────── + +class TestThinkNode: + def test_prep_returns_llm_and_messages(self): + node = ThinkNode() + shared = _shared() + result = node.prep(shared) + assert result["llm"] is shared["llm"] + assert result["messages"] is shared["messages"] + + def test_exec_calls_chat_complete(self): + node = ThinkNode() + mock_llm = MagicMock() + mock_llm.chat_complete.return_value = "I need to think." + prep_res = {"llm": mock_llm, "messages": [{"role": "user", "content": "hi"}]} + result = node.exec(prep_res) + mock_llm.chat_complete.assert_called_once_with(prep_res["messages"]) + assert result == "I need to think." + + def test_post_returns_act_when_action_in_response(self): + node = ThinkNode() + shared = _shared() + action = "Thought: I should search.\nAction: search(query)" + result = node.post(shared, {}, action) + assert result == "act" + + def test_post_returns_done_when_no_action(self): + node = ThinkNode() + shared = _shared() + result = node.post(shared, {}, "I have the final answer.") + assert result == "done" + + def test_post_appends_message(self): + node = ThinkNode() + shared = _shared() + node.post(shared, {}, "Some thought.") + assert shared["messages"][-1] == {"role": "assistant", "content": "Some thought."} + + def test_post_stores_last_thought(self): + node = ThinkNode() + shared = _shared() + node.post(shared, {}, "My thought") + assert shared["last_thought"] == "My thought" + + +# ── ActNode ──────────────────────────────────────────────────────────────────── + +class TestActNode: + def test_prep_returns_last_thought(self): + node = ActNode() + shared = _shared(last_thought="Action: search(hello)") + assert node.prep(shared) == "Action: search(hello)" + + def test_exec_parses_action_line_and_calls_tool(self): + node = ActNode() + thought = "Thought: search\nAction: search(python)" + with patch("fenn.agents.node.execute_tool", return_value="result") as mock_tool: + result = node.exec(thought) + mock_tool.assert_called_once_with("search", "python") + assert result == "result" + + def test_exec_without_action_prefix(self): + node = ActNode() + thought = "lookup(wikipedia)" + with patch("fenn.agents.node.execute_tool", return_value="wiki result") as mock_tool: + result = node.exec(thought) + mock_tool.assert_called_once_with("lookup", "wikipedia") + assert result == "wiki result" + + def test_exec_tool_exception_returns_error_string(self): + node = ActNode() + thought = "Action: bad_tool(arg)" + with patch("fenn.agents.node.execute_tool", side_effect=ValueError("not found")): + result = node.exec(thought) + assert result == "Error: not found" + + def test_exec_multi_arg_tool(self): + node = ActNode() + thought = "Action: calculate(1, 2)" + with patch("fenn.agents.node.execute_tool", return_value="3") as mock_tool: + node.exec(thought) + mock_tool.assert_called_once_with("calculate", "1", "2") + + def test_post_stores_observation_and_returns_observe(self): + node = ActNode() + shared = _shared() + action = node.post(shared, {}, "tool output") + assert shared["last_observation"] == "tool output" + assert action == "observe" + + +# ── ObserveNode ──────────────────────────────────────────────────────────────── + +class TestObserveNode: + def test_prep_returns_last_observation(self): + node = ObserveNode() + shared = _shared(last_observation="some result") + assert node.prep(shared) == "some result" + + def test_exec_returns_observation_unchanged(self): + node = ObserveNode() + assert node.exec("anything") == "anything" + + def test_post_appends_observation_message(self): + node = ObserveNode() + shared = _shared() + node.post(shared, {}, "search result") + assert shared["messages"][-1] == { + "role": "user", + "content": "Observation: search result", + } + + def test_post_increments_iterations(self): + node = ObserveNode() + shared = _shared(iterations=2) + node.post(shared, {}, "obs") + assert shared["iterations"] == 3 + + def test_post_returns_think_when_below_max(self): + node = ObserveNode() + shared = _shared(iterations=0, max_iterations=5) + result = node.post(shared, {}, "obs") + assert result == "think" + + def test_post_returns_done_when_at_max_iterations(self): + node = ObserveNode() + shared = _shared(iterations=4, max_iterations=5) + result = node.post(shared, {}, "obs") + assert result == "done" + + def test_post_returns_done_when_exceeds_max(self): + node = ObserveNode() + shared = _shared(iterations=10, max_iterations=5) + result = node.post(shared, {}, "obs") + assert result == "done" diff --git a/tests/unit/agents/test_rag_components.py b/tests/unit/agents/test_rag_components.py new file mode 100644 index 0000000..20f81ca --- /dev/null +++ b/tests/unit/agents/test_rag_components.py @@ -0,0 +1,487 @@ +"""Tests for fenn/agents/rag/chunker.py, loader.py, and rag.py""" +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch, mock_open + +from fenn.agents.rag.chunker import ( + chunk_text, + _chunk_fixed, + _chunk_paragraphs, + _chunk_sentences, + _chunk_smart, +) + +from fenn.agents.rag.loader import ( + load_documents, + _read_file, + _read_pdf, + _load_url, + _load_wikipedia, + _load_youtube, +) + +from fenn.agents.rag.rag import RAG + +# ══════════════════════════════════════════════════════════════════════════════ +# chunker.py +# ══════════════════════════════════════════════════════════════════════════════ + +class TestChunkParagraphs: + def test_splits_on_blank_line(self): + text = "First paragraph.\n\nSecond paragraph." + result = _chunk_paragraphs(text) + assert result == ["First paragraph.", "Second paragraph."] + + def test_strips_whitespace(self): + result = _chunk_paragraphs(" hello \n\n world ") + assert result == ["hello", "world"] + + def test_ignores_empty_blocks(self): + result = _chunk_paragraphs("\n\n\nonly content\n\n\n") + assert result == ["only content"] + + def test_single_paragraph(self): + result = _chunk_paragraphs("no blank lines here") + assert result == ["no blank lines here"] + + def test_empty_string(self): + assert _chunk_paragraphs("") == [] + + +class TestChunkSentences: + def test_groups_short_sentences(self): + text = "Hello world. This is a test." + result = _chunk_sentences(text) + assert len(result) >= 1 + assert all(isinstance(c, str) for c in result) + + def test_splits_long_sentences(self): + # Each sentence is ~100 chars; 6 of them exceed 500 total + sentence = "This is a moderately long sentence that takes up space. " + text = (sentence * 10).strip() + result = _chunk_sentences(text) + assert len(result) > 1 + assert all(len(c) <= 600 for c in result) + + def test_handles_exclamation_and_question(self): + text = "Really? Yes! Absolutely." + result = _chunk_sentences(text) + assert len(result) >= 1 + + def test_empty_string(self): + result = _chunk_sentences("") + assert result == [] or result == [""] + + +class TestChunkFixed: + def test_basic_chunking(self): + text = "a" * 1000 + result = _chunk_fixed(text, size=200, overlap=50) + assert len(result) > 1 + assert all(len(c) <= 200 for c in result) + + def test_overlap_is_shared(self): + text = "abcdefghij" # 10 chars + result = _chunk_fixed(text, size=6, overlap=2) + # chunk 0: [0:6] = "abcdef", chunk 1: [4:10] = "efghij" + assert result[0][-2:] == result[1][:2] + + def test_short_text_single_chunk(self): + result = _chunk_fixed("hello", size=100, overlap=10) + assert result == ["hello"] + + def test_overlap_gte_size_raises(self): + with pytest.raises(ValueError, match="overlap"): + _chunk_fixed("text", size=10, overlap=10) + + def test_overlap_larger_than_size_raises(self): + with pytest.raises(ValueError, match="overlap"): + _chunk_fixed("text", size=5, overlap=20) + + +class TestChunkSmart: + def test_short_paragraphs_kept_whole(self): + text = "Short para.\n\nAnother short para." + result = _chunk_smart(text) + assert "Short para." in result + assert "Another short para." in result + + def test_long_paragraph_split_into_sentences(self): + # Build a paragraph > 600 chars + sentence = "This is a fairly long sentence used for testing purposes only. " + long_para = sentence * 12 + result = _chunk_smart(long_para) + assert len(result) > 1 + assert all(len(c) <= 600 for c in result) + + def test_empty_chunks_filtered(self): + result = _chunk_smart("\n\n\n") + assert result == [] + + +class TestChunkText: + def test_mode_smart(self): + assert chunk_text("hello\n\nworld", mode="smart") == ["hello", "world"] + + def test_mode_paragraphs(self): + assert chunk_text("a\n\nb", mode="paragraphs") == ["a", "b"] + + def test_mode_sentences(self): + result = chunk_text("Hello. World.", mode="sentences") + assert isinstance(result, list) + + def test_mode_fixed(self): + result = chunk_text("x" * 200, mode="fixed", size=100, overlap=10) + assert len(result) > 1 + + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="unknown chunk_mode"): + chunk_text("text", mode="invalid") + + +# ══════════════════════════════════════════════════════════════════════════════ +# loader.py +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestLoadDocuments: + def test_loads_txt_file(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("hello world", encoding="utf-8") + docs = load_documents(str(f)) + assert docs == ["hello world"] + + def test_loads_md_file(self, tmp_path): + f = tmp_path / "doc.md" + f.write_text("# Title\nContent.", encoding="utf-8") + docs = load_documents(str(f)) + assert len(docs) == 1 + assert "Title" in docs[0] + + def test_loads_directory_recursively(self, tmp_path): + (tmp_path / "a.txt").write_text("file a", encoding="utf-8") + (tmp_path / "sub").mkdir() + (tmp_path / "sub" / "b.txt").write_text("file b", encoding="utf-8") + docs = load_documents(str(tmp_path)) + assert len(docs) == 2 + + def test_skips_unsupported_extensions(self, tmp_path): + (tmp_path / "image.png").write_bytes(b"\x89PNG") + (tmp_path / "doc.txt").write_text("valid", encoding="utf-8") + docs = load_documents(str(tmp_path)) + assert docs == ["valid"] + + def test_empty_directory_returns_empty(self, tmp_path): + docs = load_documents(str(tmp_path)) + assert docs == [] + + def test_missing_path_raises(self): + with pytest.raises(FileNotFoundError): + load_documents("/nonexistent/path/file.txt") + + def test_routes_youtube_url(self): + with patch("fenn.agents.rag.loader._load_youtube", return_value="transcript") as m: + result = load_documents("https://www.youtube.com/watch?v=abc123") + m.assert_called_once() + assert result == ["transcript"] + + def test_routes_youtu_be_url(self): + with patch("fenn.agents.rag.loader._load_youtube", return_value="t") as m: + load_documents("https://youtu.be/abc123") + m.assert_called_once() + + def test_routes_http_url(self): + with patch("fenn.agents.rag.loader._load_url", return_value="page text") as m: + result = load_documents("https://example.com") + m.assert_called_once() + assert result == ["page text"] + + def test_filters_empty_docs(self, tmp_path): + f = tmp_path / "empty.txt" + f.write_text("", encoding="utf-8") + docs = load_documents(str(f)) + assert docs == [] + + +class TestReadFile: + def test_reads_text_file(self, tmp_path): + f = tmp_path / "test.txt" + f.write_text("content", encoding="utf-8") + assert _read_file(f) == "content" + + def test_returns_none_on_read_error(self, tmp_path): + f = tmp_path / "bad.txt" + f.write_text("x") + with patch("pathlib.Path.read_text", side_effect=OSError("fail")): + result = _read_file(f) + assert result is None + + def test_delegates_pdf_to_read_pdf(self, tmp_path): + f = tmp_path / "doc.pdf" + f.write_bytes(b"%PDF") + with patch("fenn.agents.rag.loader._read_pdf", return_value="pdf text") as m: + result = _read_file(f) + m.assert_called_once_with(f) + assert result == "pdf text" + + +class TestReadPdf: + def test_extracts_text(self, tmp_path): + f = tmp_path / "doc.pdf" + f.write_bytes(b"%PDF") + mock_page = MagicMock() + mock_page.extract_text.return_value = "pdf content" + mock_reader = MagicMock() + mock_reader.pages = [mock_page] + with patch("fenn.agents.rag.loader.pypdf", create=True) as mock_pypdf: + mock_pypdf.PdfReader.return_value = mock_reader + with patch.dict("sys.modules", {"pypdf": mock_pypdf}): + import fenn.agents.rag.loader as loader_mod + with patch.object(loader_mod, "_read_pdf", wraps=loader_mod._read_pdf): + pass # just verifying import path + + def test_raises_import_error_without_pypdf(self, tmp_path): + f = tmp_path / "doc.pdf" + f.write_bytes(b"%PDF") + with patch.dict("sys.modules", {"pypdf": None}): + with pytest.raises(ImportError, match="pypdf"): + _read_pdf(f) + + +class TestLoadUrl: + + def test_raises_import_error_without_httpx(self): + with patch.dict("sys.modules", {"httpx": None, "bs4": None}): + with pytest.raises(ImportError, match="httpx"): + _load_url("https://example.com") + + def test_raises_import_error_without_wikipedia(self): + with patch.dict("sys.modules", {"wikipedia": None}): + with pytest.raises((ImportError, Exception)): + _load_wikipedia("https://en.wikipedia.org/wiki/Python") + + def test_raises_import_error_without_youtube_transcript(self): + with patch.dict("sys.modules", {"youtube_transcript_api": None}): + with pytest.raises(ImportError, match="youtube-transcript-api"): + _load_youtube("https://www.youtube.com/watch?v=abc") + + def test_youtube_invalid_url_raises(self): + mock_yta = MagicMock() + with patch.dict("sys.modules", {"youtube_transcript_api": mock_yta}): + with patch("fenn.agents.rag.loader.YouTubeTranscriptApi", mock_yta, create=True): + with pytest.raises((ValueError, Exception)): + _load_youtube("https://www.youtube.com/watch") # no video ID + + +# ══════════════════════════════════════════════════════════════════════════════ +# rag.py +# ══════════════════════════════════════════════════════════════════════════════ + + +def _make_rag(**kwargs): + """Build a RAG instance with a mocked LLMClient.""" + with patch("fenn.agents.rag.rag.LLMClient") as MockLLM: + mock_llm = MagicMock() + mock_llm.provider = "openai" + mock_llm.model = "gpt-4o-mini" + mock_llm.api_key = "test-key" + mock_llm.base_url = "https://api.openai.com/v1" + MockLLM.return_value = mock_llm + rag = RAG(**kwargs) + return rag + + +class TestRAGInit: + def test_default_system_prompt_set(self): + rag = _make_rag() + assert "helpful assistant" in rag._system_prompt + + def test_custom_system_prompt(self): + rag = _make_rag(system_prompt="You are a pirate.") + assert rag._system_prompt == "You are a pirate." + + def test_memory_off_by_default(self): + rag = _make_rag() + assert rag._memory is False + assert rag._history == [] + + def test_memory_on(self): + rag = _make_rag(memory=True) + assert rag._memory is True + + def test_add_source_returns_self(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("content", encoding="utf-8") + rag = _make_rag() + result = rag.add_source(str(f)) + assert result is rag + + def test_add_tool_returns_self(self): + def _fn(x): + return x + rag = _make_rag() + result = rag.add_tool(_fn) + assert result is rag + assert _fn in rag._tools + + def test_debug_returns_self(self): + rag = _make_rag() + result = rag.debug() + assert result is rag + assert rag._debug is True + + def test_reset_memory_clears_history(self): + rag = _make_rag(memory=True) + rag._history = [("q", "a")] + result = rag.reset_memory() + assert rag._history == [] + assert result is rag + + +class TestRAGBuildPrompt: + def test_no_memory_no_tools(self): + rag = _make_rag() + prompt = rag._build_prompt("What is X?", "Some context.") + assert "Some context." in prompt + assert "What is X?" in prompt + assert "Question:" in prompt + assert "Answer:" in prompt + + def test_with_tools_includes_tool_descriptions(self): + rag = _make_rag() + + def my_tool(): + """Does something useful.""" + pass + + rag.add_tool(my_tool) + prompt = rag._build_prompt("q", "ctx") + assert "Available tools:" in prompt + assert "my_tool" in prompt + assert "Does something useful." in prompt + + def test_with_memory_includes_history(self): + rag = _make_rag(memory=True) + rag._history = [("prev question", "prev answer")] + prompt = rag._build_prompt("new question", "context") + assert "Conversation so far:" in prompt + assert "prev question" in prompt + assert "prev answer" in prompt + assert "User: new question" in prompt + + def test_memory_with_max_history_trims(self): + rag = _make_rag(memory=True) + rag._max_history = 1 + rag._history = [("old q", "old a"), ("recent q", "recent a")] + prompt = rag._build_prompt("q", "ctx") + assert "recent q" in prompt + assert "old q" not in prompt + + def test_memory_no_history_uses_standard_format(self): + rag = _make_rag(memory=True) + rag._history = [] + prompt = rag._build_prompt("q", "ctx") + assert "Question:" in prompt + + +class TestRAGRun: + def test_run_returns_llm_answer(self): + rag = _make_rag() + rag._retriever = MagicMock() + rag._retriever.query.return_value = ["chunk one"] + rag._llm.ask.return_value = "the answer" + result = rag.run("my question") + assert result == "the answer" + + def test_run_passes_query_to_retriever(self): + rag = _make_rag() + rag._retriever = MagicMock() + rag._retriever.query.return_value = [] + rag._llm.ask.return_value = "ok" + rag.run("search this") + rag._retriever.query.assert_called_once_with("search this") + + def test_run_with_memory_appends_history(self): + rag = _make_rag(memory=True) + rag._retriever = MagicMock() + rag._retriever.query.return_value = [] + rag._llm.ask.return_value = "answer" + rag.run("question") + assert rag._history == [("question", "answer")] + + def test_run_without_memory_does_not_append_history(self): + rag = _make_rag(memory=False) + rag._retriever = MagicMock() + rag._retriever.query.return_value = [] + rag._llm.ask.return_value = "answer" + rag.run("question") + assert rag._history == [] + + def test_run_passes_schema(self): + from pydantic import BaseModel + + class Reply(BaseModel): + value: str + + rag = _make_rag() + rag._retriever = MagicMock() + rag._retriever.query.return_value = [] + rag._llm.ask.return_value = Reply(value="x") + rag.run("q", schema=Reply) + _, call_kwargs = rag._llm.ask.call_args + assert call_kwargs.get("schema") is Reply + + def test_run_debug_mode_prints(self, capsys): + rag = _make_rag() + rag.debug() + rag._retriever = MagicMock() + rag._retriever.query.return_value = ["a chunk"] + rag._llm.ask.return_value = "ok" + rag.run("q") + captured = capsys.readouterr() + assert "model_provider" in captured.out + + def test_chat_enables_memory_and_runs(self): + rag = _make_rag() + rag._retriever = MagicMock() + rag._retriever.query.return_value = [] + rag._llm.ask.return_value = "chat answer" + result = rag.chat("hello") + assert result == "chat answer" + assert rag._memory is True + + def test_stream_yields_tokens(self): + rag = _make_rag() + rag._retriever = MagicMock() + rag._retriever.query.return_value = ["context"] + rag._llm.stream.return_value = iter(["tok1", "tok2"]) + tokens = list(rag.stream("q")) + assert tokens == ["tok1", "tok2"] + + def test_stream_debug_prints(self, capsys): + rag = _make_rag() + rag.debug() + rag._retriever = MagicMock() + rag._retriever.query.return_value = ["chunk"] + rag._llm.stream.return_value = iter([]) + list(rag.stream("q")) + captured = capsys.readouterr() + assert "model_provider" in captured.out + + def test_add_source_indexes_docs(self, tmp_path): + f = tmp_path / "note.txt" + f.write_text("some content", encoding="utf-8") + rag = _make_rag() + rag._retriever = MagicMock() + rag.add_source(str(f)) + rag._retriever.index.assert_called_once() + + def test_add_source_debug_prints(self, tmp_path, capsys): + f = tmp_path / "note.txt" + f.write_text("content", encoding="utf-8") + rag = _make_rag() + rag.debug() + rag._retriever = MagicMock() + rag.add_source(str(f)) + captured = capsys.readouterr() + assert "loaded" in captured.out diff --git a/tests/unit/agents/test_retrievers.py b/tests/unit/agents/test_retrievers.py new file mode 100644 index 0000000..eabbf60 --- /dev/null +++ b/tests/unit/agents/test_retrievers.py @@ -0,0 +1,280 @@ +"""Tests for fenn/agents/rag/retriever.py""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from fenn.agents.rag.retriever import LOCAL_EMBEDDING_PROVIDERS, Retriever + +# ── Helpers ──────────────────────────────────────────────────────────────────── + + +def _docs(*texts): + return list(texts) + + +def _bm25_retriever(docs=None): + r = Retriever(use_faiss=False) + if docs: + r.index(docs) + return r + + +# ── Construction ─────────────────────────────────────────────────────────────── + + +class TestRetrieverInit: + def test_defaults(self): + r = Retriever() + assert r.chunks == [] + assert r.use_faiss is False + assert r.embedding_provider == "local" + assert r.embedding_model == "all-MiniLM-L6-v2" + assert r.persist_path is None + assert r._faiss_index is None + assert r._loaded_from_disk is False + + def test_persist_path_converted_to_path(self, tmp_path): + r = Retriever(persist_path=str(tmp_path)) + assert isinstance(r.persist_path, Path) + assert r.persist_path == tmp_path + + +# ── _resolve_embedding_key ───────────────────────────────────────────────────── + + +class TestResolveEmbeddingKey: + def test_explicit_key_wins(self): + r = Retriever(embedding_api_key="my-key") + assert r._resolve_embedding_key() == "my-key" + + @pytest.mark.parametrize("provider", sorted(LOCAL_EMBEDDING_PROVIDERS)) + def test_local_providers_return_none(self, provider): + r = Retriever(embedding_provider=provider) + assert r._resolve_embedding_key() is None + + def test_reads_from_env(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + r = Retriever(embedding_provider="openai") + assert r._resolve_embedding_key() == "env-key" + + def test_missing_env_raises(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + r = Retriever(embedding_provider="openai") + with pytest.raises(ValueError, match="embedding API key not found"): + r._resolve_embedding_key() + + def test_unknown_provider_returns_none(self): + r = Retriever(embedding_provider="unknown_provider") + assert r._resolve_embedding_key() is None + + +# ── BM25 index & query ───────────────────────────────────────────────────────── + + +class TestBM25: + def test_index_populates_chunks(self): + r = _bm25_retriever(_docs("hello world", "foo bar")) + assert len(r.chunks) >= 2 + + def test_index_builds_inverted_index(self): + r = _bm25_retriever(_docs("the quick brown fox")) + assert "quick" in r._index + assert "fox" in r._index + + def test_query_returns_relevant_chunk(self): + r = _bm25_retriever( + _docs( + "Python is a programming language.", + "The cat sat on the mat.", + ) + ) + results = r.query("Python programming", top_k=1) + assert any("Python" in c for c in results) + + def test_query_respects_top_k(self): + docs = [f"document number {i}" for i in range(20)] + r = _bm25_retriever(docs) + results = r.query("document number", top_k=3) + assert len(results) <= 3 + + def test_query_no_match_returns_first_chunks(self): + r = _bm25_retriever(_docs("apple banana", "cherry date")) + results = r.query("zzznomatch", top_k=2) + # Falls back to chunks[:top_k] + assert len(results) <= 2 + + def test_query_empty_index_returns_empty(self): + r = Retriever(use_faiss=False) + results = r.query("anything", top_k=5) + assert results == [] + + def test_multiple_index_calls_accumulate_chunks(self): + r = Retriever(use_faiss=False) + r.index(_docs("first document")) + r.index(_docs("second document")) + assert len(r.chunks) >= 2 + + def test_bm25_scores_by_word_frequency(self): + r = _bm25_retriever( + _docs( + "cat cat cat", + "cat dog", + "dog dog", + ) + ) + results = r.query("cat", top_k=3) + # chunk with most "cat" hits should rank first + assert results[0] == "cat cat cat" + + +# ── query() dispatch ─────────────────────────────────────────────────────────── + + +class TestQueryDispatch: + def test_bm25_path_called(self): + r = Retriever(use_faiss=False) + r.chunks = ["some chunk"] + r._index["some"].append(0) + with patch.object(r, "_query_bm25", wraps=r._query_bm25) as mock: + r.query("some text", top_k=1) + mock.assert_called_once_with("some text", 1) + + def test_faiss_path_called(self): + r = Retriever(use_faiss=True) + with patch.object(r, "_query_faiss", return_value=[]) as mock: + r.query("some text", top_k=3) + mock.assert_called_once_with("some text", 3) + + +# ── index() with persist_path ────────────────────────────────────────────────── + + +class TestIndexPersistence: + def test_index_loads_from_disk_on_first_call(self, tmp_path): + r = Retriever(use_faiss=False, persist_path=tmp_path) + with patch.object(r, "_load_from_disk", return_value=True) as mock_load: + r.index(_docs("hello")) + mock_load.assert_called_once() + assert r._loaded_from_disk is True + + def test_index_skips_disk_load_on_second_call(self, tmp_path): + r = Retriever(use_faiss=False, persist_path=tmp_path) + r._loaded_from_disk = True + with patch.object(r, "_load_from_disk") as mock_load: + r.index(_docs("hello world")) + mock_load.assert_not_called() + + def test_index_proceeds_when_disk_load_fails(self, tmp_path): + r = Retriever(use_faiss=False, persist_path=tmp_path) + with patch.object(r, "_load_from_disk", return_value=False): + r.index(_docs("hello world")) + assert len(r.chunks) > 0 + + def test_load_from_disk_returns_false_when_files_missing(self, tmp_path): + r = Retriever(persist_path=tmp_path) + assert r._load_from_disk() is False + + def test_load_from_disk_returns_false_on_corrupt_data(self, tmp_path): + (tmp_path / "index.faiss").write_bytes(b"corrupt") + (tmp_path / "chunks.json").write_text("[]", encoding="utf-8") + r = Retriever(persist_path=tmp_path) + # faiss will raise on corrupt file; should return False not crash + result = r._load_from_disk() + assert result is False + + def test_save_and_load_roundtrip(self, tmp_path): + """Save a FAISS index to disk, then load it back.""" + pytest.importorskip("faiss") + pytest.importorskip("sentence_transformers") + + r = Retriever(use_faiss=True, persist_path=tmp_path) + r.chunks = ["chunk one", "chunk two"] + + mock_index = MagicMock() + r._faiss_index = mock_index + + with patch("fenn.agents.rag.retriever.faiss"): + r._save_to_disk() + assert (tmp_path / "chunks.json").exists() + saved = json.loads((tmp_path / "chunks.json").read_text()) + assert saved == ["chunk one", "chunk two"] + + +# ── _build_faiss ─────────────────────────────────────────────────────────────── + + +class TestBuildFaiss: + def test_build_faiss_raises_without_faiss(self): + r = Retriever(use_faiss=True) + r.chunks = ["chunk"] + with patch.dict("sys.modules", {"faiss": None}): + with pytest.raises(ImportError, match="faiss-cpu"): + r._build_faiss() + + +# ── Embedding method ImportError paths ──────────────────────────────────────── + + +class TestEmbedImportErrors: + def test_embed_local_raises_without_sentence_transformers(self): + r = Retriever(embedding_provider="local") + with patch.dict("sys.modules", {"sentence_transformers": None}): + with pytest.raises(ImportError, match="sentence-transformers"): + r._embed_local(["text"]) + + def test_embed_openai_compat_raises_without_openai(self): + r = Retriever(embedding_provider="openai", embedding_api_key="key") + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + r._embed_openai_compat(["text"]) + + def test_embed_ollama_raises_without_openai(self): + r = Retriever(embedding_provider="ollama") + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="openai"): + r._embed_ollama(["text"]) + + def test_embed_cohere_raises_without_cohere(self): + r = Retriever(embedding_provider="cohere", embedding_api_key="key") + with patch.dict("sys.modules", {"cohere": None}): + with pytest.raises(ImportError, match="cohere"): + r._embed_cohere(["text"]) + + def test_embed_voyage_raises_without_voyageai(self): + r = Retriever(embedding_provider="voyage", embedding_api_key="key") + with patch.dict("sys.modules", {"voyageai": None}): + with pytest.raises(ImportError, match="voyageai"): + r._embed_voyage(["text"]) + + def test_embed_jina_raises_without_httpx(self): + r = Retriever(embedding_provider="jina", embedding_api_key="key") + with patch.dict("sys.modules", {"httpx": None}): + with pytest.raises(ImportError, match="httpx"): + r._embed_jina(["text"]) + + +# ── _embed dispatch ──────────────────────────────────────────────────────────── + + +class TestEmbedDispatch: + @pytest.mark.parametrize( + "provider,method", + [ + ("local", "_embed_local"), + ("openai", "_embed_openai_compat"), + ("gemini", "_embed_openai_compat"), + ("mistral", "_embed_openai_compat"), + ("ollama", "_embed_ollama"), + ("cohere", "_embed_cohere"), + ("voyage", "_embed_voyage"), + ("jina", "_embed_jina"), + ], + ) + def test_embed_routes_to_correct_method(self, provider, method): + r = Retriever(embedding_provider=provider, embedding_api_key="key") + with patch.object(r, method, return_value="result") as mock_method: + r._embed(["text"]) + mock_method.assert_called_once() diff --git a/tests/unit/agents/test_tools.py b/tests/unit/agents/test_tools.py new file mode 100644 index 0000000..b0e0bf7 --- /dev/null +++ b/tests/unit/agents/test_tools.py @@ -0,0 +1,165 @@ +import pytest + +from fenn.agents.tools import TOOLS_REGISTRY, execute_tool, get_tool_schema, tool + + +@pytest.fixture(autouse=True) +def clean_registry(): + """Ensure each test starts with a clean tools registry.""" + TOOLS_REGISTRY.clear() + yield + TOOLS_REGISTRY.clear() + + +class TestToolDecorator: + def test_registers_function_in_registry(self): + @tool + def my_tool(x): + """Does something.""" + return x * 2 + + assert "my_tool" in TOOLS_REGISTRY + + def test_registry_schema_contains_name_and_description(self): + @tool + def search(query): + """Search the web for a query.""" + return query + + entry = TOOLS_REGISTRY["search"] + assert entry["schema"]["name"] == "search" + assert entry["schema"]["description"] == "Search the web for a query." + + def test_missing_docstring_uses_default_description(self): + @tool + def no_doc(x): + return x + + entry = TOOLS_REGISTRY["no_doc"] + assert entry["schema"]["description"] == "No description provided" + + def test_execute_key_stores_original_function(self): + @tool + def add(a, b): + """Add two numbers.""" + return a + b + + assert TOOLS_REGISTRY["add"]["execute"](2, 3) == 5 + + def test_decorated_function_still_callable(self): + @tool + def double(x): + """Double a number.""" + return x * 2 + + assert double(5) == 10 + + def test_decorator_preserves_name_and_doc(self): + @tool + def greet(name): + """Greet someone.""" + return f"Hello, {name}" + + assert greet.__name__ == "greet" + assert greet.__doc__ == "Greet someone." + + def test_multiple_tools_registered_independently(self): + @tool + def tool_a(): + """Tool A.""" + return "a" + + @tool + def tool_b(): + """Tool B.""" + return "b" + + assert set(TOOLS_REGISTRY.keys()) == {"tool_a", "tool_b"} + + def test_decorator_passes_args_and_kwargs(self): + @tool + def combine(a, b, sep="-"): + """Combine two values with a separator.""" + return f"{a}{sep}{b}" + + assert combine("x", "y") == "x-y" + assert combine("x", "y", sep="_") == "x_y" + + +class TestGetToolSchema: + def test_empty_registry_returns_empty_list(self): + assert get_tool_schema() == [] + + def test_returns_list_of_schemas(self): + @tool + def search(query): + """Search something.""" + return query + + @tool + def calc(expr): + """Evaluate an expression.""" + return expr + + schemas = get_tool_schema() + assert len(schemas) == 2 + names = {s["name"] for s in schemas} + assert names == {"search", "calc"} + + def test_schema_does_not_include_execute_key(self): + @tool + def my_tool(): + """A tool.""" + return None + + schemas = get_tool_schema() + assert "execute" not in schemas[0] + + +class TestExecuteTool: + def test_executes_registered_tool(self): + @tool + def add(a, b): + """Add two numbers.""" + return a + b + + result = execute_tool("add", 2, 3) + assert result == 5 + + def test_executes_with_kwargs(self): + @tool + def greet(name, greeting="Hello"): + """Greet someone.""" + return f"{greeting}, {name}!" + + result = execute_tool("greet", "World", greeting="Hi") + assert result == "Hi, World!" + + def test_unregistered_tool_raises_value_error(self): + with pytest.raises(ValueError, match="not registered"): + execute_tool("nonexistent_tool") + + def test_error_message_includes_tool_name(self): + with pytest.raises(ValueError, match="missing_tool"): + execute_tool("missing_tool", "arg") + + def test_executes_tool_with_no_args(self): + @tool + def get_time(): + """Get the current time.""" + return "12:00" + + assert execute_tool("get_time") == "12:00" + + def test_execute_uses_original_function_not_wrapper(self): + """execute_tool should call the underlying func directly.""" + call_log = [] + + @tool + def logged(x): + """Logs calls.""" + call_log.append(x) + return x + + execute_tool("logged", "value") + assert call_log == ["value"] diff --git a/tests/unit/vision/test_resize.py b/tests/unit/vision/test_resize.py index 238d961..1f21dff 100644 --- a/tests/unit/vision/test_resize.py +++ b/tests/unit/vision/test_resize.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from unittest.mock import patch try: from fenn.experimental.vision import resize_batch @@ -15,96 +16,224 @@ class TestResizeBatch: """Test suite for resize_batch function.""" + # ------------------------------------------------------------------ + # Basic shape / format preservation + # ------------------------------------------------------------------ + def test_basic_resize_channels_last(self): - """Test basic resize with channels last format.""" - # Create a simple test image: (1, 100, 100, 3) array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) result = resize_batch(array, size=(50, 50)) - assert result.shape == (1, 50, 50, 3) assert result.dtype == np.uint8 assert np.all(result >= 0) and np.all(result <= 255) def test_basic_resize_channels_first(self): - """Test basic resize with channels first format.""" - # Create a test image: (1, 3, 100, 100) array = np.random.randint(0, 255, (1, 3, 100, 100), dtype=np.uint8) result = resize_batch(array, size=(50, 50)) - assert result.shape == (1, 3, 50, 50) assert result.dtype == np.uint8 def test_grayscale_no_channels(self): - """Test resize with grayscale (N, H, W).""" array = np.random.randint(0, 255, (1, 100, 100), dtype=np.uint8) result = resize_batch(array, size=(50, 50)) - assert result.shape == (1, 50, 50) assert result.dtype == np.uint8 + def test_preserve_channel_order(self): + array_cf = np.random.randint(0, 255, (1, 3, 100, 100), dtype=np.uint8) + result_cf = resize_batch(array_cf, size=(50, 50)) + assert result_cf.shape == (1, 3, 50, 50) + + array_cl = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result_cl = resize_batch(array_cl, size=(50, 50)) + assert result_cl.shape == (1, 50, 50, 3) + + def test_batch_multiple_images(self): + array = np.random.randint(0, 255, (5, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=(50, 50)) + assert result.shape == (5, 50, 50, 3) + assert result.dtype == np.uint8 + + # ------------------------------------------------------------------ + # Size argument variants + # ------------------------------------------------------------------ + + def test_square_size_int(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=50) + assert result.shape == (1, 50, 50, 3) + + def test_non_square_tuple_size(self): + """Rectangular target size should produce non-square output.""" + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=(40, 80)) + assert result.shape == (1, 40, 80, 3) + + def test_upsampling(self): + """Upsampling (target > source) should disable antialias and still work.""" + array = np.random.randint(0, 255, (1, 50, 50, 3), dtype=np.uint8) + result = resize_batch(array, size=(100, 100)) + assert result.shape == (1, 100, 100, 3) + assert result.dtype == np.uint8 + + def test_same_size_noop(self): + """Resizing to same dimensions should return identically shaped array.""" + array = np.random.randint(0, 255, (2, 64, 64, 3), dtype=np.uint8) + result = resize_batch(array, size=(64, 64)) + assert result.shape == array.shape + + # ------------------------------------------------------------------ + # dtype handling + # ------------------------------------------------------------------ + def test_float32_preservation(self): - """Test that float32 dtype is preserved.""" array = np.random.rand(1, 100, 100, 3).astype(np.float32) result = resize_batch(array, size=(50, 50)) - assert result.shape == (1, 50, 50, 3) assert result.dtype == np.float32 assert np.all(result >= 0) and np.all(result <= 1) - def test_different_interpolation_modes(self): - """Test different interpolation modes.""" - array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + def test_float64_preservation(self): + """float64 input should be preserved and values clipped to [0, 1].""" + array = np.random.rand(1, 100, 100, 3).astype(np.float64) + result = resize_batch(array, size=(50, 50)) + assert result.dtype == np.float64 + assert np.all(result >= 0) and np.all(result <= 1) - for mode in ["nearest", "bilinear", "bicubic"]: - result = resize_batch(array, size=(50, 50), interpolation=mode) - assert result.shape == (1, 50, 50, 3) - assert result.dtype == np.uint8 + def test_uint16_normalization(self): + """uint16 is neither uint8 nor float; should normalize through float and back.""" + array = np.random.randint(0, 65535, (1, 100, 100, 3), dtype=np.uint16) + result = resize_batch(array, size=(50, 50)) + assert result.shape == (1, 50, 50, 3) + assert result.dtype == np.uint16 + assert np.all(result >= 0) and np.all(result <= 65535) - def test_square_size_int(self): - """Test resize with integer size (square output).""" + def test_int32_normalization(self): + """int32 should go through float normalization path.""" + array = np.random.randint(0, 1000, (1, 100, 100, 3), dtype=np.int32) + result = resize_batch(array, size=(50, 50)) + assert result.dtype == np.int32 + + # ------------------------------------------------------------------ + # Interpolation modes + # ------------------------------------------------------------------ + + def test_interpolation_nearest(self): array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) - result = resize_batch(array, size=50) + result = resize_batch(array, size=(50, 50), interpolation="nearest") + assert result.shape == (1, 50, 50, 3) + def test_interpolation_bilinear(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=(50, 50), interpolation="bilinear") assert result.shape == (1, 50, 50, 3) - def test_batch_multiple_images(self): - """Test resize with multiple images in batch.""" - array = np.random.randint(0, 255, (5, 100, 100, 3), dtype=np.uint8) - result = resize_batch(array, size=(50, 50)) + def test_interpolation_bicubic(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=(50, 50), interpolation="bicubic") + assert result.shape == (1, 50, 50, 3) - assert result.shape == (5, 50, 50, 3) - assert result.dtype == np.uint8 + def test_interpolation_nearest_exact(self): + """nearest_exact is a valid mode that the original tests omitted.""" + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + result = resize_batch(array, size=(50, 50), interpolation="nearest_exact") + assert result.shape == (1, 50, 50, 3) - def test_invalid_size(self): - """Test that invalid size raises ValueError.""" + def test_different_interpolation_modes(self): array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + for mode in ["nearest", "bilinear", "bicubic"]: + result = resize_batch(array, size=(50, 50), interpolation=mode) + assert result.shape == (1, 50, 50, 3) + assert result.dtype == np.uint8 + + # ------------------------------------------------------------------ + # Error paths + # ------------------------------------------------------------------ + def test_invalid_size_negative_height(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) with pytest.raises(ValueError): resize_batch(array, size=(-10, 10)) + def test_invalid_size_negative_width(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) with pytest.raises(ValueError): resize_batch(array, size=(10, -10)) - def test_invalid_interpolation(self): - """Test that invalid interpolation raises ValueError.""" + def test_invalid_size_zero(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + with pytest.raises(ValueError): + resize_batch(array, size=(0, 50)) + + def test_invalid_size_wrong_type(self): + """A non-int, non-2-tuple size should raise ValueError.""" + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + with pytest.raises(ValueError): + resize_batch(array, size=(10, 10, 10)) # 3-tuple + + def test_invalid_size_string(self): array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + with pytest.raises((ValueError, TypeError)): + resize_batch(array, size="large") + def test_invalid_interpolation(self): + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) with pytest.raises(ValueError): resize_batch(array, size=(50, 50), interpolation="invalid_mode") - def test_invalid_array_type(self): - """Test that non-numpy array raises TypeError.""" + def test_invalid_array_type_list(self): with pytest.raises(TypeError): resize_batch([1, 2, 3], size=(50, 50)) - def test_preserve_channel_order(self): - """Test that channel order is preserved.""" - # Channels first - array_cf = np.random.randint(0, 255, (1, 3, 100, 100), dtype=np.uint8) - result_cf = resize_batch(array_cf, size=(50, 50)) - assert result_cf.shape == (1, 3, 50, 50) # Still channels first + def test_invalid_array_type_none(self): + with pytest.raises(TypeError): + resize_batch(None, size=(50, 50)) - # Channels last - array_cl = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) - result_cl = resize_batch(array_cl, size=(50, 50)) - assert result_cl.shape == (1, 50, 50, 3) # Still channels last + def test_torchvision_unavailable(self): + """When torchvision is not installed, ImportError should be raised.""" + import fenn.experimental.vision.resize as resize_module # adjust to real module path + + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + with patch.object(resize_module, "TORCHVISION_AVAILABLE", False): + with pytest.raises(ImportError, match="torchvision"): + resize_batch(array, size=(50, 50)) + + # ------------------------------------------------------------------ + # Antialias behaviour (white-box) + # ------------------------------------------------------------------ + + def test_antialias_disabled_for_nearest_downsampling(self): + """nearest interpolation should never trigger antialias even when downsampling.""" + import fenn.experimental.vision.resize as resize_module + import torchvision.transforms.functional as F + + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + calls = [] + original_resize = F.resize + + def mock_resize(tensor, size, interpolation, antialias): + calls.append(antialias) + return original_resize(tensor, size, interpolation=interpolation, antialias=antialias) + + with patch.object(F, "resize", side_effect=mock_resize): + resize_batch(array, size=(50, 50), interpolation="nearest") + + assert calls and calls[0] is False + + def test_antialias_enabled_for_bilinear_downsampling(self): + """bilinear downsampling should enable antialias.""" + import fenn.experimental.vision.resize as resize_module + import torchvision.transforms.functional as F + + array = np.random.randint(0, 255, (1, 100, 100, 3), dtype=np.uint8) + calls = [] + original_resize = F.resize + + def mock_resize(tensor, size, interpolation, antialias): + calls.append(antialias) + return original_resize(tensor, size, interpolation=interpolation, antialias=antialias) + + with patch.object(F, "resize", side_effect=mock_resize): + resize_batch(array, size=(50, 50), interpolation="bilinear") + + assert calls and calls[0] is True