From 23e972425720649834b3a47623e393ce1481cda7 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:35:09 -0400 Subject: [PATCH 1/9] Initiate test modules --- tests/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + From 2da904f6e1b2226a2cafa69c6a8d70bb8491259f Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:37:04 -0400 Subject: [PATCH 2/9] Add test for core components --- tests/test_core.py | 168 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tests/test_core.py diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..8a261b1 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,168 @@ +"""Tests for the core infrastructure (registry, base classes, results).""" + +from __future__ import annotations + +import pytest +import torch + +from detectzoo.core.base import BaseDetector, DetectionResult +from detectzoo.core.registry import ( + _ALIASES, + _REGISTRY, + list_detectors, + load_detector, +) + +VALID_MODALITIES = {"text", "image", "audio"} + + +class TestDetectionResult: + def test_fields(self): + r = DetectionResult(score=0.75, label="ai", confidence=0.6) + assert r.score == 0.75 + assert r.label == "ai" + assert r.confidence == 0.6 + assert r.metadata == {} + + def test_default_confidence_and_metadata(self): + r = DetectionResult(score=0.5, label="human") + assert r.confidence == 0.0 + assert r.metadata == {} + + def test_metadata_is_independent_per_instance(self): + a = DetectionResult(score=0.1, label="human") + b = DetectionResult(score=0.9, label="ai") + a.metadata["k"] = 1 + assert b.metadata == {} + + def test_repr(self): + r = DetectionResult(score=1.0, label="human", confidence=0.5) + assert "DetectionResult" in repr(r) + assert "1.0000" in repr(r) + + +class TestRegistry: + def test_detectors_registered(self): + names = list_detectors() + assert len(names) >= 24, f"Expected >=24 detectors, got {len(names)}: {names}" + + def test_registry_invariants(self): + """Every registered class must expose its registry name and a valid modality.""" + for name, cls in _REGISTRY.items(): + assert issubclass(cls, BaseDetector), f"{name} is not a BaseDetector" + assert cls.name == name, f"{name}: cls.name={cls.name!r} mismatches key" + assert cls.modality in VALID_MODALITIES, f"{name}: bad modality {cls.modality!r}" + + def test_text_detectors_present(self): + # Text detectors have no heavy optional deps, so they always load. + text = set(list_detectors("text")) + assert len(text) >= 18, f"Expected >=18 text detectors, got {sorted(text)}" + # A representative, stable subset that should always exist. + expected = { + "log_likelihood", "log_rank", "rank", "entropy", "detectgpt", + "fast_detectgpt", "binoculars", "lrr", "npr", "dna_gpt", + "revise_detect", "imbd", "lastde", "lastde_pp", "radar", + "text_fluoroscopy", "coco", "roberta_base", "roberta_large", + } + missing = expected - text + assert not missing, f"Missing expected text detectors: {missing}" + + def test_load_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown detector"): + load_detector("nonexistent_detector_xyz") + + def test_alias_resolution(self): + # roberta aliases are pure-text and resolve without any download. + assert _ALIASES.get("roberta_openai_base") == "roberta_base" + assert _ALIASES.get("roberta_openai_large") == "roberta_large" + # Every alias must point at a real, registered detector. + for alias, target in _ALIASES.items(): + assert target in _REGISTRY, f"Alias {alias!r} -> unknown target {target!r}" + + def test_list_by_modality_filters(self): + for name in list_detectors("text"): + assert _REGISTRY[name].modality == "text" + + def test_list_detectors_sorted(self): + names = list_detectors() + assert names == sorted(names) + + +class TestBaseDetector: + def _dummy(self, score: float, threshold: float = 0.5): + class _Dummy(BaseDetector): + name = "dummy_core" + modality = "text" + + def predict(self, input_data): + return self._make_result(score) + + return _Dummy(threshold=threshold) + + def test_make_result_above_threshold(self): + r = self._dummy(0.8).predict("hello") + assert r.label == "ai" + assert r.score == 0.8 + + def test_make_result_at_threshold_is_ai(self): + # label uses score >= threshold. + r = self._dummy(0.5, threshold=0.5).predict("x") + assert r.label == "ai" + + def test_make_result_below_threshold(self): + r = self._dummy(0.2).predict("hello") + assert r.label == "human" + + def test_confidence_in_unit_interval(self): + r = self._dummy(0.8).predict("hello") + assert 0.0 <= r.confidence <= 1.0 + assert r.confidence > 0.0 + + def test_make_result_passes_metadata(self): + class _Dummy(BaseDetector): + name = "dummy_meta" + modality = "text" + + def predict(self, input_data): + return self._make_result(0.9, extra="info", n=3) + + r = _Dummy().predict("x") + assert r.metadata == {"extra": "info", "n": 3} + + def test_predict_batch(self): + class _LenDummy(BaseDetector): + name = "dummy_len" + modality = "text" + + def predict(self, input_data): + return self._make_result(float(len(str(input_data))) / 100.0) + + results = _LenDummy().predict_batch(["a", "bb", "ccc"]) + assert len(results) == 3 + assert all(isinstance(r, DetectionResult) for r in results) + + def test_device_property_and_to(self): + d = self._dummy(0.5) + assert d.device == torch.device("cpu") + d.to("cpu") + assert d.device == torch.device("cpu") + + def test_unload_clears_modules(self): + class _ModelDummy(BaseDetector): + name = "dummy_model" + modality = "text" + + def __init__(self, **kw): + super().__init__(**kw) + self.net = torch.nn.Linear(2, 2) + + def predict(self, input_data): + return self._make_result(0.5) + + d = _ModelDummy() + assert isinstance(d.net, torch.nn.Module) + d.unload() + assert d.net is None + + def test_repr(self): + assert "dummy_core" in repr(self._dummy(0.5)) From fa90144887cad103b7ee9fd92bc6da3f36d66a4c Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:37:38 -0400 Subject: [PATCH 3/9] Add test for utils functions --- tests/test_utils.py | 118 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..682653f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,118 @@ +"""Tests for utility modules (io, metrics, logger).""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import pytest + +from detectzoo.utils.io import load_text +from detectzoo.utils.logger import get_logger +from detectzoo.utils.metrics import compute_metrics + + +class TestLoadText: + def test_raw_string(self): + assert load_text("hello world") == "hello world" + + def test_file_path(self, tmp_path: Path): + p = tmp_path / "sample.txt" + p.write_text("file content", encoding="utf-8") + assert load_text(str(p)) == "file content" + + def test_pathlib_input(self, tmp_path: Path): + p = tmp_path / "sample.txt" + p.write_text("via path object", encoding="utf-8") + assert load_text(p) == "via path object" + + def test_nonexistent_path_treated_as_text(self): + assert load_text("/no/such/file/here.txt") == "/no/such/file/here.txt" + + +class TestLoadImage: + def test_loads_rgb(self, tmp_path: Path): + Image = pytest.importorskip("PIL.Image") + from detectzoo.utils.io import load_image + + src = tmp_path / "img.png" + Image.new("L", (8, 8), color=128).save(src) + img = load_image(src) + assert img.mode == "RGB" + assert img.size == (8, 8) + + +class TestMetrics: + def test_perfect_predictions(self): + labels = [0, 0, 1, 1] + scores = [0.1, 0.2, 0.8, 0.9] + m = compute_metrics(labels, scores, threshold=0.5) + assert m["accuracy"] == 1.0 + assert m["f1"] == 1.0 + assert m["roc_auc"] == 1.0 + assert m["pr_auc"] == pytest.approx(1.0) + assert m["eer"] == pytest.approx(0.0, abs=1e-9) + + def test_all_wrong(self): + labels = [0, 0, 1, 1] + scores = [0.9, 0.8, 0.1, 0.2] + m = compute_metrics(labels, scores, threshold=0.5) + assert m["accuracy"] == 0.0 + assert m["roc_auc"] == 0.0 + assert m["eer"] == pytest.approx(1.0) + + def test_threshold_dependent_keys_present(self): + m = compute_metrics([0, 1], [0.2, 0.8], threshold=0.5) + for key in ( + "accuracy", "precision", "recall", "f1", "tpr", "fpr", + "roc_auc", "pr_auc", "avg_precision", "eer", + ): + assert key in m + + def test_tpr_equals_recall(self): + labels = [0, 0, 1, 1] + scores = [0.4, 0.6, 0.4, 0.9] # one FP, one FN + m = compute_metrics(labels, scores, threshold=0.5) + assert m["tpr"] == pytest.approx(m["recall"]) + + def test_single_class_auc_is_nan(self): + labels = [1, 1, 1] + scores = [0.9, 0.8, 0.7] + m = compute_metrics(labels, scores, threshold=0.5) + assert np.isnan(m["roc_auc"]) + assert np.isnan(m["pr_auc"]) + assert np.isnan(m["eer"]) + # threshold metrics are still computable for a single class + assert m["accuracy"] == 1.0 + + def test_non_finite_scores_are_dropped(self): + labels = [0, 1, 0, 1] + scores = [0.1, 0.9, float("nan"), float("inf")] + m = compute_metrics(labels, scores, threshold=0.5) + # Only the two finite, correctly-classified samples remain. + assert m["accuracy"] == 1.0 + assert m["roc_auc"] == 1.0 + + def test_all_non_finite_returns_nan(self): + m = compute_metrics([0, 1], [float("nan"), float("inf")], threshold=0.5) + assert np.isnan(m["accuracy"]) + assert np.isnan(m["roc_auc"]) + assert np.isnan(m["eer"]) + + def test_threshold_changes_predictions(self): + labels = [0, 1] + scores = [0.4, 0.6] + assert compute_metrics(labels, scores, threshold=0.5)["accuracy"] == 1.0 + # With a threshold above both scores, everything is predicted "human". + assert compute_metrics(labels, scores, threshold=0.99)["accuracy"] == 0.5 + + +class TestLogger: + def test_returns_logger(self): + log = get_logger("test_logger") + assert log.name == "test_logger" + assert isinstance(log, logging.Logger) + + def test_same_name_returns_same_instance(self): + assert get_logger("dz_shared") is get_logger("dz_shared") From 14371996d9256eb46e75dbbf12b675f3d8424b48 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:38:04 -0400 Subject: [PATCH 4/9] Add conftest --- tests/conftest.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d95d46f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,41 @@ +"""Shared pytest fixtures and helpers for the DetectZoo test-suite.""" + +from __future__ import annotations + +import importlib + +import pytest + +import detectzoo # noqa: F401 (ensures registries are populated) +from detectzoo.core.base import BaseDetector, DetectionResult + + +def require_modality(modality: str) -> None: + """Skip the current test if a modality's detector package is unavailable. + + DetectZoo loads modality subpackages on a best-effort basis (see + ``detectzoo/__init__.py``): if an optional heavy dependency such as + ``diffusers`` or ``timm`` is missing, the whole subpackage is skipped + with a warning rather than failing import. Tests that assert on a + modality's detectors must therefore skip gracefully when that package + could not be imported, so the suite stays green on partial installs. + """ + try: + importlib.import_module(f"detectzoo.detectors.{modality}") + except ImportError as exc: # pragma: no cover - depends on environment + pytest.skip(f"{modality} detectors unavailable ({exc})") + + +class DummyDetector(BaseDetector): + """Lightweight detector that scores text by its length (no models).""" + + name = "dummy" + modality = "text" + + def predict(self, input_data) -> DetectionResult: + return self._make_result(min(len(str(input_data)) / 100.0, 1.0)) + + +@pytest.fixture +def dummy_detector() -> DummyDetector: + return DummyDetector(threshold=0.5) From bb145968e2ff0393c3c16159740fb5b1945feb7d Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:39:12 -0400 Subject: [PATCH 5/9] Add test for few text detector to check workflow health --- tests/test_text_detectors.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/test_text_detectors.py diff --git a/tests/test_text_detectors.py b/tests/test_text_detectors.py new file mode 100644 index 0000000..66c121e --- /dev/null +++ b/tests/test_text_detectors.py @@ -0,0 +1,78 @@ +"""Tests for text-modality detectors. + +Tests marked ``@pytest.mark.slow`` download HuggingFace models (``gpt2``) +and are skipped by default. Run with ``pytest -m slow`` to include them. +All slow tests pin the (tiny) ``gpt2`` model on CPU to stay practical — +detectors that use a separate reference model default to multi-billion +parameter models, so those must be overridden explicitly. +""" + +from __future__ import annotations + +import pytest + +from detectzoo.core.base import DetectionResult +from detectzoo.core.registry import _REGISTRY, list_detectors + +_SAMPLE = "The quick brown fox jumps over the lazy dog." + + +class TestTextRegistryQuick: + """Lightweight checks that need no model download.""" + + def test_zero_shot_detectors_registered(self): + names = set(list_detectors("text")) + for n in ("log_likelihood", "log_rank", "entropy", "fast_detectgpt"): + assert n in names + + def test_classes_are_text_detectors(self): + for n in ("log_likelihood", "log_rank", "entropy"): + assert _REGISTRY[n].modality == "text" + + +@pytest.mark.slow +class TestLogLikelihoodDetector: + def test_predict(self): + from detectzoo.detectors.text.log_likelihood import LogLikelihoodDetector + + det = LogLikelihoodDetector(model_name="gpt2", device="cpu") + result = det.predict(_SAMPLE) + assert isinstance(result, DetectionResult) + assert isinstance(result.score, float) + + +@pytest.mark.slow +class TestLogRankDetector: + def test_predict(self): + from detectzoo.detectors.text.log_rank import LogRankDetector + + det = LogRankDetector(model_name="gpt2", device="cpu") + result = det.predict(_SAMPLE) + assert isinstance(result, DetectionResult) + + +@pytest.mark.slow +class TestEntropyDetector: + def test_predict(self): + from detectzoo.detectors.text.entropy import EntropyDetector + + det = EntropyDetector(model_name="gpt2", device="cpu") + result = det.predict(_SAMPLE) + assert isinstance(result, DetectionResult) + + +@pytest.mark.slow +class TestFastDetectGPT: + def test_predict_single_model(self): + from detectzoo.detectors.text.fast_detect_gpt import FastDetectGPTDetector + + # Override BOTH models to gpt2 — the default reference model is + # gpt-j-6B which is impractical to download for a test. + det = FastDetectGPTDetector( + model_name="gpt2", + reference_model_name="gpt2", + device="cpu", + ) + result = det.predict(_SAMPLE) + assert isinstance(result, DetectionResult) + assert "mean_log_prob" in result.metadata From 7bd2e253c8c839cd3b7269c93f426b976e702ca3 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:41:45 -0400 Subject: [PATCH 6/9] Add test for two image detectors to check workflow health --- tests/test_image_detectors.py | 57 +++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/test_image_detectors.py diff --git a/tests/test_image_detectors.py b/tests/test_image_detectors.py new file mode 100644 index 0000000..d25e764 --- /dev/null +++ b/tests/test_image_detectors.py @@ -0,0 +1,57 @@ +"""Tests for image-modality detectors. + +The current image detectors all load pretrained checkpoints at +construction time (see e.g. ``CNNSpotDetector``), so running an actual +prediction requires a network download and is marked ``@pytest.mark.slow``. +The non-slow tests verify registration and interface invariants only and +skip automatically when the image subpackage cannot be imported (missing +optional deps such as ``diffusers`` / ``timm`` / ``open_clip``). +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from detectzoo.core.base import BaseDetector, DetectionResult +from detectzoo.core.registry import _REGISTRY, list_detectors, load_detector + +from .conftest import require_modality + + +class TestImageRegistry: + def test_image_detectors_registered(self): + require_modality("image") + names = set(list_detectors("image")) + assert names, "No image detectors registered" + expected = {"cnnspot", "univfd", "aide", "freqnet", "patchcraft"} + missing = expected - names + assert not missing, f"Missing expected image detectors: {missing}" + + def test_image_detector_invariants(self): + require_modality("image") + for name in list_detectors("image"): + cls = _REGISTRY[name] + assert issubclass(cls, BaseDetector) + assert cls.modality == "image" + + def test_cnn_spot_alias(self): + require_modality("image") + from detectzoo.core.registry import _ALIASES + + assert _ALIASES.get("cnn_spot") == "cnnspot" + + +@pytest.mark.slow +class TestCNNSpotDetector: + def test_predict_on_random_image(self): + require_modality("image") + from PIL import Image + + det = load_detector("cnnspot", device="cpu") + rng = np.random.default_rng(0) + img = Image.fromarray(rng.integers(0, 255, (256, 256, 3), dtype=np.uint8)) + result = det.predict(img) + assert isinstance(result, DetectionResult) + assert 0.0 <= result.score <= 1.0 + assert result.label in ("ai", "human") From e44ecba8e9576bd8db2445c782affec0c670c416 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:42:56 -0400 Subject: [PATCH 7/9] Add test for audio detectors to check workflow health --- tests/test_audio_detectors.py | 55 +++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/test_audio_detectors.py diff --git a/tests/test_audio_detectors.py b/tests/test_audio_detectors.py new file mode 100644 index 0000000..464abdc --- /dev/null +++ b/tests/test_audio_detectors.py @@ -0,0 +1,55 @@ +"""Tests for audio-modality detectors. + +Audio detectors load pretrained checkpoints at construction time, so an +actual prediction requires a network download and is marked +``@pytest.mark.slow``. The non-slow tests verify registration and +interface invariants only, skipping when the audio subpackage cannot be +imported. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from detectzoo.core.base import BaseDetector, DetectionResult +from detectzoo.core.registry import _REGISTRY, list_detectors, load_detector + +from .conftest import require_modality + + +class TestAudioRegistry: + def test_audio_detectors_registered(self): + require_modality("audio") + names = set(list_detectors("audio")) + assert names, "No audio detectors registered" + expected = {"aasist", "rawnet2", "res_tssdnet", "samo"} + missing = expected - names + assert not missing, f"Missing expected audio detectors: {missing}" + + def test_audio_detector_invariants(self): + require_modality("audio") + for name in list_detectors("audio"): + cls = _REGISTRY[name] + assert issubclass(cls, BaseDetector) + assert cls.modality == "audio" + + def test_rawnet2_alias(self): + require_modality("audio") + from detectzoo.core.registry import _ALIASES + + assert _ALIASES.get("rawnet2_audio") == "rawnet2" + + +@pytest.mark.slow +class TestAASISTDetector: + def test_predict_with_synthetic_audio(self): + require_modality("audio") + + det = load_detector("aasist", device="cpu") + rng = np.random.default_rng(0) + waveform = rng.standard_normal(16000).astype(np.float32) + result = det.predict(waveform) + assert isinstance(result, DetectionResult) + assert 0.0 <= result.score <= 1.0 + assert "score_spoof" in result.metadata From cc741245117c1e6aa0fef678e433ab68d3f1ea52 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:43:33 -0400 Subject: [PATCH 8/9] Add test for datasets components --- tests/test_datasets.py | 206 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 tests/test_datasets.py diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..17177d1 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,206 @@ +"""Tests for the dataset module (base classes + registry).""" + +from __future__ import annotations + +import csv +from pathlib import Path + +import pytest + +from detectzoo.core.registry import ( + _DATASET_ALIASES, + _DATASET_REGISTRY, + list_datasets, + load_dataset, +) +from detectzoo.datasets.base import ( + BaseDataset, + CSVDataset, + DatasetItem, + SimpleDirectoryDataset, +) + + +class TestDatasetItem: + def test_fields(self): + item = DatasetItem(data="hello", label=1, metadata={"src": "gpt"}) + assert item.label == 1 + assert item.metadata["src"] == "gpt" + + def test_default_metadata(self): + item = DatasetItem(data="x", label=0) + assert item.metadata == {} + + +class TestSimpleDirectoryDataset: + def test_loads_files(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.txt").write_text("r1") + (real / "b.txt").write_text("r2") + (fake / "c.txt").write_text("f1") + + ds = SimpleDirectoryDataset(real, fake) + items = ds.load() + assert len(items) == 3 + labels = {it.label for it in items} + assert labels == {0, 1} + + def test_labels_match_directory(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.txt").write_text("r") + (fake / "b.txt").write_text("f") + + items = {Path(it.data).name: it.label for it in SimpleDirectoryDataset(real, fake).load()} + assert items["a.txt"] == 0 + assert items["b.txt"] == 1 + + def test_extension_filter(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.png").write_text("") + (real / "b.jpg").write_text("") + (real / "c.txt").write_text("") + (fake / "d.png").write_text("") + + ds = SimpleDirectoryDataset(real, fake, extensions=[".png"]) + assert len(ds.load()) == 2 + + def test_iter_and_len(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.txt").write_text("r") + (fake / "b.txt").write_text("f") + + ds = SimpleDirectoryDataset(real, fake) + assert len(ds) == 2 + assert sum(1 for _ in ds) == 2 + + def test_caches_items(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.txt").write_text("r") + (fake / "b.txt").write_text("f") + + ds = SimpleDirectoryDataset(real, fake) + assert ds.load() is ds.load() + + +class TestMaxSamples: + def _make(self, tmp_path: Path, n_real: int, n_fake: int, **kw): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + for i in range(n_real): + (real / f"r{i}.txt").write_text("r") + for i in range(n_fake): + (fake / f"f{i}.txt").write_text("f") + return SimpleDirectoryDataset(real, fake, **kw) + + def test_balanced_truncation(self, tmp_path: Path): + ds = self._make(tmp_path, 10, 10, max_samples=4) + items = ds.load() + assert len(items) == 4 + labels = [it.label for it in items] + assert labels.count(0) == 2 + assert labels.count(1) == 2 + + def test_fills_from_other_class_when_short(self, tmp_path: Path): + # Only 1 real sample, 10 fake; ask for 6 -> 1 real + 5 fake. + ds = self._make(tmp_path, 1, 10, max_samples=6) + items = ds.load() + assert len(items) == 6 + labels = [it.label for it in items] + assert labels.count(0) == 1 + assert labels.count(1) == 5 + + +class TestCSVDataset: + def test_loads_csv(self, tmp_path: Path): + csv_path = tmp_path / "data.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["text", "label"]) + writer.writeheader() + writer.writerow({"text": "hello world", "label": "0"}) + writer.writerow({"text": "ai text here", "label": "1"}) + + ds = CSVDataset(csv_path) + items = ds.load() + assert len(items) == 2 + assert items[0].label == 0 + assert items[1].data == "ai text here" + + def test_custom_columns(self, tmp_path: Path): + csv_path = tmp_path / "custom.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["body", "y"]) + writer.writeheader() + writer.writerow({"body": "content", "y": "1"}) + + ds = CSVDataset(csv_path, text_column="body", label_column="y") + items = ds.load() + assert items[0].data == "content" + assert items[0].label == 1 + + +class TestFromFactoryMethods: + def test_from_directory(self, tmp_path: Path): + real = tmp_path / "real" + fake = tmp_path / "fake" + real.mkdir() + fake.mkdir() + (real / "a.txt").write_text("r") + (fake / "b.txt").write_text("f") + + ds = BaseDataset.from_directory(real, fake) + assert isinstance(ds, SimpleDirectoryDataset) + assert len(ds.load()) == 2 + + def test_from_csv(self, tmp_path: Path): + csv_path = tmp_path / "test.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["text", "label"]) + writer.writeheader() + writer.writerow({"text": "sample", "label": "0"}) + + ds = BaseDataset.from_csv(csv_path) + assert isinstance(ds, CSVDataset) + assert len(ds.load()) == 1 + + +class TestDatasetRegistry: + def test_datasets_registered(self): + names = set(list_datasets()) + assert names, "No datasets registered" + # Text datasets have no heavy optional deps and should be present. + for n in ("hc3", "raid", "m4"): + assert n in names, f"{n} not registered; got {sorted(names)}" + + def test_registry_invariants(self): + for name, cls in _DATASET_REGISTRY.items(): + assert cls.name == name, f"{name}: cls.name mismatch ({cls.name!r})" + + def test_alias_targets_exist(self): + for alias, target in _DATASET_ALIASES.items(): + assert target in _DATASET_REGISTRY, f"Alias {alias!r} -> unknown {target!r}" + + def test_load_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown dataset"): + load_dataset("nonexistent_dataset_xyz") + + def test_list_by_modality(self): + text_names = list_datasets("text") + for n in text_names: + assert _DATASET_REGISTRY[n].modality == "text" From fa9a5fa9fa23469bf124c6094bf4cc461f14fd85 Mon Sep 17 00:00:00 2001 From: Sajad Ebrahimi Date: Tue, 2 Jun 2026 13:44:05 -0400 Subject: [PATCH 9/9] Add test for benchmarking components --- tests/test_benchmarks.py | 96 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/test_benchmarks.py diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py new file mode 100644 index 0000000..3da0bc8 --- /dev/null +++ b/tests/test_benchmarks.py @@ -0,0 +1,96 @@ +"""Tests for the BenchmarkEvaluator (no model downloads). + +A trivial in-memory dataset and a length-based dummy detector are used so +the evaluator's orchestration, metric aggregation, and persistence can be +exercised without any heavy dependencies. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import List + +import pytest + +from detectzoo.benchmarks.evaluator import BenchmarkEvaluator +from detectzoo.core.base import BaseDetector, DetectionResult +from detectzoo.datasets.base import BaseDataset, DatasetItem + + +class _MemoryDataset(BaseDataset): + name = "memory" + modality = "text" + + def __init__(self, items: List[DatasetItem], **kw): + super().__init__(**kw) + self._mem = items + + def _load_all(self) -> List[DatasetItem]: + return self._mem + + +class _KeywordDetector(BaseDetector): + """Scores 0.9 if 'ai' appears in the text, else 0.1 — perfectly separable.""" + + name = "keyword" + modality = "text" + + def predict(self, input_data) -> DetectionResult: + return self._make_result(0.9 if "ai" in str(input_data).lower() else 0.1) + + +@pytest.fixture +def dataset() -> _MemoryDataset: + items = [ + DatasetItem(data="a human wrote this", label=0), + DatasetItem(data="another genuine note", label=0), + DatasetItem(data="this is ai generated", label=1), + DatasetItem(data="ai produced output", label=1), + ] + return _MemoryDataset(items) + + +class TestBenchmarkEvaluator: + def test_evaluate_single(self, dataset): + ev = BenchmarkEvaluator(dataset) + metrics = ev.evaluate_single(_KeywordDetector()) + assert metrics["detector"] == "keyword" + assert metrics["n_samples"] == 4 + assert metrics["accuracy"] == 1.0 + assert metrics["roc_auc"] == 1.0 + + def test_save_scores(self, dataset): + ev = BenchmarkEvaluator(dataset) + metrics = ev.evaluate_single(_KeywordDetector(), save_scores=True) + assert "samples" in metrics + assert len(metrics["samples"]) == 4 + assert {"label", "score"} <= set(metrics["samples"][0]) + + def test_run_multiple(self, dataset): + ev = BenchmarkEvaluator(dataset) + results = ev.run([_KeywordDetector()]) + assert "keyword" in results + assert results["keyword"]["accuracy"] == 1.0 + + def test_run_and_save(self, dataset, tmp_path: Path): + ev = BenchmarkEvaluator(dataset) + out = tmp_path / "nested" / "results.json" + ev.run_and_save([_KeywordDetector()], out) + assert out.is_file() + payload = json.loads(out.read_text()) + assert payload["keyword"]["n_samples"] == 4 + + def test_run_and_save_with_meta(self, dataset, tmp_path: Path): + ev = BenchmarkEvaluator(dataset) + out = tmp_path / "results.json" + ev.run_and_save([_KeywordDetector()], out, meta={"run": "test"}) + payload = json.loads(out.read_text()) + assert payload["meta"] == {"run": "test"} + assert "keyword" in payload["results"] + + def test_modality_inferred_from_dataset(self, dataset): + assert BenchmarkEvaluator(dataset).modality == "text" + + def test_modality_override(self, dataset): + assert BenchmarkEvaluator(dataset, modality="audio").modality == "audio"