diff --git a/bot/config.py b/bot/config.py index cdcd568..314fa5d 100644 --- a/bot/config.py +++ b/bot/config.py @@ -30,6 +30,12 @@ def extensions(self) -> Iterator[Extension]: from bot import extensions for name in find_all_importable(extensions): + if any(part.startswith("test") for part in name.split(".")): + continue + + if not name.endswith("_extension"): + continue + imported: ModuleType = import_module(name) if not name.endswith("_extension"): diff --git a/tests/test_config.py b/tests/test_config.py index c2f9d5b..f81bf60 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,7 @@ +import sys import pytest from pathlib import Path +from unittest.mock import MagicMock, patch from bot.config import BotConfig, DEFAULT_LOG_LEVEL, DEFAULT_ENVIRONMENT @@ -21,6 +23,13 @@ def bot_config(config_file: Path) -> BotConfig: return BotConfig(str(config_file)) +@pytest.fixture +def extension_module() -> MagicMock: + module = MagicMock() + module.extension = MagicMock() + return module + + def test_config_loads_username(bot_config: BotConfig) -> None: assert bot_config.username == "@ada:matrix.org" @@ -68,3 +77,86 @@ def test_config_overrides_environment(tmp_path: Path) -> None: ) config = BotConfig(str(config_file)) assert config.environment == "production" + + +def test_extensions_yields_valid_extension( + bot_config: BotConfig, extension_module: MagicMock +) -> None: + with ( + patch.dict(sys.modules, {"bot.extensions": MagicMock()}), + patch( + "bot.config.find_all_importable", + return_value={"bot.extensions.foo_extension"}, + ), + patch("bot.config.import_module", return_value=extension_module), + ): + result = list(bot_config.extensions) + + assert result == [extension_module.extension] + + +def test_extensions_skips_tests_folder( + bot_config: BotConfig, extension_module: MagicMock +) -> None: + with ( + patch.dict(sys.modules, {"bot.extensions": MagicMock()}), + patch( + "bot.config.find_all_importable", + return_value={"bot.extensions.tests.foo_extension"}, + ), + patch("bot.config.import_module") as mock_import, + ): + result = list(bot_config.extensions) + + assert result == [] + mock_import.assert_not_called() + + +def test_extensions_skips_test_modules( + bot_config: BotConfig, extension_module: MagicMock +) -> None: + with ( + patch.dict(sys.modules, {"bot.extensions": MagicMock()}), + patch( + "bot.config.find_all_importable", + return_value={"bot.extensions.test_foo_extension"}, + ), + patch("bot.config.import_module") as mock_import, + ): + result = list(bot_config.extensions) + + assert result == [] + mock_import.assert_not_called() + + +def test_extensions_skips_non_extension_modules( + bot_config: BotConfig, extension_module: MagicMock +) -> None: + with ( + patch.dict(sys.modules, {"bot.extensions": MagicMock()}), + patch( + "bot.config.find_all_importable", return_value={"bot.extensions.helpers"} + ), + patch("bot.config.import_module") as mock_import, + ): + result = list(bot_config.extensions) + + assert result == [] + mock_import.assert_not_called() + + +def test_extensions_raises_if_missing_extension_attribute( + bot_config: BotConfig, +) -> None: + module = MagicMock(spec=[]) + + with ( + patch.dict(sys.modules, {"bot.extensions": MagicMock()}), + patch( + "bot.config.find_all_importable", + return_value={"bot.extensions.foo_extension"}, + ), + patch("bot.config.import_module", return_value=module), + ): + with pytest.raises(RuntimeError, match="does not define an extension"): + list(bot_config.extensions)