diff --git a/ais_bench/benchmark/datasets/swebench.py b/ais_bench/benchmark/datasets/swebench.py index a5818e74..dee0d2fe 100644 --- a/ais_bench/benchmark/datasets/swebench.py +++ b/ais_bench/benchmark/datasets/swebench.py @@ -52,14 +52,46 @@ def _parquet_data_files_from_dir( @LOAD_DATASET.register_module() class SWEBenchDataset(BaseDataset): + def _load_instance_ids_file(self, instance_ids_file: str) -> set[str]: + path = Path(instance_ids_file).expanduser() + if not path.is_file(): + raise FileOperationError( + SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED, + f"SWE-Bench instance ids file does not exist: {instance_ids_file!r}", + ) + if path.suffix.lower() != ".txt": + raise FileOperationError( + SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED, + f"SWE-Bench instance ids file must be a .txt file: {instance_ids_file!r}", + ) + + try: + instance_ids = { + line.strip() + for line in path.read_text(encoding="utf-8").splitlines() + if line.strip() + } + except OSError as e: + raise FileOperationError( + SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED, + f"Failed to read SWE-Bench instance ids file {instance_ids_file!r}: {e}", + ) + return instance_ids + def filter_instances( - self, instances: list[dict], *, filter_spec: str, shuffle: bool = False + self, + instances: list[dict], + *, + filter_spec: str, + instance_ids: set[str] | None = None, + shuffle: bool = False, ) -> list[dict]: """Filter and slice a list of SWEBench instances.""" if shuffle: instances = sorted(instances.copy(), key=lambda x: x["instance_id"]) random.seed(42) random.shuffle(instances) + before_filter = len(instances) instances = [ instance @@ -70,6 +102,26 @@ def filter_instances( self.logger.info( f"Instance filter: {before_filter} -> {after_filter} instances" ) + + if instance_ids is not None: + available_ids = {instance["instance_id"] for instance in instances} + missing_ids = instance_ids - available_ids + before_ids_filter = len(instances) + instances = [ + instance + for instance in instances + if instance["instance_id"] in instance_ids + ] + if (after_ids_filter := len(instances)) != before_ids_filter: + self.logger.info( + f"Instance ids file filter: {before_ids_filter} -> {after_ids_filter} instances" + ) + if missing_ids: + self.logger.warning( + "Instance ids file contains %d ids not present after dataset/filter_spec selection: %s", + len(missing_ids), + ", ".join(sorted(missing_ids)[:10]), + ) return instances def load( @@ -78,6 +130,7 @@ def load( path: str = "", split: str = "test", filter_spec: str = "", + instance_ids_file: str = "", shuffle: bool = False, **kwargs, ): @@ -87,6 +140,7 @@ def load( path: The path to the dataset. split (str): The split of the dataset to load. filter_spec (str): The filter specification to apply to the dataset. + instance_ids_file (str): Text file containing one instance_id per line. shuffle (bool): Whether to shuffle the dataset. **kwargs: Additional keyword arguments. @@ -146,5 +200,19 @@ def load( SWEB_CODES.LOCAL_PARQUET_LOAD_FAILED, f"Failed to load local swebench parquet from {root}: {e}", ) - dataset = self.filter_instances(list(dataset), filter_spec=filter_spec, shuffle=shuffle) + instance_ids = None + if instance_ids_file: + instance_ids = self._load_instance_ids_file(instance_ids_file) + self.logger.info( + "Loaded %d SWE-Bench instance ids from %s", + len(instance_ids), + instance_ids_file, + ) + + dataset = self.filter_instances( + list(dataset), + filter_spec=filter_spec, + instance_ids=instance_ids, + shuffle=shuffle, + ) return Dataset.from_list(dataset) diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_full.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_full.py index bf19fef6..bb0dd5d3 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_full.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_full.py @@ -29,6 +29,7 @@ split="test", step_limit=STEP_LIMIT, filter_spec="", + instance_ids_file="", shuffle=False, ), ] diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_lite.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_lite.py index ff6acaa9..41189756 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_lite.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_lite.py @@ -28,6 +28,7 @@ name="lite", split="test", filter_spec="", + instance_ids_file="", shuffle=False, step_limit=STEP_LIMIT, ), diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual.py index dc7e5881..a48215ad 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual.py @@ -29,6 +29,7 @@ split="test", step_limit=STEP_LIMIT, filter_spec="", + instance_ids_file="", shuffle=False, ), ] diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual_mini.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual_mini.py index 79baf4cd..0609c2ac 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual_mini.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_multilingual_mini.py @@ -29,6 +29,7 @@ split="test", step_limit=STEP_LIMIT, filter_spec="", + instance_ids_file="", shuffle=False, ), ] diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified.py index 38be2127..93b242d9 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified.py @@ -29,6 +29,7 @@ split="test", step_limit=STEP_LIMIT, filter_spec="", + instance_ids_file="", shuffle=False, ), ] diff --git a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified_mini.py b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified_mini.py index e63aa0e3..b1e8c083 100644 --- a/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified_mini.py +++ b/ais_bench/configs/swe_bench_examples/mini_swe_agent_swe_bench_verified_mini.py @@ -29,6 +29,7 @@ split="test", step_limit=STEP_LIMIT, filter_spec="", + instance_ids_file="", shuffle=False, ), ] diff --git a/tests/UT/datasets/test_swebench.py b/tests/UT/datasets/test_swebench.py new file mode 100644 index 00000000..1f1ead33 --- /dev/null +++ b/tests/UT/datasets/test_swebench.py @@ -0,0 +1,49 @@ +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from ais_bench.benchmark.datasets.swebench import SWEBenchDataset +from ais_bench.benchmark.utils.logging.exceptions import FileOperationError + + +class TestSWEBenchDataset(unittest.TestCase): + def setUp(self): + self.dataset = object.__new__(SWEBenchDataset) + self.dataset.logger = mock.MagicMock() + + def test_load_instance_ids_file(self): + with tempfile.TemporaryDirectory() as temp_dir: + ids_file = Path(temp_dir) / "ids.txt" + ids_file.write_text("django__django-1\n\nsympy__sympy-2\nsympy__sympy-2\n", encoding="utf-8") + + instance_ids = self.dataset._load_instance_ids_file(str(ids_file)) + + self.assertEqual(instance_ids, {"django__django-1", "sympy__sympy-2"}) + + def test_load_instance_ids_file_requires_txt_suffix(self): + with tempfile.TemporaryDirectory() as temp_dir: + ids_file = Path(temp_dir) / "ids.csv" + ids_file.write_text("django__django-1\n", encoding="utf-8") + + with self.assertRaises(FileOperationError): + self.dataset._load_instance_ids_file(str(ids_file)) + + def test_filter_instances_by_filter_spec_and_instance_ids(self): + instances = [ + {"instance_id": "django__django-1"}, + {"instance_id": "django__django-2"}, + {"instance_id": "sympy__sympy-1"}, + ] + + filtered = self.dataset.filter_instances( + instances, + filter_spec=r"^django__", + instance_ids={"django__django-2", "sympy__sympy-1"}, + ) + + self.assertEqual(filtered, [{"instance_id": "django__django-2"}]) + + +if __name__ == "__main__": + unittest.main()