diff --git a/src/evaluate/loading.py b/src/evaluate/loading.py index 505015b1..6e065f7f 100644 --- a/src/evaluate/loading.py +++ b/src/evaluate/loading.py @@ -617,13 +617,13 @@ def evaluation_module_factory( if path.endswith(filename): if os.path.isfile(path): return LocalEvaluationModuleFactory( - path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path + path, download_mode=download_mode, download_config=download_config, dynamic_modules_path=dynamic_modules_path ).get_module() else: raise FileNotFoundError(f"Couldn't find a metric script at {relative_to_absolute_path(path)}") elif os.path.isfile(combined_path): return LocalEvaluationModuleFactory( - combined_path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path + combined_path, download_mode=download_mode, download_config=download_config, dynamic_modules_path=dynamic_modules_path ).get_module() elif is_relative_path(path) and path.count("/") <= 1 and not force_local_path: try: diff --git a/tests/test_load.py b/tests/test_load.py index e20ea671..d82f02e1 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -2,6 +2,7 @@ import os import tempfile from unittest import TestCase +from unittest.mock import patch import pytest from datasets import DownloadConfig @@ -100,6 +101,35 @@ def test_LocalMetricModuleFactory(self): module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + def test_evaluation_module_factory_local_py_path_passes_download_config(self): + # Regression test for https://github.com/huggingface/evaluate/issues/709: + # evaluation_module_factory must forward download_config to LocalEvaluationModuleFactory + # when a direct .py path is given (path.endswith(filename) branch). + py_path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") + with patch("evaluate.loading.LocalEvaluationModuleFactory", wraps=LocalEvaluationModuleFactory) as spy: + evaluation_module_factory( + py_path, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + ) + spy.assert_called_once() + _, kwargs = spy.call_args + assert kwargs.get("download_config") is self.download_config + + def test_evaluation_module_factory_local_dir_path_passes_download_config(self): + # Regression test for https://github.com/huggingface/evaluate/issues/709: + # evaluation_module_factory must forward download_config to LocalEvaluationModuleFactory + # when a directory path is given (combined_path branch). + with patch("evaluate.loading.LocalEvaluationModuleFactory", wraps=LocalEvaluationModuleFactory) as spy: + evaluation_module_factory( + self._metric_loading_script_dir, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + ) + spy.assert_called_once() + _, kwargs = spy.call_args + assert kwargs.get("download_config") is self.download_config + def test_CachedMetricModuleFactory(self): path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") factory = LocalEvaluationModuleFactory(