Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions ais_bench/benchmark/datasets/swebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,46 @@ def _parquet_data_files_from_dir(

@LOAD_DATASET.register_module()
class SWEBenchDataset(BaseDataset):
def _load_instance_ids_file(self, instance_ids_file: str) -> set[str]:
path = Path(instance_ids_file).expanduser()
if not path.is_file():
raise FileOperationError(
SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
f"SWE-Bench instance ids file does not exist: {instance_ids_file!r}",
)
if path.suffix.lower() != ".txt":
raise FileOperationError(
SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
f"SWE-Bench instance ids file must be a .txt file: {instance_ids_file!r}",
)

try:
instance_ids = {
line.strip()
for line in path.read_text(encoding="utf-8").splitlines()
if line.strip()
}
except OSError as e:
raise FileOperationError(
SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
f"Failed to read SWE-Bench instance ids file {instance_ids_file!r}: {e}",
)
return instance_ids
Comment on lines +55 to +79

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

  1. 移除不必要的 .txt 后缀限制:限制 instance_ids_file 必须为 .txt 后缀是不必要的。用户可能会使用 .log.csv.list 或无后缀的文件。只要文件是纯文本且每行一个 ID,就应该允许读取。移除此限制可以提升通用性和用户体验。
  2. 增加空文件校验:如果用户提供了一个空文件,当前代码会返回一个空集合 set(),导致后续过滤后的数据集为空,并在没有任何错误提示的情况下静默结束运行。建议在读取文件后,如果解析出的 instance_ids 为空,则抛出 FileOperationError 异常,以明确提示用户文件内容无效。
    def _load_instance_ids_file(self, instance_ids_file: str) -> set[str]:
        path = Path(instance_ids_file).expanduser()
        if not path.is_file():
            raise FileOperationError(
                SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
                f"SWE-Bench instance ids file does not exist: {instance_ids_file!r}",
            )

        try:
            instance_ids = {
                line.strip()
                for line in path.read_text(encoding="utf-8").splitlines()
                if line.strip()
            }
        except OSError as e:
            raise FileOperationError(
                SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
                f"Failed to read SWE-Bench instance ids file {instance_ids_file!r}: {e}",
            )

        if not instance_ids:
            raise FileOperationError(
                SWEB_CODES.LOCAL_PATH_RESOLVE_FAILED,
                f"SWE-Bench instance ids file is empty or contains no valid ids: {instance_ids_file!r}",
            )
        return instance_ids

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加这种约束是为了避免太宽松导致的一些不确定性问题


def filter_instances(
self, instances: list[dict], *, filter_spec: str, shuffle: bool = False
self,
instances: list[dict],
*,
filter_spec: str,
instance_ids: set[str] | None = None,
shuffle: bool = False,
) -> list[dict]:
"""Filter and slice a list of SWEBench instances."""
if shuffle:
instances = sorted(instances.copy(), key=lambda x: x["instance_id"])
random.seed(42)
random.shuffle(instances)

before_filter = len(instances)
instances = [
instance
Expand All @@ -70,6 +102,26 @@ def filter_instances(
self.logger.info(
f"Instance filter: {before_filter} -> {after_filter} instances"
)

if instance_ids is not None:
available_ids = {instance["instance_id"] for instance in instances}
missing_ids = instance_ids - available_ids
before_ids_filter = len(instances)
instances = [
instance
for instance in instances
if instance["instance_id"] in instance_ids
]
if (after_ids_filter := len(instances)) != before_ids_filter:
self.logger.info(
f"Instance ids file filter: {before_ids_filter} -> {after_ids_filter} instances"
)
if missing_ids:
self.logger.warning(
"Instance ids file contains %d ids not present after dataset/filter_spec selection: %s",
len(missing_ids),
", ".join(sorted(missing_ids)[:10]),
)
return instances

def load(
Expand All @@ -78,6 +130,7 @@ def load(
path: str = "",
split: str = "test",
filter_spec: str = "",
instance_ids_file: str = "",
shuffle: bool = False,
**kwargs,
):
Expand All @@ -87,6 +140,7 @@ def load(
path: The path to the dataset.
split (str): The split of the dataset to load.
filter_spec (str): The filter specification to apply to the dataset.
instance_ids_file (str): Text file containing one instance_id per line.
shuffle (bool): Whether to shuffle the dataset.
**kwargs: Additional keyword arguments.

Expand Down Expand Up @@ -146,5 +200,19 @@ def load(
SWEB_CODES.LOCAL_PARQUET_LOAD_FAILED,
f"Failed to load local swebench parquet from {root}: {e}",
)
dataset = self.filter_instances(list(dataset), filter_spec=filter_spec, shuffle=shuffle)
instance_ids = None
if instance_ids_file:
instance_ids = self._load_instance_ids_file(instance_ids_file)
self.logger.info(
"Loaded %d SWE-Bench instance ids from %s",
len(instance_ids),
instance_ids_file,
)

dataset = self.filter_instances(
list(dataset),
filter_spec=filter_spec,
instance_ids=instance_ids,
shuffle=shuffle,
)
return Dataset.from_list(dataset)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
split="test",
step_limit=STEP_LIMIT,
filter_spec="",
instance_ids_file="",
shuffle=False,
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
name="lite",
split="test",
filter_spec="",
instance_ids_file="",
shuffle=False,
step_limit=STEP_LIMIT,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
split="test",
step_limit=STEP_LIMIT,
filter_spec="",
instance_ids_file="",
shuffle=False,
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
split="test",
step_limit=STEP_LIMIT,
filter_spec="",
instance_ids_file="",
shuffle=False,
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
split="test",
step_limit=STEP_LIMIT,
filter_spec="",
instance_ids_file="",
shuffle=False,
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
split="test",
step_limit=STEP_LIMIT,
filter_spec="",
instance_ids_file="",
shuffle=False,
),
]
Expand Down
49 changes: 49 additions & 0 deletions tests/UT/datasets/test_swebench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tempfile
import unittest
from pathlib import Path
from unittest import mock

from ais_bench.benchmark.datasets.swebench import SWEBenchDataset
from ais_bench.benchmark.utils.logging.exceptions import FileOperationError


class TestSWEBenchDataset(unittest.TestCase):
def setUp(self):
self.dataset = object.__new__(SWEBenchDataset)
self.dataset.logger = mock.MagicMock()

def test_load_instance_ids_file(self):
with tempfile.TemporaryDirectory() as temp_dir:
ids_file = Path(temp_dir) / "ids.txt"
ids_file.write_text("django__django-1\n\nsympy__sympy-2\nsympy__sympy-2\n", encoding="utf-8")

instance_ids = self.dataset._load_instance_ids_file(str(ids_file))

self.assertEqual(instance_ids, {"django__django-1", "sympy__sympy-2"})

def test_load_instance_ids_file_requires_txt_suffix(self):
with tempfile.TemporaryDirectory() as temp_dir:
ids_file = Path(temp_dir) / "ids.csv"
ids_file.write_text("django__django-1\n", encoding="utf-8")

with self.assertRaises(FileOperationError):
self.dataset._load_instance_ids_file(str(ids_file))
Comment on lines +24 to +30

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

由于移除了对 .txt 后缀的限制,原有的 test_load_instance_ids_file_requires_txt_suffix 测试用例将不再适用。建议将其替换为针对空文件抛出 FileOperationError 异常的测试用例,以确保新添加的空文件校验逻辑得到充分测试。

Suggested change
def test_load_instance_ids_file_requires_txt_suffix(self):
with tempfile.TemporaryDirectory() as temp_dir:
ids_file = Path(temp_dir) / "ids.csv"
ids_file.write_text("django__django-1\n", encoding="utf-8")
with self.assertRaises(FileOperationError):
self.dataset._load_instance_ids_file(str(ids_file))
def test_load_instance_ids_file_empty_raises_error(self):
with tempfile.TemporaryDirectory() as temp_dir:
ids_file = Path(temp_dir) / "ids.txt"
ids_file.write_text(" \n\n \n", encoding="utf-8")
with self.assertRaises(FileOperationError):
self.dataset._load_instance_ids_file(str(ids_file))


def test_filter_instances_by_filter_spec_and_instance_ids(self):
instances = [
{"instance_id": "django__django-1"},
{"instance_id": "django__django-2"},
{"instance_id": "sympy__sympy-1"},
]

filtered = self.dataset.filter_instances(
instances,
filter_spec=r"^django__",
instance_ids={"django__django-2", "sympy__sympy-1"},
)

self.assertEqual(filtered, [{"instance_id": "django__django-2"}])


if __name__ == "__main__":
unittest.main()
Loading