diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 987407c..c94cfa9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,11 @@ Releases ======== +Unreleased +---------- + +* `#547 `_: Added ``SpyType`` for annotating ``mocker.spy`` results. + 3.15.1 ------ diff --git a/docs/usage.rst b/docs/usage.rst index 80d7c66..5dfb7c3 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -75,10 +75,8 @@ also tracks function/method calls, return values and exceptions raised. assert spy.call_count == 1 assert spy.spy_return == 42 -The object returned by ``mocker.spy`` is a ``MagicMock`` object, so all standard checking functions -are available (like ``assert_called_once_with`` or ``call_count`` in the examples above). - -In addition, spy objects contain four extra attributes: +The object returned by ``mocker.spy`` is a ``pytest_mock.SpyType`` object which subclasses ``MagicMock``, so all standard checking functions +are available (like ``assert_called_once_with`` or ``call_count`` in the examples above), in addition to four extra attributes: * ``spy_return``: contains the last returned value of the spied function. * ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator and spy was created using ``.spy(..., duplicate_iterators=True)``. Uses `tee `__) to duplicate the iterator. diff --git a/src/pytest_mock/__init__.py b/src/pytest_mock/__init__.py index 75fd27a..b130ef5 100644 --- a/src/pytest_mock/__init__.py +++ b/src/pytest_mock/__init__.py @@ -2,6 +2,7 @@ from pytest_mock.plugin import MockerFixture from pytest_mock.plugin import MockType from pytest_mock.plugin import PytestMockWarning +from pytest_mock.plugin import SpyType from pytest_mock.plugin import class_mocker from pytest_mock.plugin import mocker from pytest_mock.plugin import module_mocker @@ -18,6 +19,7 @@ "MockFixture", "MockType", "PytestMockWarning", + "SpyType", "pytest_addoption", "pytest_configure", "session_mocker", diff --git a/src/pytest_mock/plugin.py b/src/pytest_mock/plugin.py index ef99612..79f6ef1 100644 --- a/src/pytest_mock/plugin.py +++ b/src/pytest_mock/plugin.py @@ -33,6 +33,17 @@ ] +class SpyType(unittest.mock.Mock): + """ + Type stub used to annotate the result of ``mocker.spy``. + """ + + spy_return: Any + spy_return_iter: Optional[Iterator[Any]] + spy_return_list: list[Any] + spy_exception: Optional[BaseException] + + class PytestMockWarning(UserWarning): """Base class for all warnings emitted by pytest-mock.""" @@ -157,9 +168,7 @@ def stop(self, mock: unittest.mock.MagicMock) -> None: """ self._mock_cache.remove(mock) - def spy( - self, obj: object, name: str, duplicate_iterators: bool = False - ) -> MockType: + def spy(self, obj: object, name: str, duplicate_iterators: bool = False) -> SpyType: """ Create a spy of method. It will run method normally, but it is now possible to use `mock` call features with it, like call count. @@ -210,7 +219,10 @@ async def async_wrapper(*args, **kwargs): autospec = inspect.ismethod(method) or inspect.isfunction(method) - spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec) + spy_obj = cast( + SpyType, + self.patch.object(obj, name, side_effect=wrapped, autospec=autospec), + ) spy_obj.spy_return = None spy_obj.spy_return_iter = None spy_obj.spy_return_list = [] diff --git a/tests/test_pytest_mock.py b/tests/test_pytest_mock.py index 54baf06..04352ef 100644 --- a/tests/test_pytest_mock.py +++ b/tests/test_pytest_mock.py @@ -16,6 +16,7 @@ from pytest_mock import MockerFixture from pytest_mock import PytestMockWarning +from pytest_mock import SpyType pytest_plugins = "pytester" @@ -283,6 +284,29 @@ def bar(self, arg): assert spy.spy_return_list == [20, 22, 24] +def assert_spy_has_no_return(spy: SpyType) -> None: + assert spy.spy_return is None + assert spy.spy_return_iter is None + assert spy.spy_return_list == [] + + +def test_spy_type(mocker: MockerFixture) -> None: + class Foo: + def bar(self) -> str: + return "ok" + + foo = Foo() + spy: SpyType = mocker.spy(foo, "bar") + + assert_spy_has_no_return(spy) + assert spy.spy_exception is None + spy.assert_not_called() + + assert foo.bar() == "ok" + assert spy.spy_return == "ok" + assert spy.spy_return_list == ["ok"] + + # Ref: https://docs.python.org/3/library/exceptions.html#exception-hierarchy @pytest.mark.parametrize( "exc_cls", @@ -357,14 +381,12 @@ def bar(self, x): return x * 3 spy = mocker.spy(Foo, "bar") - assert spy.spy_return is None - assert spy.spy_return_iter is None - assert spy.spy_return_list == [] + assert_spy_has_no_return(spy) assert spy.spy_exception is None Foo().bar(10) assert spy.spy_return == 30 - assert spy.spy_return_iter is None # type:ignore[unreachable] + assert spy.spy_return_iter is None assert spy.spy_return_list == [30] assert spy.spy_exception is None @@ -373,9 +395,7 @@ def bar(self, x): with pytest.raises(ValueError): Foo().bar(0) - assert spy.spy_return is None - assert spy.spy_return_iter is None - assert spy.spy_return_list == [] + assert_spy_has_no_return(spy) assert str(spy.spy_exception) == "invalid x" Foo().bar(15) @@ -624,6 +644,7 @@ def bar(self) -> Any: result_iterator = list(foo.bar()) assert result_iterator == [0, 1, 2] + assert spy.spy_return_iter is not None assert list(spy.spy_return_iter) == result_iterator assert foo.bar() == 99