From c2972348a428cd7e25cacb73f343ac8ac88d46b7 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Mon, 6 Apr 2026 19:09:14 -0700 Subject: [PATCH 01/27] test --- README.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 README.txt diff --git a/README.txt b/README.txt new file mode 100644 index 000000000..71dfd5bac --- /dev/null +++ b/README.txt @@ -0,0 +1 @@ +README.txt From 4f99e8782ee0e79aa8ad48510ea33146fee3d331 Mon Sep 17 00:00:00 2001 From: jovianw Date: Mon, 6 Apr 2026 20:07:00 -0700 Subject: [PATCH 02/27] add: ecg-qa dataset --- pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/configs/ecg_qa.yaml | 16 +++ pyhealth/datasets/ecg_qa.py | 160 ++++++++++++++++++++++++++ 3 files changed, 177 insertions(+) create mode 100644 pyhealth/datasets/configs/ecg_qa.yaml create mode 100644 pyhealth/datasets/ecg_qa.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..1b0da6a39 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -53,6 +53,7 @@ def __init__(self, *args, **kwargs): from .cosmic import COSMICDataset from .covid19_cxr import COVID19CXRDataset from .dreamt import DREAMTDataset +from .ecg_qa import ECGQADataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset from .isruc import ISRUCDataset diff --git a/pyhealth/datasets/configs/ecg_qa.yaml b/pyhealth/datasets/configs/ecg_qa.yaml new file mode 100644 index 000000000..cc18708d6 --- /dev/null +++ b/pyhealth/datasets/configs/ecg_qa.yaml @@ -0,0 +1,16 @@ +version: "3.0.0" +tables: + ecg_qa: + file_path: "ecg-qa-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "ecg_id" + - "question" + - "answer" + - "question_type" + - "attribute_type" + - "template_id" + - "question_id" + - "sample_id" + - "attribute" diff --git a/pyhealth/datasets/ecg_qa.py b/pyhealth/datasets/ecg_qa.py new file mode 100644 index 000000000..99647fc97 --- /dev/null +++ b/pyhealth/datasets/ecg_qa.py @@ -0,0 +1,160 @@ +import json +import logging +import pandas as pd +from pathlib import Path +from typing import Optional + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class ECGQADataset(BaseDataset): + """ECG Question Answering dataset. + + This dataset provides natural language question-answer pairs linked to + ECG recordings via ecg_id. It is an annotation layer on top of ECG + recordings from PTB-XL or MIMIC-IV-ECG. + + The QA data originates from the ECG-QA dataset (Oh et al., 2024), + restructured for few-shot learning by Tang et al. (CHIL 2025). + + Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA + + Three question types are supported: + - single-verify: yes/no questions about ECG findings + - single-choose: multi-choice questions (answer is one option, "both", or "none") + - single-query: open-ended questions with free-form answers + + Args: + root: path to the paraphrased QA directory containing train/, valid/, + test/ subdirectories with JSON files. Works with both PTB-XL + (ecgqa/ptbxl/paraphrased/) and MIMIC-IV-ECG + (ecgqa/mimic-iv-ecg/paraphrased/) data. + dataset_name: name of the dataset. Default is "ecg_qa". + config_path: path to the YAML config file. Default uses built-in config. + + Examples: + >>> from pyhealth.datasets import ECGQADataset + >>> dataset = ECGQADataset( + ... root="/path/to/ecgqa/ptbxl/paraphrased/", + ... ) + >>> dataset.stats() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "ecg_qa.yaml" + + self.root = root + + self.prepare_metadata() + + # Check if CSV is in cache rather than root + root_path = Path(root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" + csv_name = "ecg-qa-pyhealth.csv" + + use_cache = False + if not (root_path / csv_name).exists() and (cache_dir / csv_name).exists(): + use_cache = True + + if use_cache: + logger.info(f"Using cached metadata from {cache_dir}") + root = str(cache_dir) + + super().__init__( + root=root, + tables=["ecg_qa"], + dataset_name=dataset_name or "ecg_qa", + config_path=config_path, + **kwargs, + ) + + def prepare_metadata(self) -> None: + """Build and save a metadata CSV from all ECG-QA JSON files. + + Scans train/, valid/, test/ subdirectories under root, loads all + JSON files, filters to single-* question types, and writes a + single CSV with columns: + patient_id, ecg_id, question, answer, question_type, + attribute_type, template_id, question_id, sample_id, attribute + """ + root = Path(self.root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" + csv_name = "ecg-qa-pyhealth.csv" + + shared_csv = root / csv_name + cache_csv = cache_dir / csv_name + if shared_csv.exists() or cache_csv.exists(): + return + + # Load all JSON files from all split directories + data = [] + for split_dir in ("train", "valid", "test"): + json_dir = root / split_dir + if not json_dir.is_dir(): + logger.warning("JSON directory not found: %s", json_dir) + continue + for fpath in sorted(json_dir.glob("*.json")): + with open(fpath, "r") as f: + data.extend(json.load(f)) + + if not data: + raise FileNotFoundError( + f"No JSON files found in train/valid/test subdirectories of {root}" + ) + + # Filter to single-* question types and build rows + rows: list[dict] = [] + for record in data: + qt = record.get("question_type", "") + if not qt.startswith("single-"): + continue + + ecg_id = record["ecg_id"][0] + answer = ";".join(record["answer"]) + attribute = ";".join(record.get("attribute", [])) + + rows.append({ + "patient_id": str(ecg_id), + "ecg_id": ecg_id, + "question": record["question"], + "answer": answer, + "question_type": qt, + "attribute_type": record.get("attribute_type", ""), + "template_id": record.get("template_id", 0), + "question_id": record.get("question_id", 0), + "sample_id": record.get("sample_id", 0), + "attribute": attribute, + }) + + if not rows: + raise ValueError("No single-* question type records found in JSON data") + + df = pd.DataFrame(rows) + df.sort_values(["patient_id", "question_type", "template_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + + # Try shared location first, fall back to cache + try: + shared_csv.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(shared_csv, index=False) + logger.info(f"Wrote metadata to {shared_csv}") + except (PermissionError, OSError): + cache_dir.mkdir(parents=True, exist_ok=True) + df.to_csv(cache_csv, index=False) + logger.info(f"Wrote metadata to cache: {cache_csv}") + + @property + def default_task(self): + """Returns the default task for the ECG-QA dataset: ECGQA.""" + from pyhealth.tasks import ECGQA + return ECGQA() From b65918c6fdf607a0da49b4dd734543ec97964c69 Mon Sep 17 00:00:00 2001 From: jovianw Date: Tue, 14 Apr 2026 19:46:43 -0700 Subject: [PATCH 03/27] feat: implement ECG-QA dataset download capability and testing --- pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/ecg_qa.py | 160 ----------------- pyhealth/datasets/ecgqa.py | 318 ++++++++++++++++++++++++++++++++++ pyhealth/tasks/__init__.py | 1 + tests/core/test_ecgqa.py | 216 +++++++++++++++++++++++ 5 files changed, 536 insertions(+), 161 deletions(-) delete mode 100644 pyhealth/datasets/ecg_qa.py create mode 100644 pyhealth/datasets/ecgqa.py create mode 100644 tests/core/test_ecgqa.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 1b0da6a39..fe2df73ce 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs): from .cosmic import COSMICDataset from .covid19_cxr import COVID19CXRDataset from .dreamt import DREAMTDataset -from .ecg_qa import ECGQADataset +from .ecgqa import ECGQADataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset from .isruc import ISRUCDataset diff --git a/pyhealth/datasets/ecg_qa.py b/pyhealth/datasets/ecg_qa.py deleted file mode 100644 index 99647fc97..000000000 --- a/pyhealth/datasets/ecg_qa.py +++ /dev/null @@ -1,160 +0,0 @@ -import json -import logging -import pandas as pd -from pathlib import Path -from typing import Optional - -from .base_dataset import BaseDataset - -logger = logging.getLogger(__name__) - - -class ECGQADataset(BaseDataset): - """ECG Question Answering dataset. - - This dataset provides natural language question-answer pairs linked to - ECG recordings via ecg_id. It is an annotation layer on top of ECG - recordings from PTB-XL or MIMIC-IV-ECG. - - The QA data originates from the ECG-QA dataset (Oh et al., 2024), - restructured for few-shot learning by Tang et al. (CHIL 2025). - - Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA - - Three question types are supported: - - single-verify: yes/no questions about ECG findings - - single-choose: multi-choice questions (answer is one option, "both", or "none") - - single-query: open-ended questions with free-form answers - - Args: - root: path to the paraphrased QA directory containing train/, valid/, - test/ subdirectories with JSON files. Works with both PTB-XL - (ecgqa/ptbxl/paraphrased/) and MIMIC-IV-ECG - (ecgqa/mimic-iv-ecg/paraphrased/) data. - dataset_name: name of the dataset. Default is "ecg_qa". - config_path: path to the YAML config file. Default uses built-in config. - - Examples: - >>> from pyhealth.datasets import ECGQADataset - >>> dataset = ECGQADataset( - ... root="/path/to/ecgqa/ptbxl/paraphrased/", - ... ) - >>> dataset.stats() - """ - - def __init__( - self, - root: str, - dataset_name: Optional[str] = None, - config_path: Optional[str] = None, - **kwargs, - ) -> None: - if config_path is None: - logger.info("No config path provided, using default config") - config_path = Path(__file__).parent / "configs" / "ecg_qa.yaml" - - self.root = root - - self.prepare_metadata() - - # Check if CSV is in cache rather than root - root_path = Path(root) - cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" - csv_name = "ecg-qa-pyhealth.csv" - - use_cache = False - if not (root_path / csv_name).exists() and (cache_dir / csv_name).exists(): - use_cache = True - - if use_cache: - logger.info(f"Using cached metadata from {cache_dir}") - root = str(cache_dir) - - super().__init__( - root=root, - tables=["ecg_qa"], - dataset_name=dataset_name or "ecg_qa", - config_path=config_path, - **kwargs, - ) - - def prepare_metadata(self) -> None: - """Build and save a metadata CSV from all ECG-QA JSON files. - - Scans train/, valid/, test/ subdirectories under root, loads all - JSON files, filters to single-* question types, and writes a - single CSV with columns: - patient_id, ecg_id, question, answer, question_type, - attribute_type, template_id, question_id, sample_id, attribute - """ - root = Path(self.root) - cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" - csv_name = "ecg-qa-pyhealth.csv" - - shared_csv = root / csv_name - cache_csv = cache_dir / csv_name - if shared_csv.exists() or cache_csv.exists(): - return - - # Load all JSON files from all split directories - data = [] - for split_dir in ("train", "valid", "test"): - json_dir = root / split_dir - if not json_dir.is_dir(): - logger.warning("JSON directory not found: %s", json_dir) - continue - for fpath in sorted(json_dir.glob("*.json")): - with open(fpath, "r") as f: - data.extend(json.load(f)) - - if not data: - raise FileNotFoundError( - f"No JSON files found in train/valid/test subdirectories of {root}" - ) - - # Filter to single-* question types and build rows - rows: list[dict] = [] - for record in data: - qt = record.get("question_type", "") - if not qt.startswith("single-"): - continue - - ecg_id = record["ecg_id"][0] - answer = ";".join(record["answer"]) - attribute = ";".join(record.get("attribute", [])) - - rows.append({ - "patient_id": str(ecg_id), - "ecg_id": ecg_id, - "question": record["question"], - "answer": answer, - "question_type": qt, - "attribute_type": record.get("attribute_type", ""), - "template_id": record.get("template_id", 0), - "question_id": record.get("question_id", 0), - "sample_id": record.get("sample_id", 0), - "attribute": attribute, - }) - - if not rows: - raise ValueError("No single-* question type records found in JSON data") - - df = pd.DataFrame(rows) - df.sort_values(["patient_id", "question_type", "template_id"], inplace=True) - df.reset_index(drop=True, inplace=True) - - # Try shared location first, fall back to cache - try: - shared_csv.parent.mkdir(parents=True, exist_ok=True) - df.to_csv(shared_csv, index=False) - logger.info(f"Wrote metadata to {shared_csv}") - except (PermissionError, OSError): - cache_dir.mkdir(parents=True, exist_ok=True) - df.to_csv(cache_csv, index=False) - logger.info(f"Wrote metadata to cache: {cache_csv}") - - @property - def default_task(self): - """Returns the default task for the ECG-QA dataset: ECGQA.""" - from pyhealth.tasks import ECGQA - return ECGQA() diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py new file mode 100644 index 000000000..b04526156 --- /dev/null +++ b/pyhealth/datasets/ecgqa.py @@ -0,0 +1,318 @@ +import hashlib +import json +import logging +import os +import tarfile +import urllib.request +from pathlib import Path +from typing import Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + +# (owner, commit_sha) for each variant of the dataset. +# Pinning to a commit SHA keeps the URL and MD5 stable as either repo evolves. +_REPO_BY_VARIANT = { + False: ("Tang-Jia-Lu", "b0ec9bd84ae2337052ca977941e37a703dcb492e"), + True: ("jovianw", "2e2d4ac185d6069c741d083269ea40ca01bfd50b"), +} +_MD5_BY_VARIANT = { + False: "894b4af304e99c48ecd62a914ba3ba2b", + True: "e65c4b6ae127103ad92a33ec9246039e", +} +_VALID_ECG_SOURCES = {"ptbxl": "ptbxl", "mimic": "mimic-iv-ecg"} + + +class ECGQADataset(BaseDataset): + """ECG Question Answering dataset. + + This dataset provides natural language question-answer pairs linked to + ECG recordings via ecg_id. It is an annotation layer on top of ECG + recordings from PTB-XL or MIMIC-IV-ECG. + + The QA data originates from the ECG-QA dataset (Oh et al., 2024), + restructured for few-shot learning by Tang et al. (CHIL 2025). + + Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA + + Three question types are supported: + - single-verify: yes/no questions about ECG findings + - single-choose: multi-choice questions (answer is one option, "both", or "none") + - single-query: open-ended questions with free-form answers + + Args: + root: directory that holds (or will hold) the paraphrased QA splits as + train/, valid/, test/ subdirectories of JSON files. + dataset_name: name of the dataset. Default is "ecg_qa". + config_path: path to the YAML config file. Default uses built-in config. + download: if True, download the chosen variant from GitHub into ``root`` + before loading. Defaults to False. + ecg_source: which underlying ECG dataset the QA pairs are grounded in. + One of ``"ptbxl"`` (PTB-XL) or ``"mimic"`` (MIMIC-IV-ECG). + Defaults to ``"ptbxl"``. + include_demographics: if True, download the modified variant whose + question text includes patient sex and age. Defaults to False + (the original Tang et al. release). + + Examples: + >>> from pyhealth.datasets import ECGQADataset + >>> # Use a pre-downloaded local copy + >>> dataset = ECGQADataset( + ... root="/path/to/ecgqa/ptbxl/paraphrased/", + ... ) + >>> # Or download the modified PTB-XL variant on the fly + >>> dataset = ECGQADataset( + ... root="./ecg_qa_ptbxl_demo", + ... download=True, + ... ecg_source="ptbxl", + ... include_demographics=True, + ... ) + >>> dataset.stats() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + download: bool = False, + ecg_source: str = "ptbxl", + include_demographics: bool = False, + **kwargs, + ) -> None: + if ecg_source not in _VALID_ECG_SOURCES: + raise ValueError( + f"ecg_source must be one of {sorted(_VALID_ECG_SOURCES)}, " + f"got {ecg_source!r}" + ) + + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "ecg_qa.yaml" + + self.root = root + + if download: + self._download_data(root, ecg_source, include_demographics) + self._verify_data(root) + + self.prepare_metadata() + + # Check if CSV is in cache rather than root + root_path = Path(root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" + csv_name = "ecg-qa-pyhealth.csv" + + use_cache = False + if not (root_path / csv_name).exists() and (cache_dir / csv_name).exists(): + use_cache = True + + if use_cache: + logger.info(f"Using cached metadata from {cache_dir}") + root = str(cache_dir) + + super().__init__( + root=root, + tables=["ecg_qa"], + dataset_name=dataset_name or "ecg_qa", + config_path=config_path, + **kwargs, + ) + + def prepare_metadata(self) -> None: + """Build and save a metadata CSV from all ECG-QA JSON files. + + Scans train/, valid/, test/ subdirectories under root, loads all + JSON files, filters to single-* question types, and writes a + single CSV with columns: + patient_id, ecg_id, question, answer, question_type, + attribute_type, template_id, question_id, sample_id, attribute + """ + root = Path(self.root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" + csv_name = "ecg-qa-pyhealth.csv" + + shared_csv = root / csv_name + cache_csv = cache_dir / csv_name + if shared_csv.exists() or cache_csv.exists(): + return + + # Load all JSON files from all split directories + data = [] + for split_dir in ("train", "valid", "test"): + json_dir = root / split_dir + if not json_dir.is_dir(): + logger.warning("JSON directory not found: %s", json_dir) + continue + for fpath in sorted(json_dir.glob("*.json")): + with open(fpath, "r") as f: + data.extend(json.load(f)) + + if not data: + raise FileNotFoundError( + f"No JSON files found in train/valid/test subdirectories of {root}" + ) + + # Filter to single-* question types and build rows + rows: list[dict] = [] + for record in data: + qt = record.get("question_type", "") + if not qt.startswith("single-"): + continue + + ecg_id = record["ecg_id"][0] + answer = ";".join(record["answer"]) + attribute = ";".join(record.get("attribute", [])) + + rows.append({ + "patient_id": str(ecg_id), + "ecg_id": ecg_id, + "question": record["question"], + "answer": answer, + "question_type": qt, + "attribute_type": record.get("attribute_type", ""), + "template_id": record.get("template_id", 0), + "question_id": record.get("question_id", 0), + "sample_id": record.get("sample_id", 0), + "attribute": attribute, + }) + + if not rows: + raise ValueError("No single-* question type records found in JSON data") + + df = pd.DataFrame(rows) + df.sort_values(["patient_id", "question_type", "template_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + + # Try shared location first, fall back to cache + try: + shared_csv.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(shared_csv, index=False) + logger.info(f"Wrote metadata to {shared_csv}") + except (PermissionError, OSError): + cache_dir.mkdir(parents=True, exist_ok=True) + df.to_csv(cache_csv, index=False) + logger.info(f"Wrote metadata to cache: {cache_csv}") + + def _download_data( + self, root: str, ecg_source: str, include_demographics: bool + ) -> None: + """Downloads the requested ECG-QA variant from GitHub into ``root``. + + Fetches a commit-pinned tarball of the upstream ``Tang-Jia-Lu`` repo + (or the user's ``jovianw`` fork when ``include_demographics`` is True), + verifies its MD5, and extracts only the + ``ecgqa//paraphrased/{train,valid,test}/`` subtree directly + into ``root``. The tarball itself is cached under ``{root}/.ecgqa-cache/`` + so a second call with the other ``ecg_source`` value can reuse it. + + Args: + root: directory the splits will land in. + ecg_source: ``"ptbxl"`` or ``"mimic"``. + include_demographics: selects the modified variant when True. + + Raises: + ValueError: if the downloaded tarball fails MD5 verification or + if it contains an unsafe path during extraction. + """ + owner, sha = _REPO_BY_VARIANT[include_demographics] + expected_md5 = _MD5_BY_VARIANT[include_demographics] + url = f"https://github.com/{owner}/FSL_ECG_QA/archive/{sha}.tar.gz" + + os.makedirs(root, exist_ok=True) + cache_dir = os.path.join(root, ".ecgqa-cache") + os.makedirs(cache_dir, exist_ok=True) + archive_path = os.path.join(cache_dir, f"{sha}.tar.gz") + + need_download = True + if os.path.isfile(archive_path): + with open(archive_path, "rb") as f: + if hashlib.md5(f.read()).hexdigest() == expected_md5: + logger.info(f"Reusing cached archive {archive_path}") + need_download = False + + if need_download: + logger.info(f"Downloading {url} -> {archive_path}") + urllib.request.urlretrieve(url, archive_path) + + logger.info(f"Checking MD5 checksum for {archive_path}...") + with open(archive_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != expected_md5: + msg = ( + f"Invalid MD5 checksum for {archive_path}: " + f"expected {expected_md5}, got {file_md5}" + ) + logger.error(msg) + raise ValueError(msg) + + ecg_source_dir = _VALID_ECG_SOURCES[ecg_source] + prefix = f"FSL_ECG_QA-{sha}/ecgqa/{ecg_source_dir}/paraphrased/" + abs_root = os.path.abspath(root) + + logger.info(f"Extracting {prefix}* from {archive_path} into {root}") + with tarfile.open(archive_path, "r:gz") as tar: + extracted = 0 + for member in tar.getmembers(): + if not member.name.startswith(prefix): + continue + rel = member.name[len(prefix):] + if not rel: + continue + + target_path = os.path.abspath(os.path.join(abs_root, rel)) + if os.path.commonpath([abs_root]) != os.path.commonpath( + [abs_root, target_path] + ): + msg = f"Unsafe path detected in tar file: '{member.name}'!" + logger.error(msg) + raise ValueError(msg) + + member.name = rel + tar.extract(member, path=root) + extracted += 1 + + logger.info(f"Download complete ({extracted} entries extracted)") + + def _verify_data(self, root: str) -> None: + """Verifies the presence and structure of the dataset directory. + + Checks that ``root`` exists, that ``train/``, ``valid/``, and ``test/`` + subdirectories are present, and that each contains at least one + ``*.json`` file. + + Args: + root: directory expected to hold the dataset splits. + + Raises: + FileNotFoundError: if ``root`` or any required split directory is + missing. + ValueError: if a split directory contains no JSON files. + """ + if not os.path.exists(root): + msg = f"Dataset path does not exist: {root}" + logger.error(msg) + raise FileNotFoundError(msg) + + for split in ("train", "valid", "test"): + split_dir = os.path.join(root, split) + if not os.path.isdir(split_dir): + msg = ( + f"Dataset path must contain a '{split}' subdirectory: " + f"{split_dir}" + ) + logger.error(msg) + raise FileNotFoundError(msg) + if not list(Path(split_dir).glob("*.json")): + msg = f"Dataset '{split}' directory must contain JSON files!" + logger.error(msg) + raise ValueError(msg) + + @property + def default_task(self): + """Returns the default task for the ECG-QA dataset: ECGQA.""" + from pyhealth.tasks import ECGQA + return ECGQA() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..957f23d82 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -10,6 +10,7 @@ cardiology_isWA_fn, ) from .chestxray14_binary_classification import ChestXray14BinaryClassification +from .ecgqa_preprocess import ECGQA from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py new file mode 100644 index 000000000..343c49cec --- /dev/null +++ b/tests/core/test_ecgqa.py @@ -0,0 +1,216 @@ +import unittest +import tempfile +import shutil +import json +from pathlib import Path +from unittest.mock import patch + +from pyhealth.datasets import ECGQADataset +from pyhealth.tasks import ECGQA + + +class TestECGQADataset(unittest.TestCase): + """Test ECG-QA dataset with synthetic test data.""" + + def setUp(self): + """Set up train/valid/test JSON files and a temp cache dir.""" + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + + train_records = [ + { + "ecg_id": [1], + "question": "Does this ECG show normal sinus rhythm?", + "answer": ["yes"], + "question_type": "single-verify", + "attribute_type": "scp_code", + "template_id": 1, + "question_id": 100, + "sample_id": 1000, + "attribute": ["NORM"], + }, + { + "ecg_id": [2], + "question": "What rhythm does this ECG show?", + "answer": ["sinus rhythm", "atrial fibrillation"], + "question_type": "single-choose", + "attribute_type": "rhythm", + "template_id": 2, + "question_id": 101, + "sample_id": 1001, + "attribute": ["SR", "AFIB"], + }, + { + "ecg_id": [3], + "question": "Are both left axis deviation and right bundle branch block present?", + "answer": ["yes"], + "question_type": "comparison-verify", + "attribute_type": "scp_code", + "template_id": 3, + "question_id": 102, + "sample_id": 1002, + "attribute": ["LAD", "RBBB"], + }, + ] + valid_records = [ + { + "ecg_id": [4], + "question": "What ECG abnormalities are present?", + "answer": ["left axis deviation"], + "question_type": "single-query", + "attribute_type": "scp_code", + "template_id": 4, + "question_id": 103, + "sample_id": 1003, + "attribute": ["LAD"], + }, + ] + test_records = [ + { + "ecg_id": [5], + "question": "Is the heart rate above 100 beats per minute?", + "answer": ["no"], + "question_type": "single-verify", + "attribute_type": "heart_rate", + "template_id": 5, + "question_id": 104, + "sample_id": 1004, + "attribute": ["bradycardia"], + }, + ] + + (self.root / "train" / "00.json").write_text(json.dumps(train_records)) + (self.root / "valid" / "00.json").write_text(json.dumps(valid_records)) + (self.root / "test" / "00.json").write_text(json.dumps(test_records)) + + # Redirect Path.home() into the temp dir so that prepare_metadata's + # ~/.cache/pyhealth/ecg_qa fallback cannot find a pre-existing user + # cache and shadow the test fixture with stale data. + self._home_patch = patch.object(Path, "home", return_value=self.root) + self._home_patch.start() + + self._cache_tmp = tempfile.mkdtemp() + + def tearDown(self): + """Clean up temporary files""" + self._home_patch.stop() + shutil.rmtree(self.temp_dir, ignore_errors=True) + shutil.rmtree(self._cache_tmp, ignore_errors=True) + + def test_dataset_initialization(self): + """Test ECGQADataset initialization""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "ecg_qa") + self.assertEqual(dataset.root, str(self.root)) + + def test_metadata_file_created(self): + """Test ecg-qa-pyhealth.csv is created in root""" + ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + metadata_file = self.root / "ecg-qa-pyhealth.csv" + self.assertTrue(metadata_file.exists()) + + def test_patient_count(self): + """Test only single-* records are loaded (4 of 5)""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + # ecg_ids 1, 2, 4, 5 survive; ecg_id 3 is comparison-verify and is filtered. + self.assertEqual(len(dataset.unique_patient_ids), 4) + self.assertEqual( + sorted(dataset.unique_patient_ids), ["1", "2", "4", "5"] + ) + + def test_filters_non_single_question_types(self): + """Test that comparison-verify records are dropped""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + self.assertNotIn("3", dataset.unique_patient_ids) + + def test_stats_method(self): + """Test stats method runs without error""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + dataset.stats() + + def test_get_patient(self): + """Test get_patient method""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + patient = dataset.get_patient("1") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "1") + + def test_get_patient_not_found(self): + """Test that patient not found throws error.""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + with self.assertRaises(AssertionError): + dataset.get_patient("999") + + def test_single_verify_event_fields(self): + """Test a single-verify event surfaces the expected attributes""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + events = dataset.get_patient("1").get_events() + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["question_type"], "single-verify") + self.assertEqual(events[0]["answer"], "yes") + self.assertEqual(events[0]["attribute"], "NORM") + + def test_single_choose_event_joins_multi_valued_fields(self): + """Test single-choose records join answer/attribute with ';'""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + events = dataset.get_patient("2").get_events() + self.assertEqual(events[0]["answer"], "sinus rhythm;atrial fibrillation") + self.assertEqual(events[0]["attribute"], "SR;AFIB") + + def test_default_task(self): + """Test default_task returns an ECGQA instance""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + self.assertIsInstance(dataset.default_task, ECGQA) + + def test_set_task_ecgqa(self): + """Test ECGQA task yields one sample per QA pair""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + samples = dataset.set_task(ECGQA()) + self.assertEqual(len(samples), 4) + + def test_invalid_ecg_source_raises(self): + """Test ValueError on invalid ecg_source""" + with self.assertRaises(ValueError): + ECGQADataset( + root=str(self.root), + ecg_source="nope", + cache_dir=self._cache_tmp, + ) + + +class TestECGQAVerifyData(unittest.TestCase): + """Test the structural checks performed by ECGQADataset._verify_data.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_nonexistent_root_raises(self): + """Test FileNotFoundError when root does not exist""" + bogus = self.root / "does-not-exist" + with self.assertRaises(FileNotFoundError): + ECGQADataset(root=str(bogus)) + + def test_missing_split_dir_raises(self): + """Test FileNotFoundError when train/valid/test dirs are missing""" + with self.assertRaises(FileNotFoundError): + ECGQADataset(root=str(self.root)) + + def test_empty_split_dir_raises(self): + """Test ValueError when a split dir has no JSON files""" + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + with self.assertRaises(ValueError): + ECGQADataset(root=str(self.root)) + + +if __name__ == "__main__": + unittest.main() From 9aefcf411fdfa0b0750d57238c07c3de796db1ae Mon Sep 17 00:00:00 2001 From: jovianw Date: Tue, 14 Apr 2026 19:51:50 -0700 Subject: [PATCH 04/27] remove: reference to unused base task --- tests/core/test_ecgqa.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py index 343c49cec..50c6bb8bf 100644 --- a/tests/core/test_ecgqa.py +++ b/tests/core/test_ecgqa.py @@ -6,7 +6,6 @@ from unittest.mock import patch from pyhealth.datasets import ECGQADataset -from pyhealth.tasks import ECGQA class TestECGQADataset(unittest.TestCase): @@ -162,17 +161,6 @@ def test_single_choose_event_joins_multi_valued_fields(self): self.assertEqual(events[0]["answer"], "sinus rhythm;atrial fibrillation") self.assertEqual(events[0]["attribute"], "SR;AFIB") - def test_default_task(self): - """Test default_task returns an ECGQA instance""" - dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - self.assertIsInstance(dataset.default_task, ECGQA) - - def test_set_task_ecgqa(self): - """Test ECGQA task yields one sample per QA pair""" - dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - samples = dataset.set_task(ECGQA()) - self.assertEqual(len(samples), 4) - def test_invalid_ecg_source_raises(self): """Test ValueError on invalid ecg_source""" with self.assertRaises(ValueError): From 5c2ed4e8bfac5807eb888bc4ae2a132735c2bf96 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Tue, 14 Apr 2026 20:01:40 -0700 Subject: [PATCH 05/27] add new signal dataset for PTB-XL --- pyhealth/datasets/PTB-XL1.0.3.py | 116 +++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 pyhealth/datasets/PTB-XL1.0.3.py diff --git a/pyhealth/datasets/PTB-XL1.0.3.py b/pyhealth/datasets/PTB-XL1.0.3.py new file mode 100644 index 000000000..39b7d988b --- /dev/null +++ b/pyhealth/datasets/PTB-XL1.0.3.py @@ -0,0 +1,116 @@ +""" + +Pyhealth dataset for the 1.0.3 PTB-XL dataset. + +Dataset link: + https://physionet.org/content/ptb-xl/1.0.3/ + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Yiyun Wang (yiyunw3@illinois.edu) + +""" +import pandas as pd +import os +import urllib.request +import requests +from pyhealth.datasets import BaseSignalDataset +import zipfile +from pathlib import Path + +""" +Dataset class for the PTB-XL 1.0.3 dataset. + +Args: + dataset_name: name of the dataset. + root: root directory of the raw data (should contain many csv files). + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. +""" +class PTBXLDataset(BaseSignalDataset): + """ + Initialize the PTB-XL dataset. + + Attributes: + root (str): Root directory of the raw data. + download (bool): True iff requested to download dataset. Default to False. + """ + def __init__(self, + root: str = '.', + download: bool = False, + down_sampled: bool = False) -> None: + + # Determine the root path, where most of the data is stored + # self.data_path: str = os.path.join(root, 'ptb_xl_processed_full.zip') + self.data_path: str = os.path.join(root, 'test.zip') + self.root = root + + # Determine signal path, where to fetch the signal samples + signal_folder = 'records100' if down_sampled else 'records500' + root_path = os.path.join(root, 'physionet.org/files/ptb-xl/1.0.3/') if download else root + self.signal_path: str = os.path.join(root_path, signal_folder) + + # Download the dataset from online source to root if needed + self._download(download) + + super().__init__( + root=root, + dataset_name='PTB-XL1.0.3', + ) + + + """ + Download PTB-XL dataset from public google drive sources. + It will contain both the original and downsampled versions, + in /records500 and /records100 folder respectively. + """ + def _download(self, download) -> None: + + if download: + # zip_id = '1IE-4Co1fLRoEI9jez2pwuf9HPmFRzuLX' # full + zip_id = '1Q9Ksxj4gSrsHVb6qqICI0nm0K8HWtDW2' # test + response = requests.get(f'https://drive.google.com/uc?export=download&id={zip_id}') + with open(self.data_path, 'wb') as file: + file.write(response.content) + + with zipfile.ZipFile(self.data_path, "r") as z: + z.extractall(self.root) + + """ + Process and return a dictionary of the requested PTB-XL data for each patient. + Each patient will have a corresponding object that contains + load_from_path, patient_id, signal_file, label_file, and save_to_path. + """ + def process_EEG_data(self): + patients = {} + + for dirpath, dirnames, filenames in os.walk(self.signal_path): + for filename in filenames: + f = Path(filename).stem + pid = f.split('_')[0] + + if pid not in patients: + patients[pid] = [ + { + "load_from_path": dirpath, + "patient_id": pid, + "signal_file": f + '.dat', + "label_file": f + '.hea', + "save_to_path": self.filepath, + } + ] + + return patients + +if __name__ == "__main__": + dataset = PTBXLDataset(root='../../../../', download=False, down_sampled=True) + dataset.stat() + dataset.info() + print(dataset.process_EEG_data()) From 10af81344e1d2877b407f93156b6dc8b3afb7954 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Tue, 14 Apr 2026 20:07:33 -0700 Subject: [PATCH 06/27] rm read me --- README.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 README.txt diff --git a/README.txt b/README.txt deleted file mode 100644 index 71dfd5bac..000000000 --- a/README.txt +++ /dev/null @@ -1 +0,0 @@ -README.txt From 74d70c76eebec31f31617a4e5e4e7c9bc22e1952 Mon Sep 17 00:00:00 2001 From: jovianw Date: Tue, 14 Apr 2026 20:11:49 -0700 Subject: [PATCH 07/27] refactor: remove extra caching logic --- pyhealth/datasets/ecgqa.py | 131 ++++++++++++------------------------- tests/core/test_ecgqa.py | 8 --- 2 files changed, 43 insertions(+), 96 deletions(-) diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index b04526156..60d2660f4 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -13,16 +13,6 @@ logger = logging.getLogger(__name__) -# (owner, commit_sha) for each variant of the dataset. -# Pinning to a commit SHA keeps the URL and MD5 stable as either repo evolves. -_REPO_BY_VARIANT = { - False: ("Tang-Jia-Lu", "b0ec9bd84ae2337052ca977941e37a703dcb492e"), - True: ("jovianw", "2e2d4ac185d6069c741d083269ea40ca01bfd50b"), -} -_MD5_BY_VARIANT = { - False: "894b4af304e99c48ecd62a914ba3ba2b", - True: "e65c4b6ae127103ad92a33ec9246039e", -} _VALID_ECG_SOURCES = {"ptbxl": "ptbxl", "mimic": "mimic-iv-ecg"} @@ -101,19 +91,6 @@ def __init__( self.prepare_metadata() - # Check if CSV is in cache rather than root - root_path = Path(root) - cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" - csv_name = "ecg-qa-pyhealth.csv" - - use_cache = False - if not (root_path / csv_name).exists() and (cache_dir / csv_name).exists(): - use_cache = True - - if use_cache: - logger.info(f"Using cached metadata from {cache_dir}") - root = str(cache_dir) - super().__init__( root=root, tables=["ecg_qa"], @@ -132,31 +109,16 @@ def prepare_metadata(self) -> None: attribute_type, template_id, question_id, sample_id, attribute """ root = Path(self.root) - cache_dir = Path.home() / ".cache" / "pyhealth" / "ecg_qa" - csv_name = "ecg-qa-pyhealth.csv" - - shared_csv = root / csv_name - cache_csv = cache_dir / csv_name - if shared_csv.exists() or cache_csv.exists(): + csv_path = root / "ecg-qa-pyhealth.csv" + if csv_path.exists(): return - # Load all JSON files from all split directories data = [] for split_dir in ("train", "valid", "test"): - json_dir = root / split_dir - if not json_dir.is_dir(): - logger.warning("JSON directory not found: %s", json_dir) - continue - for fpath in sorted(json_dir.glob("*.json")): + for fpath in sorted((root / split_dir).glob("*.json")): with open(fpath, "r") as f: data.extend(json.load(f)) - if not data: - raise FileNotFoundError( - f"No JSON files found in train/valid/test subdirectories of {root}" - ) - - # Filter to single-* question types and build rows rows: list[dict] = [] for record in data: qt = record.get("question_type", "") @@ -186,28 +148,19 @@ def prepare_metadata(self) -> None: df = pd.DataFrame(rows) df.sort_values(["patient_id", "question_type", "template_id"], inplace=True) df.reset_index(drop=True, inplace=True) - - # Try shared location first, fall back to cache - try: - shared_csv.parent.mkdir(parents=True, exist_ok=True) - df.to_csv(shared_csv, index=False) - logger.info(f"Wrote metadata to {shared_csv}") - except (PermissionError, OSError): - cache_dir.mkdir(parents=True, exist_ok=True) - df.to_csv(cache_csv, index=False) - logger.info(f"Wrote metadata to cache: {cache_csv}") + df.to_csv(csv_path, index=False) + logger.info(f"Wrote metadata to {csv_path}") def _download_data( self, root: str, ecg_source: str, include_demographics: bool ) -> None: - """Downloads the requested ECG-QA variant from GitHub into ``root``. + """Downloads the requested ECG-QA dataset from GitHub into ``root``. - Fetches a commit-pinned tarball of the upstream ``Tang-Jia-Lu`` repo - (or the user's ``jovianw`` fork when ``include_demographics`` is True), + Fetches a commit-pinned tarball from the original ``Tang-Jia-Lu`` repo + (or a modified fork when ``include_demographics`` is True), verifies its MD5, and extracts only the ``ecgqa//paraphrased/{train,valid,test}/`` subtree directly - into ``root``. The tarball itself is cached under ``{root}/.ecgqa-cache/`` - so a second call with the other ``ecg_source`` value can reuse it. + into ``root``. The tarball is deleted after extraction. Args: root: directory the splits will land in. @@ -218,44 +171,46 @@ def _download_data( ValueError: if the downloaded tarball fails MD5 verification or if it contains an unsafe path during extraction. """ - owner, sha = _REPO_BY_VARIANT[include_demographics] - expected_md5 = _MD5_BY_VARIANT[include_demographics] - url = f"https://github.com/{owner}/FSL_ECG_QA/archive/{sha}.tar.gz" + # URLs are pinned to specific commit SHAs so the MD5s below stay stable + # even if either repo gains new commits later. + if include_demographics: + url = ( + "https://github.com/jovianw/FSL_ECG_QA/archive/" + "2e2d4ac185d6069c741d083269ea40ca01bfd50b.tar.gz" + ) + expected_md5 = "e65c4b6ae127103ad92a33ec9246039e" + archive_prefix = "FSL_ECG_QA-2e2d4ac185d6069c741d083269ea40ca01bfd50b" + else: + url = ( + "https://github.com/Tang-Jia-Lu/FSL_ECG_QA/archive/" + "b0ec9bd84ae2337052ca977941e37a703dcb492e.tar.gz" + ) + expected_md5 = "894b4af304e99c48ecd62a914ba3ba2b" + archive_prefix = "FSL_ECG_QA-b0ec9bd84ae2337052ca977941e37a703dcb492e" os.makedirs(root, exist_ok=True) - cache_dir = os.path.join(root, ".ecgqa-cache") - os.makedirs(cache_dir, exist_ok=True) - archive_path = os.path.join(cache_dir, f"{sha}.tar.gz") - - need_download = True - if os.path.isfile(archive_path): - with open(archive_path, "rb") as f: - if hashlib.md5(f.read()).hexdigest() == expected_md5: - logger.info(f"Reusing cached archive {archive_path}") - need_download = False - - if need_download: - logger.info(f"Downloading {url} -> {archive_path}") - urllib.request.urlretrieve(url, archive_path) - - logger.info(f"Checking MD5 checksum for {archive_path}...") - with open(archive_path, "rb") as f: - file_md5 = hashlib.md5(f.read()).hexdigest() - if file_md5 != expected_md5: - msg = ( - f"Invalid MD5 checksum for {archive_path}: " - f"expected {expected_md5}, got {file_md5}" - ) - logger.error(msg) - raise ValueError(msg) + archive_path = os.path.join(root, "ecgqa-download.tar.gz") + + logger.info(f"Downloading {url} -> {archive_path}") + urllib.request.urlretrieve(url, archive_path) + + logger.info(f"Checking MD5 checksum for {archive_path}...") + with open(archive_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != expected_md5: + msg = ( + f"Invalid MD5 checksum for {archive_path}: " + f"expected {expected_md5}, got {file_md5}" + ) + logger.error(msg) + raise ValueError(msg) ecg_source_dir = _VALID_ECG_SOURCES[ecg_source] - prefix = f"FSL_ECG_QA-{sha}/ecgqa/{ecg_source_dir}/paraphrased/" + prefix = f"{archive_prefix}/ecgqa/{ecg_source_dir}/paraphrased/" abs_root = os.path.abspath(root) logger.info(f"Extracting {prefix}* from {archive_path} into {root}") with tarfile.open(archive_path, "r:gz") as tar: - extracted = 0 for member in tar.getmembers(): if not member.name.startswith(prefix): continue @@ -273,9 +228,9 @@ def _download_data( member.name = rel tar.extract(member, path=root) - extracted += 1 - logger.info(f"Download complete ({extracted} entries extracted)") + os.remove(archive_path) + logger.info("Download complete") def _verify_data(self, root: str) -> None: """Verifies the presence and structure of the dataset directory. diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py index 50c6bb8bf..3915f6228 100644 --- a/tests/core/test_ecgqa.py +++ b/tests/core/test_ecgqa.py @@ -3,7 +3,6 @@ import shutil import json from pathlib import Path -from unittest.mock import patch from pyhealth.datasets import ECGQADataset @@ -85,17 +84,10 @@ def setUp(self): (self.root / "valid" / "00.json").write_text(json.dumps(valid_records)) (self.root / "test" / "00.json").write_text(json.dumps(test_records)) - # Redirect Path.home() into the temp dir so that prepare_metadata's - # ~/.cache/pyhealth/ecg_qa fallback cannot find a pre-existing user - # cache and shadow the test fixture with stale data. - self._home_patch = patch.object(Path, "home", return_value=self.root) - self._home_patch.start() - self._cache_tmp = tempfile.mkdtemp() def tearDown(self): """Clean up temporary files""" - self._home_patch.stop() shutil.rmtree(self.temp_dir, ignore_errors=True) shutil.rmtree(self._cache_tmp, ignore_errors=True) From 2c1c2e939ddee807896bddf172dbfefd6f8d92d9 Mon Sep 17 00:00:00 2001 From: jovianw Date: Tue, 14 Apr 2026 20:21:50 -0700 Subject: [PATCH 08/27] add: ecgqa dataset docs --- docs/api/datasets.rst | 1 + docs/api/datasets/pyhealth.datasets.ECGQADataset.rst | 9 +++++++++ docs/api/tasks.rst | 1 + docs/api/tasks/pyhealth.tasks.ECGQA.rst | 7 +++++++ 4 files changed, 18 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.ECGQADataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.ECGQA.rst diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..f0d45e758 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.ECGQADataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst b/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst new file mode 100644 index 000000000..789764aef --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.ECGQADataset +=================================== + +The ECG-QA dataset (Oh et al., 2024) provides natural-language question-answer pairs grounded in ECG recordings from PTB-XL or MIMIC-IV-ECG, restructured for few-shot learning by Tang et al. (CHIL 2025). For more information see the `FSL_ECG_QA repository `_. + +.. autoclass:: pyhealth.datasets.ECGQADataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..e67b20e61 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -225,6 +225,7 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification + ECG Question Answering Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ECGQA.rst b/docs/api/tasks/pyhealth.tasks.ECGQA.rst new file mode 100644 index 000000000..4350a77f7 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ECGQA.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ECGQA +======================================= + +.. autoclass:: pyhealth.tasks.ECGQA + :members: + :undoc-members: + :show-inheritance: From 7c85f6b76c2b2fb601859c33ac1c5dcc08bd8373 Mon Sep 17 00:00:00 2001 From: Matthew Pham Date: Tue, 14 Apr 2026 21:04:10 -0700 Subject: [PATCH 09/27] add task --- ....tasks.ptbxl_diagnostic_classification.rst | 11 +++ pyhealth/tasks/__init__.py | 1 + .../tasks/ptbxl_diagnostic_classification.py | 85 +++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst create mode 100644 pyhealth/tasks/ptbxl_diagnostic_classification.py diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst new file mode 100644 index 000000000..4aeef91a5 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst @@ -0,0 +1,11 @@ +PTB-XL Diagnostic Classification +================================ + +.. currentmodule:: pyhealth.tasks.ptbxl_diagnostic_classification + +.. autoclass:: PTBXLDiagnosticClassification + :members: + :show-inheritance: + :exclude-members: __init__ + + .. automethod:: __call__ \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 957f23d82..9cbcae5b6 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,6 +11,7 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .ecgqa_preprocess import ECGQA +from .ptbxl_diagnostic_classification import PTBXLDiagnosticClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 diff --git a/pyhealth/tasks/ptbxl_diagnostic_classification.py b/pyhealth/tasks/ptbxl_diagnostic_classification.py new file mode 100644 index 000000000..68e6e5ea5 --- /dev/null +++ b/pyhealth/tasks/ptbxl_diagnostic_classification.py @@ -0,0 +1,85 @@ +""" +PyHealth task for ECG diagnostic classification using the PTB-XL dataset. + +Dataset link: + https://physionet.org/content/ptb-xl/1.0.1/ + +Dataset paper: (please cite if you use this dataset) + Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. + "PTB-XL, a large publicly available electrocardiography dataset." + Scientific Data, 7(1), 1-15. (2020). + +Dataset paper link: + https://www.nature.com/articles/s41597-020-0495-6 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import logging +import os +import scipy.io as sio +from typing import Dict, List + +from pyhealth.data import Patient, Event +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +class PTBXLDiagnosticClassification(BaseTask): + """ + A PyHealth task class for multi-label diagnostic classification of ECGs + in the PTB-XL dataset, utilizing resampled 500Hz signal features. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input (signal). + output_schema (Dict[str, str]): The schema for the task output (multi-label). + """ + task_name: str = "PTBXLDiagnosticClassification" + # x is the resampled signal, metadata provides clinical context + input_schema: Dict[str, str] = {"signal": "signal", "metadata": "tabular"} + output_schema: Dict[str, str] = {"label": "multilabel"} + + def __init__(self, root: str) -> None: + """ + Initializes the PTBXLDiagnosticClassification task. + + Args: + root (str): The root directory where the resampled .mat files + (12 leads, 2500 samples) are stored. + """ + if not os.path.exists(root): + msg = f"Signal root path does not exist: {root}" + logger.error(msg) + raise FileNotFoundError(msg) + + self.root = root + + def __call__(self, patient: Patient) -> List[Dict]: + events: List[Event] = patient.get_events(event_type="ptbxl") + samples = [] + + for event in events: + ecg_id = int(event["ecg_id"]) + subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" + + file_name = f"{str(ecg_id).zfill(5)}_hr.mat" + signal_path = os.path.join(self.root, "records100", subfolder, file_name) + + try: + mat_data = sio.loadmat(signal_path) + signal_data = mat_data['feats'] + + samples.append({ + "signal": signal_data, + "metadata": [patient.age, patient.sex], + "label": event["label"], + "record_id": ecg_id + }) + except Exception as e: + logger.warning(f"Could not load signal for record {ecg_id}: {e}") + continue + + return samples \ No newline at end of file From a8c70c8a2d43cf8b06d7befac323e1440ed5e1ed Mon Sep 17 00:00:00 2001 From: Yiyun Date: Tue, 14 Apr 2026 21:31:02 -0700 Subject: [PATCH 10/27] rename and added rst --- .../pyhealth.datasets.PTBXLDataset.rst | 9 ++++++++ .../{PTB-XL1.0.3.py => ptbxl_1_0_3.py} | 22 ++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst rename pyhealth/datasets/{PTB-XL1.0.3.py => ptbxl_1_0_3.py} (84%) diff --git a/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst new file mode 100644 index 000000000..39e29d253 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.PTBXLDataset +=================================== + +The PTB-XL 1.0.3 dataset. For the original dataset see `here `_. + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/pyhealth/datasets/PTB-XL1.0.3.py b/pyhealth/datasets/ptbxl_1_0_3.py similarity index 84% rename from pyhealth/datasets/PTB-XL1.0.3.py rename to pyhealth/datasets/ptbxl_1_0_3.py index 39b7d988b..6626f8bf6 100644 --- a/pyhealth/datasets/PTB-XL1.0.3.py +++ b/pyhealth/datasets/ptbxl_1_0_3.py @@ -19,8 +19,9 @@ import os import urllib.request import requests -from pyhealth.datasets import BaseSignalDataset import zipfile +import random +from pyhealth.datasets import BaseSignalDataset from pathlib import Path """ @@ -28,7 +29,8 @@ Args: dataset_name: name of the dataset. - root: root directory of the raw data (should contain many csv files). + root: root directory of the raw data. + Expected to contain folders for original (records500) or downsampled (records100) data with determined names. dev: whether to enable dev mode (only use a small subset of the data). Default is False. refresh_cache: whether to refresh the cache; if true, the dataset will @@ -41,11 +43,16 @@ class PTBXLDataset(BaseSignalDataset): Attributes: root (str): Root directory of the raw data. download (bool): True iff requested to download dataset. Default to False. + dev (bool): True iff enable dev mode. + downsampled (bool): True iff use downsampled signal data. """ def __init__(self, root: str = '.', download: bool = False, - down_sampled: bool = False) -> None: + dev: bool = False, + downsampled: bool = False) -> None: + + self.dev = dev # Determine the root path, where most of the data is stored # self.data_path: str = os.path.join(root, 'ptb_xl_processed_full.zip') @@ -53,7 +60,7 @@ def __init__(self, self.root = root # Determine signal path, where to fetch the signal samples - signal_folder = 'records100' if down_sampled else 'records500' + signal_folder = 'records100' if downsampled else 'records500' root_path = os.path.join(root, 'physionet.org/files/ptb-xl/1.0.3/') if download else root self.signal_path: str = os.path.join(root_path, signal_folder) @@ -107,10 +114,15 @@ def process_EEG_data(self): } ] + if self.dev: + keys = random.sample(list(patients), min(len(patients), 5)) + values = [d[k] for k in keys] + return dict(zip(keys, values)) + return patients if __name__ == "__main__": - dataset = PTBXLDataset(root='../../../../', download=False, down_sampled=True) + dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True) dataset.stat() dataset.info() print(dataset.process_EEG_data()) From 74b7c59e6cc659f9b532da00ec2e2b31009b9263 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Tue, 14 Apr 2026 22:06:15 -0700 Subject: [PATCH 11/27] added actual test file, grammar fixes --- pyhealth/datasets/ptbxl_1_0_3.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pyhealth/datasets/ptbxl_1_0_3.py b/pyhealth/datasets/ptbxl_1_0_3.py index 6626f8bf6..13a268a58 100644 --- a/pyhealth/datasets/ptbxl_1_0_3.py +++ b/pyhealth/datasets/ptbxl_1_0_3.py @@ -55,13 +55,12 @@ def __init__(self, self.dev = dev # Determine the root path, where most of the data is stored - # self.data_path: str = os.path.join(root, 'ptb_xl_processed_full.zip') - self.data_path: str = os.path.join(root, 'test.zip') + self.data_path: str = os.path.join(root, 'ptb_xl_processed_final.zip') self.root = root # Determine signal path, where to fetch the signal samples signal_folder = 'records100' if downsampled else 'records500' - root_path = os.path.join(root, 'physionet.org/files/ptb-xl/1.0.3/') if download else root + root_path = os.path.join(root, 'ptb_xl_processed_final/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3') if download else root self.signal_path: str = os.path.join(root_path, signal_folder) # Download the dataset from online source to root if needed @@ -81,8 +80,7 @@ def __init__(self, def _download(self, download) -> None: if download: - # zip_id = '1IE-4Co1fLRoEI9jez2pwuf9HPmFRzuLX' # full - zip_id = '1Q9Ksxj4gSrsHVb6qqICI0nm0K8HWtDW2' # test + zip_id = '1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2' response = requests.get(f'https://drive.google.com/uc?export=download&id={zip_id}') with open(self.data_path, 'wb') as file: file.write(response.content) @@ -114,15 +112,15 @@ def process_EEG_data(self): } ] - if self.dev: - keys = random.sample(list(patients), min(len(patients), 5)) - values = [d[k] for k in keys] - return dict(zip(keys, values)) + if self.dev: + keys = random.sample(list(patients), min(len(patients), 5)) + values = [d[k] for k in keys] + return dict(zip(keys, values)) return patients if __name__ == "__main__": - dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True) + dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True, dev=False) dataset.stat() dataset.info() print(dataset.process_EEG_data()) From 674f05c48856d3f7cc072fdf66f007bcf39f7a8a Mon Sep 17 00:00:00 2001 From: Yiyun Date: Tue, 14 Apr 2026 22:50:14 -0700 Subject: [PATCH 12/27] add doc file and add to init, add unit test --- docs/api/datasets.rst | 1 + pyhealth/datasets/__init__.py | 1 + tests/core/test_ptbxl_1_0_3.py | 37 ++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 tests/core/test_ptbxl_1_0_3.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..8ee2d1dbd 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,3 +245,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.PTBXLDataset diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..1ab935a27 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .ptbxl_1_0_3 import PTBXLDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/tests/core/test_ptbxl_1_0_3.py b/tests/core/test_ptbxl_1_0_3.py new file mode 100644 index 000000000..ef969857f --- /dev/null +++ b/tests/core/test_ptbxl_1_0_3.py @@ -0,0 +1,37 @@ +import unittest +from pyhealth.datasets import PTBXLDataset + +class TestPTBXLDataset(unittest.TestCase): + """ + Test PTB-XL 1.0.3 dataset with demo data. + """ + def setUp(self): + self.dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True, dev=True) + + """ + Verify if the dataset contains correct basic information + """ + def testBasicInfo(self): + self.assertEqual(self.dataset.dataset_name, 'PTB-XL1.0.3') + self.assertIsInstance(self.dataset.filepath, str) + + """ + Verify if the dataset contains expected data + """ + def testData(self): + self.assertIsInstance(self.dataset.patients, dict) + self.assertLessEqual(len(self.dataset.patients), 5) + for pid, value in self.dataset.patients.items(): + self.assertIsNotNone(pid) + self.assertIsInstance(pid, str) + + self.assertIsInstance(value, list) + self.assertIsInstance(value[0], dict) + self.assertIsNotNone(value[0]['load_from_path']) + self.assertIsNotNone(value[0]['patient_id']) + self.assertIsNotNone(value[0]['signal_file']) + self.assertIsNotNone(value[0]['label_file']) + self.assertIsNotNone(value[0]['save_to_path']) + +if __name__ == "__main__": + unittest.main() From b5fd75400dc84ac2d2b741342ecddc53f5ce8678 Mon Sep 17 00:00:00 2001 From: Matthew Pham Date: Wed, 15 Apr 2026 17:47:38 -0700 Subject: [PATCH 13/27] modify task to do resampling --- docs/api/tasks.rst | 1 + ....tasks.ptbxl_diagnostic_classification.rst | 11 --- .../tasks/pyhealth.tasks.ptbxl_resampling.rst | 11 +++ pyhealth/tasks/__init__.py | 2 +- .../tasks/ptbxl_diagnostic_classification.py | 85 ------------------- pyhealth/tasks/ptbxl_resampling.py | 73 ++++++++++++++++ 6 files changed, 86 insertions(+), 97 deletions(-) delete mode 100644 docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst create mode 100644 docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst delete mode 100644 pyhealth/tasks/ptbxl_diagnostic_classification.py create mode 100644 pyhealth/tasks/ptbxl_resampling.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index e67b20e61..b9de6168b 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -226,6 +226,7 @@ Available Tasks ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification ECG Question Answering + PTB-XL Signal Resampling Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst deleted file mode 100644 index 4aeef91a5..000000000 --- a/docs/api/tasks/pyhealth.tasks.ptbxl_diagnostic_classification.rst +++ /dev/null @@ -1,11 +0,0 @@ -PTB-XL Diagnostic Classification -================================ - -.. currentmodule:: pyhealth.tasks.ptbxl_diagnostic_classification - -.. autoclass:: PTBXLDiagnosticClassification - :members: - :show-inheritance: - :exclude-members: __init__ - - .. automethod:: __call__ \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst new file mode 100644 index 000000000..c6a47182e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst @@ -0,0 +1,11 @@ +PTB-XL Signal Resampling +======================== + +.. currentmodule:: pyhealth.tasks.ptbxl_resampling + +.. autoclass:: PTBXLResampling + :members: + :show-inheritance: + :exclude-members: __init__ + + .. automethod:: __call__ \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 9cbcae5b6..c54fcf00e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,7 +11,7 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .ecgqa_preprocess import ECGQA -from .ptbxl_diagnostic_classification import PTBXLDiagnosticClassification +from .ptbxl_resampling import PTBXLResamplingTask from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 diff --git a/pyhealth/tasks/ptbxl_diagnostic_classification.py b/pyhealth/tasks/ptbxl_diagnostic_classification.py deleted file mode 100644 index 68e6e5ea5..000000000 --- a/pyhealth/tasks/ptbxl_diagnostic_classification.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -PyHealth task for ECG diagnostic classification using the PTB-XL dataset. - -Dataset link: - https://physionet.org/content/ptb-xl/1.0.1/ - -Dataset paper: (please cite if you use this dataset) - Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. - "PTB-XL, a large publicly available electrocardiography dataset." - Scientific Data, 7(1), 1-15. (2020). - -Dataset paper link: - https://www.nature.com/articles/s41597-020-0495-6 - -Author: - Jovian Wang (jovianw2@illinois.edu) - Matthew Pham (mdpham2@illinois.edu) - Yiyun Wang (yiyunw3@illinois.edu) -""" -import logging -import os -import scipy.io as sio -from typing import Dict, List - -from pyhealth.data import Patient, Event -from pyhealth.tasks import BaseTask - -logger = logging.getLogger(__name__) - -class PTBXLDiagnosticClassification(BaseTask): - """ - A PyHealth task class for multi-label diagnostic classification of ECGs - in the PTB-XL dataset, utilizing resampled 500Hz signal features. - - Attributes: - task_name (str): The name of the task. - input_schema (Dict[str, str]): The schema for the task input (signal). - output_schema (Dict[str, str]): The schema for the task output (multi-label). - """ - task_name: str = "PTBXLDiagnosticClassification" - # x is the resampled signal, metadata provides clinical context - input_schema: Dict[str, str] = {"signal": "signal", "metadata": "tabular"} - output_schema: Dict[str, str] = {"label": "multilabel"} - - def __init__(self, root: str) -> None: - """ - Initializes the PTBXLDiagnosticClassification task. - - Args: - root (str): The root directory where the resampled .mat files - (12 leads, 2500 samples) are stored. - """ - if not os.path.exists(root): - msg = f"Signal root path does not exist: {root}" - logger.error(msg) - raise FileNotFoundError(msg) - - self.root = root - - def __call__(self, patient: Patient) -> List[Dict]: - events: List[Event] = patient.get_events(event_type="ptbxl") - samples = [] - - for event in events: - ecg_id = int(event["ecg_id"]) - subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" - - file_name = f"{str(ecg_id).zfill(5)}_hr.mat" - signal_path = os.path.join(self.root, "records100", subfolder, file_name) - - try: - mat_data = sio.loadmat(signal_path) - signal_data = mat_data['feats'] - - samples.append({ - "signal": signal_data, - "metadata": [patient.age, patient.sex], - "label": event["label"], - "record_id": ecg_id - }) - except Exception as e: - logger.warning(f"Could not load signal for record {ecg_id}: {e}") - continue - - return samples \ No newline at end of file diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py new file mode 100644 index 000000000..ff3cac82f --- /dev/null +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -0,0 +1,73 @@ +""" +A PyHealth task for ECG signal resampling/super-resolution on the PTB-XL dataset. + +Dataset link: + https://physionet.org/content/ptb-xl/1.0.1/ + +Dataset paper: (please cite if you use this dataset) + Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. + "PTB-XL, a large publicly available electrocardiography dataset." + Scientific Data, 7(1), 1-15. (2020). + +Dataset paper link: + https://www.nature.com/articles/s41597-020-0495-6 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import logging +import os +import scipy.io as sio +from typing import Dict, List + +from pyhealth.data import Patient, Event +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +class PTBXLResamplingTask(BaseTask): + task_name: str = "PTBXLResampling" + + # Input: 100Hz Signal | Output: 500Hz Signal (the ground truth) + input_schema: Dict[str, str] = {"low_res": "signal"} + output_schema: Dict[str, str] = {"high_res": "signal"} + + def __init__(self, root: str) -> None: + """ + Args: + root (str): Root directory containing BOTH 'records100' and 'records500'. + """ + if not os.path.exists(root): + raise FileNotFoundError(f"Root path does not exist: {root}") + self.root = root + + def __call__(self, patient: Patient) -> List[Dict]: + events: List[Event] = patient.get_events(event_type="ptbxl") + samples = [] + + for event in events: + ecg_id = int(event["ecg_id"]) + subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" + + lr_path = os.path.join(self.root, "records100", subfolder, f"{str(ecg_id).zfill(5)}_lr") + hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") + + try: + lr_record = wfdb.rdrecord(lr_path) + lr_signal = lr_record.p_signal.T # Transpose to (12, 1000) + + hr_record = wfdb.rdrecord(hr_path) + hr_signal = hr_record.p_signal.T # Transpose to (12, 2500) + + samples.append({ + "low_res": lr_signal.astype(np.float32), + "high_res": hr_signal.astype(np.float32), + "record_id": ecg_id + }) + + except Exception as e: + continue + + return samples \ No newline at end of file From 43fcfbd12c8cd751b8f3bf6eb5a2a882627665cd Mon Sep 17 00:00:00 2001 From: Matthew Pham Date: Wed, 15 Apr 2026 17:53:51 -0700 Subject: [PATCH 14/27] add comment --- pyhealth/tasks/ptbxl_resampling.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py index ff3cac82f..0d5f237dc 100644 --- a/pyhealth/tasks/ptbxl_resampling.py +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -28,6 +28,16 @@ logger = logging.getLogger(__name__) class PTBXLResamplingTask(BaseTask): + """ + This task maps a low-resolution (100Hz) ECG signal to its + high-resolution (500Hz) counterpart. + + Attributes: + task_name (str): The name of the task ("PTBXLResampling"). + input_schema (Dict[str, str]): "low_res": signal at 100Hz. + output_schema (Dict[str, str]): "high_res": signal at 500Hz. + """ + task_name: str = "PTBXLResampling" # Input: 100Hz Signal | Output: 500Hz Signal (the ground truth) From 13b5dc00ea986a97fbea10c15d51dbd60d59961e Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 17:57:22 -0700 Subject: [PATCH 15/27] add author name, re-arranged text order --- pyhealth/datasets/ecgqa.py | 104 ++++++++++++++++++------------- pyhealth/datasets/ptbxl_1_0_3.py | 6 +- 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index 60d2660f4..504db202f 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -1,3 +1,27 @@ +""" +ECG Question Answering dataset. + +This dataset provides natural language question-answer pairs linked to +ECG recordings via ecg_id. It is an annotation layer on top of ECG +recordings from PTB-XL or MIMIC-IV-ECG. + +The QA data originates from the ECG-QA dataset (Oh et al., 2024), +restructured for few-shot learning by Tang et al. (CHIL 2025). + +Dataset link: + Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" import hashlib import json import logging @@ -15,53 +39,43 @@ _VALID_ECG_SOURCES = {"ptbxl": "ptbxl", "mimic": "mimic-iv-ecg"} - +""" +Three question types are supported: + - single-verify: yes/no questions about ECG findings + - single-choose: multi-choice questions (answer is one option, "both", or "none") + - single-query: open-ended questions with free-form answers + +Args: + root: directory that holds (or will hold) the paraphrased QA splits as + train/, valid/, test/ subdirectories of JSON files. + dataset_name: name of the dataset. Default is "ecg_qa". + config_path: path to the YAML config file. Default uses built-in config. + download: if True, download the chosen variant from GitHub into ``root`` + before loading. Defaults to False. + ecg_source: which underlying ECG dataset the QA pairs are grounded in. + One of ``"ptbxl"`` (PTB-XL) or ``"mimic"`` (MIMIC-IV-ECG). + Defaults to ``"ptbxl"``. + include_demographics: if True, download the modified variant whose + question text includes patient sex and age. Defaults to False + (the original Tang et al. release). + +Examples: + >>> from pyhealth.datasets import ECGQADataset + >>> # Use a pre-downloaded local copy + >>> dataset = ECGQADataset( + ... root="/path/to/ecgqa/ptbxl/paraphrased/", + ... ) + >>> # Or download the modified PTB-XL variant on the fly + >>> dataset = ECGQADataset( + ... root="./ecg_qa_ptbxl_demo", + ... download=True, + ... ecg_source="ptbxl", + ... include_demographics=True, + ... ) + >>> dataset.stats() +""" class ECGQADataset(BaseDataset): - """ECG Question Answering dataset. - - This dataset provides natural language question-answer pairs linked to - ECG recordings via ecg_id. It is an annotation layer on top of ECG - recordings from PTB-XL or MIMIC-IV-ECG. - - The QA data originates from the ECG-QA dataset (Oh et al., 2024), - restructured for few-shot learning by Tang et al. (CHIL 2025). - - Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA - Three question types are supported: - - single-verify: yes/no questions about ECG findings - - single-choose: multi-choice questions (answer is one option, "both", or "none") - - single-query: open-ended questions with free-form answers - - Args: - root: directory that holds (or will hold) the paraphrased QA splits as - train/, valid/, test/ subdirectories of JSON files. - dataset_name: name of the dataset. Default is "ecg_qa". - config_path: path to the YAML config file. Default uses built-in config. - download: if True, download the chosen variant from GitHub into ``root`` - before loading. Defaults to False. - ecg_source: which underlying ECG dataset the QA pairs are grounded in. - One of ``"ptbxl"`` (PTB-XL) or ``"mimic"`` (MIMIC-IV-ECG). - Defaults to ``"ptbxl"``. - include_demographics: if True, download the modified variant whose - question text includes patient sex and age. Defaults to False - (the original Tang et al. release). - - Examples: - >>> from pyhealth.datasets import ECGQADataset - >>> # Use a pre-downloaded local copy - >>> dataset = ECGQADataset( - ... root="/path/to/ecgqa/ptbxl/paraphrased/", - ... ) - >>> # Or download the modified PTB-XL variant on the fly - >>> dataset = ECGQADataset( - ... root="./ecg_qa_ptbxl_demo", - ... download=True, - ... ecg_source="ptbxl", - ... include_demographics=True, - ... ) - >>> dataset.stats() - """ def __init__( self, diff --git a/pyhealth/datasets/ptbxl_1_0_3.py b/pyhealth/datasets/ptbxl_1_0_3.py index 13a268a58..05eae110a 100644 --- a/pyhealth/datasets/ptbxl_1_0_3.py +++ b/pyhealth/datasets/ptbxl_1_0_3.py @@ -1,5 +1,4 @@ """ - Pyhealth dataset for the 1.0.3 PTB-XL dataset. Dataset link: @@ -12,8 +11,9 @@ https://arxiv.org/abs/2410.14464 Author: - Yiyun Wang (yiyunw3@illinois.edu) - + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) """ import pandas as pd import os From 1f48cb769d0d5d56bf050e17776bb8a8544e0efa Mon Sep 17 00:00:00 2001 From: Matthew Pham Date: Wed, 15 Apr 2026 17:58:27 -0700 Subject: [PATCH 16/27] perform resampling --- pyhealth/tasks/ptbxl_resampling.py | 40 ++++++++++++++---------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py index 0d5f237dc..9afad725e 100644 --- a/pyhealth/tasks/ptbxl_resampling.py +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -1,5 +1,5 @@ """ -A PyHealth task for ECG signal resampling/super-resolution on the PTB-XL dataset. +A PyHealth task that performs dynamic resampling of ECG signals. Dataset link: https://physionet.org/content/ptb-xl/1.0.1/ @@ -19,7 +19,9 @@ """ import logging import os -import scipy.io as sio +import wfdb +import numpy as np +from scipy import signal from typing import Dict, List from pyhealth.data import Patient, Event @@ -27,30 +29,25 @@ logger = logging.getLogger(__name__) -class PTBXLResamplingTask(BaseTask): +class PTBXLResampling(BaseTask): """ - This task maps a low-resolution (100Hz) ECG signal to its - high-resolution (500Hz) counterpart. - - Attributes: - task_name (str): The name of the task ("PTBXLResampling"). - input_schema (Dict[str, str]): "low_res": signal at 100Hz. - output_schema (Dict[str, str]): "high_res": signal at 500Hz. + This task reads the 100Hz (low-res) raw signal and uses scipy to + mathematically generate the 500Hz (high-res) version as the target. """ - task_name: str = "PTBXLResampling" - # Input: 100Hz Signal | Output: 500Hz Signal (the ground truth) + # Input: The original 100Hz signal + # Output: The target signal generated via resampling logic input_schema: Dict[str, str] = {"low_res": "signal"} output_schema: Dict[str, str] = {"high_res": "signal"} def __init__(self, root: str) -> None: """ Args: - root (str): Root directory containing BOTH 'records100' and 'records500'. + root (str): The path to the PTB-XL dataset (containing records100). """ if not os.path.exists(root): - raise FileNotFoundError(f"Root path does not exist: {root}") + raise FileNotFoundError(f"Path not found: {root}") self.root = root def __call__(self, patient: Patient) -> List[Dict]: @@ -62,22 +59,21 @@ def __call__(self, patient: Patient) -> List[Dict]: subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" lr_path = os.path.join(self.root, "records100", subfolder, f"{str(ecg_id).zfill(5)}_lr") - hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") try: - lr_record = wfdb.rdrecord(lr_path) - lr_signal = lr_record.p_signal.T # Transpose to (12, 1000) - - hr_record = wfdb.rdrecord(hr_path) - hr_signal = hr_record.p_signal.T # Transpose to (12, 2500) + record = wfdb.rdrecord(lr_path) + lr_data = record.p_signal.T # Shape: (12, 1000) + num_samples_target = 2500 + hr_data = signal.resample(lr_data, num_samples_target, axis=1) samples.append({ - "low_res": lr_signal.astype(np.float32), - "high_res": hr_signal.astype(np.float32), + "low_res": lr_data.astype(np.float32), + "high_res": hr_data.astype(np.float32), "record_id": ecg_id }) except Exception as e: + logger.debug(f"Skipping record {ecg_id}: {e}") continue return samples \ No newline at end of file From 6f0ed8b232cacecd611d97368ee3e0f9b6711660 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 18:03:45 -0700 Subject: [PATCH 17/27] rename ptb-xl dataset --- pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/{ptbxl_1_0_3.py => ptbxl.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename pyhealth/datasets/{ptbxl_1_0_3.py => ptbxl.py} (100%) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index c7d88a8a3..6a2e059f7 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -62,7 +62,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .ptbxl_1_0_3 import PTBXLDataset +from .ptbxl import PTBXLDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/pyhealth/datasets/ptbxl_1_0_3.py b/pyhealth/datasets/ptbxl.py similarity index 100% rename from pyhealth/datasets/ptbxl_1_0_3.py rename to pyhealth/datasets/ptbxl.py From 19516a008ceb43131249a77ad46129e55822319d Mon Sep 17 00:00:00 2001 From: Matthew Pham Date: Wed, 15 Apr 2026 18:12:50 -0700 Subject: [PATCH 18/27] change task to downsample --- pyhealth/tasks/ptbxl_resampling.py | 38 ++++++++++++------------------ 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py index 9afad725e..9eb0a57f2 100644 --- a/pyhealth/tasks/ptbxl_resampling.py +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -31,49 +31,41 @@ class PTBXLResampling(BaseTask): """ - This task reads the 100Hz (low-res) raw signal and uses scipy to - mathematically generate the 500Hz (high-res) version as the target. + Task: Downsample high-resolution (500Hz) signals to 250Hz. + This provides a balance between detail and computational efficiency. """ task_name: str = "PTBXLResampling" - - # Input: The original 100Hz signal - # Output: The target signal generated via resampling logic - input_schema: Dict[str, str] = {"low_res": "signal"} - output_schema: Dict[str, str] = {"high_res": "signal"} + input_schema: Dict[str, str] = {"signal": "signal"} + output_schema: Dict[str, str] = {"label": "multilabel"} def __init__(self, root: str) -> None: - """ - Args: - root (str): The path to the PTB-XL dataset (containing records100). - """ - if not os.path.exists(root): - raise FileNotFoundError(f"Path not found: {root}") self.root = root def __call__(self, patient: Patient) -> List[Dict]: - events: List[Event] = patient.get_events(event_type="ptbxl") + events = patient.get_events(event_type="ptbxl") samples = [] for event in events: ecg_id = int(event["ecg_id"]) subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" - lr_path = os.path.join(self.root, "records100", subfolder, f"{str(ecg_id).zfill(5)}_lr") + # Use the 500Hz records as the source + hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") try: - record = wfdb.rdrecord(lr_path) - lr_data = record.p_signal.T # Shape: (12, 1000) - num_samples_target = 2500 - hr_data = signal.resample(lr_data, num_samples_target, axis=1) + record = wfdb.rdrecord(hr_path) + data_500hz = record.p_signal.T # Shape: (12, 5000) + + # Downsample to 250Hz (2500 samples for a 10s record) + num_samples_target = 2500 + data_250hz = signal.resample(data_500hz, num_samples_target, axis=1) samples.append({ - "low_res": lr_data.astype(np.float32), - "high_res": hr_data.astype(np.float32), + "signal": data_250hz.astype(np.float32), + "label": event["label"], "record_id": ecg_id }) - except Exception as e: - logger.debug(f"Skipping record {ecg_id}: {e}") continue return samples \ No newline at end of file From eb1e650adf5223c75c4ba5ea5dc57d026a6bc6c8 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 20:39:18 -0700 Subject: [PATCH 19/27] rebase dataset to BaseDataset --- pyhealth/datasets/configs/ptbxl.yaml | 11 +++ pyhealth/datasets/ptbxl.py | 120 ++++++++++++++++++++------- 2 files changed, 102 insertions(+), 29 deletions(-) create mode 100644 pyhealth/datasets/configs/ptbxl.yaml diff --git a/pyhealth/datasets/configs/ptbxl.yaml b/pyhealth/datasets/configs/ptbxl.yaml new file mode 100644 index 000000000..3b9883b26 --- /dev/null +++ b/pyhealth/datasets/configs/ptbxl.yaml @@ -0,0 +1,11 @@ +version: "1.0.3" +tables: + ptb-xl: + file_path: "ptbxl.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "load_from_path" + - "signal_file" + - "label_file" + - "save_to_path" diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 05eae110a..8aab0e29b 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -17,12 +17,18 @@ """ import pandas as pd import os +import logging import urllib.request import requests import zipfile import random -from pyhealth.datasets import BaseSignalDataset from pathlib import Path +from typing import Optional +from . import BaseDataset +from pyhealth.datasets.utils import hash_str, MODULE_CACHE_PATH +import csv + +logger = logging.getLogger(__name__) """ Dataset class for the PTB-XL 1.0.3 dataset. @@ -36,7 +42,7 @@ refresh_cache: whether to refresh the cache; if true, the dataset will be processed from scratch and the cache will be updated. Default is False. """ -class PTBXLDataset(BaseSignalDataset): +class PTBXLDataset(BaseDataset): """ Initialize the PTB-XL dataset. @@ -46,29 +52,45 @@ class PTBXLDataset(BaseSignalDataset): dev (bool): True iff enable dev mode. downsampled (bool): True iff use downsampled signal data. """ - def __init__(self, - root: str = '.', + def __init__( + self, + root: str = ".", download: bool = False, dev: bool = False, - downsampled: bool = False) -> None: + downsampled: bool = False, + config_path: Optional[str] = None, + **kwargs) -> None: self.dev = dev # Determine the root path, where most of the data is stored - self.data_path: str = os.path.join(root, 'ptb_xl_processed_final.zip') + self.data_path: str = os.path.join(root, "ptb_xl_processed_final.zip") self.root = root # Determine signal path, where to fetch the signal samples - signal_folder = 'records100' if downsampled else 'records500' - root_path = os.path.join(root, 'ptb_xl_processed_final/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3') if download else root + signal_folder = "records100" if downsampled else "records500" + root_path = os.path.join(root, "ptb_xl_processed_final/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3") if download else root self.signal_path: str = os.path.join(root_path, signal_folder) # Download the dataset from online source to root if needed self._download(download) + # Determine config_path if it isn't provided + if config_path is None: + logger.info("No config path provided. Using default config.") + config_path = os.path.join(os.path.dirname(__file__), "configs", "ptbxl.yaml") + + # Validate data + self._validate() + + self._prepare_metadata() + super().__init__( root=root, - dataset_name='PTB-XL1.0.3', + dataset_name="PTB-XL", + tables=["PTB-XL"], + config_path=config_path, + **kwargs, ) @@ -78,49 +100,89 @@ def __init__(self, in /records500 and /records100 folder respectively. """ def _download(self, download) -> None: - if download: - zip_id = '1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2' - response = requests.get(f'https://drive.google.com/uc?export=download&id={zip_id}') - with open(self.data_path, 'wb') as file: + zip_id = "1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2" + response = requests.get(f"https://drive.google.com/uc?export=download&id={zip_id}") + with open(self.data_path, "wb") as file: file.write(response.content) with zipfile.ZipFile(self.data_path, "r") as z: z.extractall(self.root) + """ + Verifies if the dataset directory exists and its structure. + Check if specified records folders exists underneath the root, + each patient directory inside contains at least one pair of .dat and .hea files, + and there"s no other unexpected type of files in the directory. + + Raises: + FileNotFoundError: if any directory is not found. + ValueError: if a patient directory contains not .dat/.hea file or if there"s a mismatch of .dat/.hea. + """ + def _validate(self) -> None: + if not os.path.exists(self.root): + e = f"Dataset root path doesn't exist: {self.root}" + logger.error(e) + raise FileNotFoundError(e) + + if not os.path.exists(self.signal_path): + e = f"Dataset signal path doesn't exist: {self.signal_path}" + logger.error(e) + raise FileNotFoundError(e) + + dat = set() + hea = set() + for dirpath, dirnames, filenames in os.walk(self.signal_path): + for filename in filenames: + f, suffix = filename.split(".") + if suffix == "dat": + dat.add(f) + elif suffix == "hea": + hea.add(f) + else: + e = f"Unexpected file format {suffix} in the directory" + logger.error(e) + raise ValueError(e) + if len(dat - hea) != 0: + e = f".dat and .hea files mismatch for patient id {dat - hea}." + logger.error(e) + raise ValueError(e) + """ Process and return a dictionary of the requested PTB-XL data for each patient. Each patient will have a corresponding object that contains load_from_path, patient_id, signal_file, label_file, and save_to_path. """ - def process_EEG_data(self): + def _prepare_metadata(self) -> None: patients = {} for dirpath, dirnames, filenames in os.walk(self.signal_path): for filename in filenames: f = Path(filename).stem - pid = f.split('_')[0] + pid = f.split("_")[0] if pid not in patients: - patients[pid] = [ - { - "load_from_path": dirpath, - "patient_id": pid, - "signal_file": f + '.dat', - "label_file": f + '.hea', - "save_to_path": self.filepath, - } - ] + patients[pid] = { + "load_from_path": dirpath, + "patient_id": pid, + "signal_file": f + ".dat", + "label_file": f + ".hea", + "save_to_path": os.path.join(MODULE_CACHE_PATH, hash_str(filename)), + } if self.dev: keys = random.sample(list(patients), min(len(patients), 5)) values = [d[k] for k in keys] return dict(zip(keys, values)) - return patients + # patients.to_csv(os.path.join(self.root, "ptbxl.csv"), index=False) + with open(os.path.join(self.root, "ptbxl.csv"), "w") as file: + w = csv.DictWriter(file, fieldnames=["load_from_path", "patient_id", "signal_file", "label_file", "save_to_path"]) + w.writeheader() + w.writerows(list(patients.values())) + return None if __name__ == "__main__": - dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True, dev=False) - dataset.stat() - dataset.info() - print(dataset.process_EEG_data()) + dataset = PTBXLDataset(root="../../", download=False, downsampled=True, dev=False) + dataset.stats() + print(dataset.load_data()) From c1995dce348e44cd81f5fcb8d70f3ec8623ab975 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 20:53:23 -0700 Subject: [PATCH 20/27] fix xor --- pyhealth/datasets/ptbxl.py | 9 +++++---- pyhealth/tasks/__init__.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 8aab0e29b..f37fe154c 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -22,11 +22,11 @@ import requests import zipfile import random +import csv from pathlib import Path from typing import Optional -from . import BaseDataset from pyhealth.datasets.utils import hash_str, MODULE_CACHE_PATH -import csv +from . import BaseDataset logger = logging.getLogger(__name__) @@ -143,8 +143,9 @@ def _validate(self) -> None: e = f"Unexpected file format {suffix} in the directory" logger.error(e) raise ValueError(e) - if len(dat - hea) != 0: - e = f".dat and .hea files mismatch for patient id {dat - hea}." + + if len(dat ^ hea) != 0: + e = f".dat and .hea files mismatch for patient id {dat ^ hea}." logger.error(e) raise ValueError(e) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 9cbcae5b6..29cae5023 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -10,7 +10,7 @@ cardiology_isWA_fn, ) from .chestxray14_binary_classification import ChestXray14BinaryClassification -from .ecgqa_preprocess import ECGQA +# from .ecgqa_preprocess import ECGQA from .ptbxl_diagnostic_classification import PTBXLDiagnosticClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification From 9d9d8fbcf781bbc81d940311a26fcba34b475a66 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 22:03:42 -0700 Subject: [PATCH 21/27] modify test files, meet character line length --- pyhealth/datasets/ecgqa.py | 4 +- pyhealth/datasets/ptbxl.py | 45 +++++++++++----- tests/core/test_ptbxl.py | 98 ++++++++++++++++++++++++++++++++++ tests/core/test_ptbxl_1_0_3.py | 37 ------------- 4 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 tests/core/test_ptbxl.py delete mode 100644 tests/core/test_ptbxl_1_0_3.py diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index 504db202f..8c31c112f 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -12,7 +12,9 @@ Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA Dataset paper: - J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model + for few-shot question answering with meta learning," + arXiv preprint arXiv:2410.14464, 2024. Dataset paper link: https://arxiv.org/abs/2410.14464 diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index f37fe154c..8666a935d 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -5,7 +5,9 @@ https://physionet.org/content/ptb-xl/1.0.3/ Dataset paper: - J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, + "Electrocardiogram-language model for few-shot question answering with meta + learning," arXiv preprint arXiv:2410.14464, 2024. Dataset paper link: https://arxiv.org/abs/2410.14464 @@ -36,11 +38,13 @@ Args: dataset_name: name of the dataset. root: root directory of the raw data. - Expected to contain folders for original (records500) or downsampled (records100) data with determined names. + Expected to contain folders for original (records500) or + downsampled (records100) data with determined names. dev: whether to enable dev mode (only use a small subset of the data). Default is False. refresh_cache: whether to refresh the cache; if true, the dataset will - be processed from scratch and the cache will be updated. Default is False. + be processed from scratch and the cache will be updated. + Default is False. """ class PTBXLDataset(BaseDataset): """ @@ -48,7 +52,8 @@ class PTBXLDataset(BaseDataset): Attributes: root (str): Root directory of the raw data. - download (bool): True iff requested to download dataset. Default to False. + download (bool): True iff requested to download dataset. + Default to False. dev (bool): True iff enable dev mode. downsampled (bool): True iff use downsampled signal data. """ @@ -69,7 +74,13 @@ def __init__( # Determine signal path, where to fetch the signal samples signal_folder = "records100" if downsampled else "records500" - root_path = os.path.join(root, "ptb_xl_processed_final/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3") if download else root + + if download: + root_path = os.path.join(root, "ptb_xl_processed_final/\ + ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3") + else: + root_path = root + self.signal_path: str = os.path.join(root_path, signal_folder) # Download the dataset from online source to root if needed @@ -78,7 +89,8 @@ def __init__( # Determine config_path if it isn't provided if config_path is None: logger.info("No config path provided. Using default config.") - config_path = os.path.join(os.path.dirname(__file__), "configs", "ptbxl.yaml") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ptbxl.yaml") # Validate data self._validate() @@ -102,7 +114,8 @@ def __init__( def _download(self, download) -> None: if download: zip_id = "1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2" - response = requests.get(f"https://drive.google.com/uc?export=download&id={zip_id}") + response = requests.get( + f"https://drive.google.com/uc?export=download&id={zip_id}") with open(self.data_path, "wb") as file: file.write(response.content) @@ -117,7 +130,8 @@ def _download(self, download) -> None: Raises: FileNotFoundError: if any directory is not found. - ValueError: if a patient directory contains not .dat/.hea file or if there"s a mismatch of .dat/.hea. + ValueError: if a patient directory contains not .dat/.hea file + or if there"s a mismatch of .dat/.hea. """ def _validate(self) -> None: if not os.path.exists(self.root): @@ -168,17 +182,22 @@ def _prepare_metadata(self) -> None: "patient_id": pid, "signal_file": f + ".dat", "label_file": f + ".hea", - "save_to_path": os.path.join(MODULE_CACHE_PATH, hash_str(filename)), + "save_to_path": os.path.join( + MODULE_CACHE_PATH, hash_str(filename)), } if self.dev: keys = random.sample(list(patients), min(len(patients), 5)) - values = [d[k] for k in keys] - return dict(zip(keys, values)) + values = [patients[k] for k in keys] + patients = dict(zip(keys, values)) - # patients.to_csv(os.path.join(self.root, "ptbxl.csv"), index=False) with open(os.path.join(self.root, "ptbxl.csv"), "w") as file: - w = csv.DictWriter(file, fieldnames=["load_from_path", "patient_id", "signal_file", "label_file", "save_to_path"]) + w = csv.DictWriter(file, fieldnames=[ + "load_from_path", + "patient_id", + "signal_file", + "label_file", + "save_to_path"]) w.writeheader() w.writerows(list(patients.values())) return None diff --git a/tests/core/test_ptbxl.py b/tests/core/test_ptbxl.py new file mode 100644 index 000000000..990519b5f --- /dev/null +++ b/tests/core/test_ptbxl.py @@ -0,0 +1,98 @@ +import unittest +import tempfile +import os +import shutil +from pyhealth.datasets import PTBXLDataset +from pathlib import Path + +class TestPTBXLDataset(unittest.TestCase): + """ + Test PTB-XL 1.0.3 dataset with demo data. + """ + def setUp(self): + sample_records = { + '00001': { + 'load_from_path': 'sample/path', + 'signal_file': '00001_lr.dat', + 'label_file': '00001_lr.hea', + 'save_to_path': 'sample/path', + }, + '00002': { + 'load_from_path': 'sample/path', + 'signal_file': '00002_lr.dat', + 'label_file': '00002_lr.hea', + 'save_to_path': 'sample/path', + }, + '00003': { + 'load_from_path': 'sample/path', + 'signal_file': '00003_lr.dat', + 'label_file': '00003_lr.hea', + 'save_to_path': 'sample/path', + }, + '00004': { + 'load_from_path': 'sample/path', + 'signal_file': '00004_lr.dat', + 'label_file': '00004_lr.hea', + 'save_to_path': 'sample/path', + }, + '00005': { + 'load_from_path': 'sample/path', + 'signal_file': '00005_lr.dat', + 'label_file': '00005_lr.hea', + 'save_to_path': 'sample/path', + }, + '00006': { + 'load_from_path': 'sample/path', + 'signal_file': '00006_lr.dat', + 'label_file': '00006_lr.hea', + 'save_to_path': 'sample/path', + }, + } + + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + os.makedirs(os.path.join(self.root, 'records100')) + + for i in sample_records.keys(): + with open(os.path.join(self.root, f'records100/{i}.dat'), 'w') as f: + f.write('sample .dat data') + with open(os.path.join(self.root, f'records100/{i}.hea'), 'w') as f: + f.write('sample .hea data') + + self.dataset = PTBXLDataset(root=self.root, download=False, downsampled=True, dev=True) + + + """ + Verify if the dataset contains correct basic information + """ + def testBasicInfo(self): + self.assertEqual(self.dataset.dataset_name, 'PTB-XL') + self.dataset.stats() + + """ + Verify if the dataset contains expected data + """ + def testData(self): + + # Test if dev mode has been applied + self.assertLessEqual(len(self.dataset.unique_patient_ids), 5) + + # Test info of the first patient + patient = self.dataset.get_patient('00001') + self.assertIsNotNone(patient) + self.assertIsInstance(patient.patient_id, str) + self.assertEqual(patient.patient_id, '00001') + + events = self.dataset.get_patient('00001').get_events() + self.assertEqual(len(events), 1) + self.assertIsNotNone(events[0]['load_from_path']) + self.assertIsNotNone(events[0]['signal_file']) + self.assertIsNotNone(events[0]['label_file']) + self.assertIsNotNone(events[0]['save_to_path']) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/core/test_ptbxl_1_0_3.py b/tests/core/test_ptbxl_1_0_3.py deleted file mode 100644 index ef969857f..000000000 --- a/tests/core/test_ptbxl_1_0_3.py +++ /dev/null @@ -1,37 +0,0 @@ -import unittest -from pyhealth.datasets import PTBXLDataset - -class TestPTBXLDataset(unittest.TestCase): - """ - Test PTB-XL 1.0.3 dataset with demo data. - """ - def setUp(self): - self.dataset = PTBXLDataset(root='../../../../', download=False, downsampled=True, dev=True) - - """ - Verify if the dataset contains correct basic information - """ - def testBasicInfo(self): - self.assertEqual(self.dataset.dataset_name, 'PTB-XL1.0.3') - self.assertIsInstance(self.dataset.filepath, str) - - """ - Verify if the dataset contains expected data - """ - def testData(self): - self.assertIsInstance(self.dataset.patients, dict) - self.assertLessEqual(len(self.dataset.patients), 5) - for pid, value in self.dataset.patients.items(): - self.assertIsNotNone(pid) - self.assertIsInstance(pid, str) - - self.assertIsInstance(value, list) - self.assertIsInstance(value[0], dict) - self.assertIsNotNone(value[0]['load_from_path']) - self.assertIsNotNone(value[0]['patient_id']) - self.assertIsNotNone(value[0]['signal_file']) - self.assertIsNotNone(value[0]['label_file']) - self.assertIsNotNone(value[0]['save_to_path']) - -if __name__ == "__main__": - unittest.main() From 20297dd740d270d055ab160c5fcb4b583140845a Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 22:04:20 -0700 Subject: [PATCH 22/27] uncomment ecg task --- pyhealth/tasks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 29cae5023..9cbcae5b6 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -10,7 +10,7 @@ cardiology_isWA_fn, ) from .chestxray14_binary_classification import ChestXray14BinaryClassification -# from .ecgqa_preprocess import ECGQA +from .ecgqa_preprocess import ECGQA from .ptbxl_diagnostic_classification import PTBXLDiagnosticClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification From d8cfd2335ad3683b32e3e88c2d1b90b3d2530c99 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 22:09:09 -0700 Subject: [PATCH 23/27] rm trailing spaces --- pyhealth/datasets/ecgqa.py | 6 +++--- pyhealth/datasets/ptbxl.py | 32 ++++++++++++++++---------------- tests/core/test_ptbxl.py | 12 ++++++------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index 8c31c112f..4b30defdb 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -8,12 +8,12 @@ The QA data originates from the ECG-QA dataset (Oh et al., 2024), restructured for few-shot learning by Tang et al. (CHIL 2025). -Dataset link: +Dataset link: Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA -Dataset paper: +Dataset paper: J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model - for few-shot question answering with meta learning," + for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. Dataset paper link: diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 8666a935d..9e1b76eed 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -1,12 +1,12 @@ """ Pyhealth dataset for the 1.0.3 PTB-XL dataset. -Dataset link: +Dataset link: https://physionet.org/content/ptb-xl/1.0.3/ -Dataset paper: - J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, - "Electrocardiogram-language model for few-shot question answering with meta +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, + "Electrocardiogram-language model for few-shot question answering with meta learning," arXiv preprint arXiv:2410.14464, 2024. Dataset paper link: @@ -37,13 +37,13 @@ Args: dataset_name: name of the dataset. - root: root directory of the raw data. - Expected to contain folders for original (records500) or + root: root directory of the raw data. + Expected to contain folders for original (records500) or downsampled (records100) data with determined names. dev: whether to enable dev mode (only use a small subset of the data). Default is False. refresh_cache: whether to refresh the cache; if true, the dataset will - be processed from scratch and the cache will be updated. + be processed from scratch and the cache will be updated. Default is False. """ class PTBXLDataset(BaseDataset): @@ -52,7 +52,7 @@ class PTBXLDataset(BaseDataset): Attributes: root (str): Root directory of the raw data. - download (bool): True iff requested to download dataset. + download (bool): True iff requested to download dataset. Default to False. dev (bool): True iff enable dev mode. downsampled (bool): True iff use downsampled signal data. @@ -107,8 +107,8 @@ def __init__( """ - Download PTB-XL dataset from public google drive sources. - It will contain both the original and downsampled versions, + Download PTB-XL dataset from public google drive sources. + It will contain both the original and downsampled versions, in /records500 and /records100 folder respectively. """ def _download(self, download) -> None: @@ -130,7 +130,7 @@ def _download(self, download) -> None: Raises: FileNotFoundError: if any directory is not found. - ValueError: if a patient directory contains not .dat/.hea file + ValueError: if a patient directory contains not .dat/.hea file or if there"s a mismatch of .dat/.hea. """ def _validate(self) -> None: @@ -165,7 +165,7 @@ def _validate(self) -> None: """ Process and return a dictionary of the requested PTB-XL data for each patient. - Each patient will have a corresponding object that contains + Each patient will have a corresponding object that contains load_from_path, patient_id, signal_file, label_file, and save_to_path. """ def _prepare_metadata(self) -> None: @@ -193,10 +193,10 @@ def _prepare_metadata(self) -> None: with open(os.path.join(self.root, "ptbxl.csv"), "w") as file: w = csv.DictWriter(file, fieldnames=[ - "load_from_path", - "patient_id", - "signal_file", - "label_file", + "load_from_path", + "patient_id", + "signal_file", + "label_file", "save_to_path"]) w.writeheader() w.writerows(list(patients.values())) diff --git a/tests/core/test_ptbxl.py b/tests/core/test_ptbxl.py index 990519b5f..973a71bf1 100644 --- a/tests/core/test_ptbxl.py +++ b/tests/core/test_ptbxl.py @@ -16,37 +16,37 @@ def setUp(self): 'signal_file': '00001_lr.dat', 'label_file': '00001_lr.hea', 'save_to_path': 'sample/path', - }, + }, '00002': { 'load_from_path': 'sample/path', 'signal_file': '00002_lr.dat', 'label_file': '00002_lr.hea', 'save_to_path': 'sample/path', - }, + }, '00003': { 'load_from_path': 'sample/path', 'signal_file': '00003_lr.dat', 'label_file': '00003_lr.hea', 'save_to_path': 'sample/path', - }, + }, '00004': { 'load_from_path': 'sample/path', 'signal_file': '00004_lr.dat', 'label_file': '00004_lr.hea', 'save_to_path': 'sample/path', - }, + }, '00005': { 'load_from_path': 'sample/path', 'signal_file': '00005_lr.dat', 'label_file': '00005_lr.hea', 'save_to_path': 'sample/path', - }, + }, '00006': { 'load_from_path': 'sample/path', 'signal_file': '00006_lr.dat', 'label_file': '00006_lr.hea', 'save_to_path': 'sample/path', - }, + }, } self.temp_dir = tempfile.mkdtemp() From bbebd73d1c85962485b9bf6199e32dc141f61557 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Wed, 15 Apr 2026 22:10:22 -0700 Subject: [PATCH 24/27] minor --- pyhealth/datasets/ptbxl.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 9e1b76eed..071768591 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -36,15 +36,15 @@ Dataset class for the PTB-XL 1.0.3 dataset. Args: - dataset_name: name of the dataset. - root: root directory of the raw data. - Expected to contain folders for original (records500) or - downsampled (records100) data with determined names. - dev: whether to enable dev mode (only use a small subset of the data). - Default is False. - refresh_cache: whether to refresh the cache; if true, the dataset will - be processed from scratch and the cache will be updated. - Default is False. + dataset_name: name of the dataset. + root: root directory of the raw data. + Expected to contain folders for original (records500) or + downsampled (records100) data with determined names. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. + Default is False. """ class PTBXLDataset(BaseDataset): """ @@ -63,7 +63,7 @@ def __init__( download: bool = False, dev: bool = False, downsampled: bool = False, - config_path: Optional[str] = None, + config_path: Optional[str] = None, **kwargs) -> None: self.dev = dev @@ -129,9 +129,9 @@ def _download(self, download) -> None: and there"s no other unexpected type of files in the directory. Raises: - FileNotFoundError: if any directory is not found. - ValueError: if a patient directory contains not .dat/.hea file - or if there"s a mismatch of .dat/.hea. + FileNotFoundError: if any directory is not found. + ValueError: if a patient directory contains not .dat/.hea file + or if there"s a mismatch of .dat/.hea. """ def _validate(self) -> None: if not os.path.exists(self.root): From 704c376c9e4be33a989a952ad74ff35d789d7587 Mon Sep 17 00:00:00 2001 From: jovianw Date: Wed, 15 Apr 2026 23:36:38 -0700 Subject: [PATCH 25/27] add: ECGQA example and update datasets and tasks --- examples/ecgqa_fsl.py | 135 +++++++++++++++++++++++++++++ pyhealth/datasets/ecgqa.py | 2 +- pyhealth/datasets/ptbxl.py | 2 +- pyhealth/tasks/__init__.py | 2 +- pyhealth/tasks/ecgqa_preprocess.py | 100 +++++++++++++++++++++ pyhealth/tasks/ptbxl_resampling.py | 50 +++++------ tests/core/test_ecgqa.py | 104 ++++++++++++++++------ 7 files changed, 338 insertions(+), 57 deletions(-) create mode 100644 examples/ecgqa_fsl.py create mode 100644 pyhealth/tasks/ecgqa_preprocess.py diff --git a/examples/ecgqa_fsl.py b/examples/ecgqa_fsl.py new file mode 100644 index 000000000..97d7cd34c --- /dev/null +++ b/examples/ecgqa_fsl.py @@ -0,0 +1,135 @@ +"""ECG Question Answering with Few-Shot Learning — PyHealth example. + +This script demonstrates the full data pipeline for the FSL_ECG_QA +project (Tang et al., CHIL 2025) using PyHealth datasets and tasks. + +Pipeline: + 1. PTBXLDataset → PTBXLResampling task → resampled ECG signals (12, 2500) + 2. ECGQADataset → ECGQA task (with signal_loader) → multimodal QA samples + +For the full meta-learning training loop, see: + https://github.com/Tang-Jia-Lu/FSL_ECG_QA/blob/main/train.py + +Requirements: + - PTB-XL dataset (https://physionet.org/content/ptb-xl/1.0.3/) + - ECG-QA data (https://github.com/Tang-Jia-Lu/FSL_ECG_QA/tree/main/ecgqa) + - pip install wfdb + +Authors: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" + +import json +import os +import tempfile +from pathlib import Path + +import torch +from pyhealth.datasets import PTBXLDataset, ECGQADataset +from pyhealth.tasks import PTBXLResampling, ECGQA + +# ---------- Configuration ---------- +# Update these paths to match your local setup +PTBXL_ROOT = "/path/to/ptb-xl/1.0.3/" # contains records500/, records100/ +ECGQA_ROOT = "/path/to/ecgqa/ptbxl/paraphrased/" # contains train/, valid/, test/ + +# Set to True for a quick test run (loads a small matched subset). +# Set to False to process the full dataset. +DEV_MODE = True + + +def _load_dev_subset(ptbxl_root, ecgqa_root): + """Load a small matched subset for quick testing. + + PTBXLDataset dev mode picks 5 random patients from the full 21K + range. To guarantee every QA sample has a matching signal, this + helper pre-filters the QA JSON files to only include records whose + ecg_id appears in the loaded PTB-XL signals, then loads the filtered + data through ECGQADataset. + """ + # Load + resample PTB-XL signals (5 in dev mode) + print(" Loading PTB-XL signals...") + ptb = PTBXLDataset(root=ptbxl_root, downsampled=False, dev=True) + signal_ds = ptb.set_task(PTBXLResampling(root=ptbxl_root)) + signal_lookup = {s["record_id"]: s["signal"] for s in signal_ds} + matched_ecg_ids = set(int(k) for k in signal_lookup.keys()) + print(f" PTB-XL: {len(signal_lookup)} signals loaded (ecg_ids: {matched_ecg_ids})") + + # Pre-filter QA JSONs to only records with matching ecg_ids + print(" Filtering QA data to matched ecg_ids...") + tmp_dir = tempfile.mkdtemp() + src = Path(ecgqa_root) + total_kept = 0 + for split in ("train", "valid", "test"): + dst = Path(tmp_dir) / split + dst.mkdir() + split_records = [] + for fpath in sorted((src / split).glob("*.json")): + with open(fpath) as f: + records = json.load(f) + split_records.extend(r for r in records if r["ecg_id"][0] in matched_ecg_ids) + if not split_records: + # Write a dummy record so _verify_data passes; it gets filtered + # out by prepare_metadata (question_type won't start with "single-") + split_records = [{"ecg_id": [0], "question": "", "answer": [""], + "question_type": "dummy", "attribute_type": "", + "template_id": 0, "question_id": 0, + "sample_id": 0, "attribute": [""]}] + else: + total_kept += len(split_records) + with open(dst / "00.json", "w") as f: + json.dump(split_records, f) + print(f" Kept {total_kept} QA records for {len(matched_ecg_ids)} ecg_ids") + + # Load filtered QA data with signal loader + def signal_loader(ecg_id): + return torch.FloatTensor(signal_lookup[ecg_id]) + + qa = ECGQADataset(root=tmp_dir) + samples = qa.set_task(ECGQA(signal_loader=signal_loader)) + print(f" Created {len(samples)} matched QA samples") + return samples, signal_lookup + + +def main(): + if DEV_MODE: + samples, signal_lookup = _load_dev_subset(PTBXL_ROOT, ECGQA_ROOT) + else: + # ---------- Full pipeline ---------- + # Step 1: Load + resample all PTB-XL signals + print("Loading PTB-XL dataset...") + ptb = PTBXLDataset(root=PTBXL_ROOT, downsampled=False) + signal_ds = ptb.set_task(PTBXLResampling(root=PTBXL_ROOT)) + signal_lookup = {s["record_id"]: s["signal"] for s in signal_ds} + print(f" Loaded {len(signal_lookup)} signal samples") + + # Step 2: Build signal loader + def signal_loader(ecg_id: int) -> torch.Tensor: + return torch.FloatTensor(signal_lookup[ecg_id]) + + # Step 3: Load ECG-QA data with signals + print("Loading ECG-QA dataset...") + qa = ECGQADataset(root=ECGQA_ROOT) + samples = qa.set_task(ECGQA(signal_loader=signal_loader)) + print(f" Created {len(samples)} QA samples") + + # ---------- Inspect a sample ---------- + if len(samples) == 0: + print("\nNo matched samples found. Check that PTBXL_ROOT and ECGQA_ROOT are correct.") + return + + sample = samples[0] + print("\n=== Sample ===") + print(f" patient_id: {sample['patient_id']}") + print(f" question: {sample['question'][:80]}...") + print(f" answer: {sample['answer']}") + print(f" question_type: {sample['question_type']}") + print(f" episode_class: {sample['episode_class']}") + if "signal" in sample: + print(f" signal shape: {sample['signal'].shape}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index 4b30defdb..419da4214 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -146,7 +146,7 @@ def prepare_metadata(self) -> None: attribute = ";".join(record.get("attribute", [])) rows.append({ - "patient_id": str(ecg_id), + "patient_id": f"{ecg_id:05d}", "ecg_id": ecg_id, "question": record["question"], "answer": answer, diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 071768591..9ff136f58 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -100,7 +100,7 @@ def __init__( super().__init__( root=root, dataset_name="PTB-XL", - tables=["PTB-XL"], + tables=["ptb-xl"], config_path=config_path, **kwargs, ) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index c54fcf00e..40e992c47 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,7 +11,7 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .ecgqa_preprocess import ECGQA -from .ptbxl_resampling import PTBXLResamplingTask +from .ptbxl_resampling import PTBXLResampling from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 diff --git a/pyhealth/tasks/ecgqa_preprocess.py b/pyhealth/tasks/ecgqa_preprocess.py new file mode 100644 index 000000000..067db250f --- /dev/null +++ b/pyhealth/tasks/ecgqa_preprocess.py @@ -0,0 +1,100 @@ +""" +A PyHealth task for ECG Question Answering preprocessing. + +Produces one sample per QA pair, optionally attaching a resampled ECG +signal tensor via a user-provided signal_loader callable. + +Dataset link: + https://github.com/Tang-Jia-Lu/FSL_ECG_QA + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, + "Electrocardiogram-language model for few-shot question answering + with meta learning," arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import torch +from typing import Any, Callable, Dict, List, Optional + +from pyhealth.tasks import BaseTask + + +class ECGQA(BaseTask): + """ECG Question Answering task. + + For each patient (ECG recording), this task returns one sample per + QA pair, containing the question, answer, question type, and an + episode_class key for episodic sampling. Optionally attaches an ECG + signal tensor via a user-provided signal_loader. + + Works with both PTB-XL and MIMIC-IV-ECG based QA data. + + Args: + signal_loader: optional callable mapping ecg_id (int) to a signal + tensor of shape (12, N). If None, samples are text-only. + + Each returned sample contains: + - "question": str, the natural language question + - "answer": str, the answer (semicolon-separated if multiple) + - "question_type": str, one of "single-verify", "single-choose", "single-query" + - "episode_class": str, class key for episodic sampling (template_id + attribute + answer) + - "signal": torch.FloatTensor (only if signal_loader is provided) + """ + + task_name: str = "ECG_QA" + input_schema: Dict[str, str] = {"question": "text"} + output_schema: Dict[str, str] = {"answer": "text"} + + def __init__( + self, + signal_loader: Optional[Callable[[int], torch.Tensor]] = None, + ) -> None: + self.signal_loader = signal_loader + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient (ECG recording). Creates one sample per QA pair. + + The ECG signal is loaded once per patient via signal_loader and + shared across all QA pairs for the same ecg_id. + """ + pid = patient.patient_id + samples: List[Dict[str, Any]] = [] + + events = patient.get_events("ecg_qa") + if not events: + return samples + + # Load signal once for this patient if loader provided. + # If loading fails, skip the patient entirely. + signal = None + if self.signal_loader is not None: + try: + signal = self.signal_loader(int(pid)) + if not isinstance(signal, torch.Tensor): + signal = torch.FloatTensor(signal) + except Exception: + return samples + + for event in events: + episode_class = f"{event.template_id}_{event.attribute}_{event.answer}" + + sample = { + "patient_id": pid, + "question": event.question, + "answer": event.answer, + "question_type": event.question_type, + "episode_class": episode_class, + } + if signal is not None: + sample["signal"] = signal + + samples.append(sample) + + return samples diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py index 9eb0a57f2..afd26f35c 100644 --- a/pyhealth/tasks/ptbxl_resampling.py +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -2,11 +2,11 @@ A PyHealth task that performs dynamic resampling of ECG signals. Dataset link: - https://physionet.org/content/ptb-xl/1.0.1/ + https://physionet.org/content/ptb-xl/1.0.3/ Dataset paper: (please cite if you use this dataset) - Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. - "PTB-XL, a large publicly available electrocardiography dataset." + Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. + "PTB-XL, a large publicly available electrocardiography dataset." Scientific Data, 7(1), 1-15. (2020). Dataset paper link: @@ -24,7 +24,7 @@ from scipy import signal from typing import Dict, List -from pyhealth.data import Patient, Event +from pyhealth.data import Patient from pyhealth.tasks import BaseTask logger = logging.getLogger(__name__) @@ -35,37 +35,33 @@ class PTBXLResampling(BaseTask): This provides a balance between detail and computational efficiency. """ task_name: str = "PTBXLResampling" - input_schema: Dict[str, str] = {"signal": "signal"} - output_schema: Dict[str, str] = {"label": "multilabel"} + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"signal": "tensor"} def __init__(self, root: str) -> None: self.root = root def __call__(self, patient: Patient) -> List[Dict]: - events = patient.get_events(event_type="ptbxl") + events = patient.get_events(event_type="ptb-xl") samples = [] - for event in events: - ecg_id = int(event["ecg_id"]) - subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" - - # Use the 500Hz records as the source - hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") + ecg_id = int(patient.patient_id) + subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" + hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") - try: - record = wfdb.rdrecord(hr_path) - data_500hz = record.p_signal.T # Shape: (12, 5000) + try: + record = wfdb.rdrecord(hr_path) + data_500hz = record.p_signal.T # Shape: (12, 5000) - # Downsample to 250Hz (2500 samples for a 10s record) - num_samples_target = 2500 - data_250hz = signal.resample(data_500hz, num_samples_target, axis=1) + # Downsample to 250Hz (2500 samples for a 10s record) + num_samples_target = 2500 + data_250hz = signal.resample(data_500hz, num_samples_target, axis=1) - samples.append({ - "signal": data_250hz.astype(np.float32), - "label": event["label"], - "record_id": ecg_id - }) - except Exception as e: - continue + samples.append({ + "signal": data_250hz.astype(np.float32), + "record_id": ecg_id + }) + except Exception: + pass - return samples \ No newline at end of file + return samples diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py index 3915f6228..4aea1204b 100644 --- a/tests/core/test_ecgqa.py +++ b/tests/core/test_ecgqa.py @@ -4,7 +4,10 @@ import json from pathlib import Path +import torch + from pyhealth.datasets import ECGQADataset +from pyhealth.tasks import ECGQA class TestECGQADataset(unittest.TestCase): @@ -99,25 +102,19 @@ def test_dataset_initialization(self): self.assertEqual(dataset.dataset_name, "ecg_qa") self.assertEqual(dataset.root, str(self.root)) - def test_metadata_file_created(self): - """Test ecg-qa-pyhealth.csv is created in root""" - ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - metadata_file = self.root / "ecg-qa-pyhealth.csv" - self.assertTrue(metadata_file.exists()) - def test_patient_count(self): """Test only single-* records are loaded (4 of 5)""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) # ecg_ids 1, 2, 4, 5 survive; ecg_id 3 is comparison-verify and is filtered. self.assertEqual(len(dataset.unique_patient_ids), 4) self.assertEqual( - sorted(dataset.unique_patient_ids), ["1", "2", "4", "5"] + sorted(dataset.unique_patient_ids), ["00001", "00002", "00004", "00005"] ) def test_filters_non_single_question_types(self): """Test that comparison-verify records are dropped""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - self.assertNotIn("3", dataset.unique_patient_ids) + self.assertNotIn("00003", dataset.unique_patient_ids) def test_stats_method(self): """Test stats method runs without error""" @@ -127,20 +124,14 @@ def test_stats_method(self): def test_get_patient(self): """Test get_patient method""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - patient = dataset.get_patient("1") + patient = dataset.get_patient("00001") self.assertIsNotNone(patient) - self.assertEqual(patient.patient_id, "1") - - def test_get_patient_not_found(self): - """Test that patient not found throws error.""" - dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - with self.assertRaises(AssertionError): - dataset.get_patient("999") + self.assertEqual(patient.patient_id, "00001") def test_single_verify_event_fields(self): """Test a single-verify event surfaces the expected attributes""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - events = dataset.get_patient("1").get_events() + events = dataset.get_patient("00001").get_events() self.assertEqual(len(events), 1) self.assertEqual(events[0]["question_type"], "single-verify") self.assertEqual(events[0]["answer"], "yes") @@ -149,19 +140,10 @@ def test_single_verify_event_fields(self): def test_single_choose_event_joins_multi_valued_fields(self): """Test single-choose records join answer/attribute with ';'""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - events = dataset.get_patient("2").get_events() + events = dataset.get_patient("00002").get_events() self.assertEqual(events[0]["answer"], "sinus rhythm;atrial fibrillation") self.assertEqual(events[0]["attribute"], "SR;AFIB") - def test_invalid_ecg_source_raises(self): - """Test ValueError on invalid ecg_source""" - with self.assertRaises(ValueError): - ECGQADataset( - root=str(self.root), - ecg_source="nope", - cache_dir=self._cache_tmp, - ) - class TestECGQAVerifyData(unittest.TestCase): """Test the structural checks performed by ECGQADataset._verify_data.""" @@ -192,5 +174,73 @@ def test_empty_split_dir_raises(self): ECGQADataset(root=str(self.root)) +class TestECGQATask(unittest.TestCase): + """Test the ECGQA task with synthetic data.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + + records = [ + { + "ecg_id": [1], + "question": "Does this ECG show normal sinus rhythm?", + "answer": ["yes"], + "question_type": "single-verify", + "attribute_type": "scp_code", + "template_id": 1, + "question_id": 100, + "sample_id": 1000, + "attribute": ["NORM"], + }, + { + "ecg_id": [1], + "question": "What rhythm does this ECG show?", + "answer": ["sinus rhythm"], + "question_type": "single-query", + "attribute_type": "rhythm", + "template_id": 2, + "question_id": 101, + "sample_id": 1001, + "attribute": ["SR"], + }, + ] + (self.root / "train" / "00.json").write_text(json.dumps(records)) + (self.root / "valid" / "00.json").write_text(json.dumps([records[0]])) + (self.root / "test" / "00.json").write_text(json.dumps([records[1]])) + + self._cache_tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + shutil.rmtree(self._cache_tmp, ignore_errors=True) + + def test_text_only_mode(self): + """Task with no signal_loader returns text-only samples.""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + samples = dataset.set_task(ECGQA()) + sample = samples[0] + + self.assertIn("question", sample) + self.assertIn("answer", sample) + self.assertIn("question_type", sample) + self.assertIn("episode_class", sample) + self.assertNotIn("signal", sample) + + def test_signal_loader_attaches_signal(self): + """Task with signal_loader attaches signal tensor to samples.""" + def fake_loader(ecg_id): + return torch.randn(12, 2500) + + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + samples = dataset.set_task(ECGQA(signal_loader=fake_loader)) + sample = samples[0] + + self.assertIn("signal", sample) + self.assertEqual(sample["signal"].shape, torch.Size([12, 2500])) + + if __name__ == "__main__": unittest.main() From 562f01600f8de4b9779f436abd94a10e3deb5e21 Mon Sep 17 00:00:00 2001 From: Yiyun Date: Sun, 19 Apr 2026 21:12:20 -0700 Subject: [PATCH 26/27] improve code a little --- pyhealth/datasets/ptbxl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 071768591..b6186f3b1 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -69,7 +69,6 @@ def __init__( self.dev = dev # Determine the root path, where most of the data is stored - self.data_path: str = os.path.join(root, "ptb_xl_processed_final.zip") self.root = root # Determine signal path, where to fetch the signal samples @@ -113,13 +112,14 @@ def __init__( """ def _download(self, download) -> None: if download: + data_path: str = os.path.join(self.root, "ptb_xl_processed_final.zip") zip_id = "1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2" response = requests.get( f"https://drive.google.com/uc?export=download&id={zip_id}") - with open(self.data_path, "wb") as file: + with open(data_path, "wb") as file: file.write(response.content) - with zipfile.ZipFile(self.data_path, "r") as z: + with zipfile.ZipFile(data_path, "r") as z: z.extractall(self.root) """ From 04bd9630c40848a8c6abac510aa32d925839d5f0 Mon Sep 17 00:00:00 2001 From: jovianw Date: Mon, 20 Apr 2026 18:16:38 -0700 Subject: [PATCH 27/27] Rename files and task classes for consistency --- docs/api/tasks.rst | 4 ++-- ...ks.ECGQA.rst => pyhealth.tasks.ECGQAPreprocessing.rst} | 4 ++-- ..._resampling.rst => pyhealth.tasks.PTBXLResampling.rst} | 0 examples/ecgqa_fsl.py | 8 ++++---- pyhealth/datasets/configs/{ecg_qa.yaml => ecgqa.yaml} | 0 pyhealth/datasets/ecgqa.py | 8 ++++---- pyhealth/tasks/__init__.py | 2 +- pyhealth/tasks/ecgqa_preprocess.py | 4 ++-- tests/core/test_ecgqa.py | 8 ++++---- 9 files changed, 19 insertions(+), 19 deletions(-) rename docs/api/tasks/{pyhealth.tasks.ECGQA.rst => pyhealth.tasks.ECGQAPreprocessing.rst} (54%) rename docs/api/tasks/{pyhealth.tasks.ptbxl_resampling.rst => pyhealth.tasks.PTBXLResampling.rst} (100%) rename pyhealth/datasets/configs/{ecg_qa.yaml => ecgqa.yaml} (100%) diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index b9de6168b..ebbef9f45 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -225,8 +225,8 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification - ECG Question Answering - PTB-XL Signal Resampling + ECG Question Answering + PTB-XL Signal Resampling Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ECGQA.rst b/docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst similarity index 54% rename from docs/api/tasks/pyhealth.tasks.ECGQA.rst rename to docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst index 4350a77f7..c81fb95c2 100644 --- a/docs/api/tasks/pyhealth.tasks.ECGQA.rst +++ b/docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst @@ -1,7 +1,7 @@ -pyhealth.tasks.ECGQA +pyhealth.tasks.ECGQAPreprocessing ======================================= -.. autoclass:: pyhealth.tasks.ECGQA +.. autoclass:: pyhealth.tasks.ECGQAPreprocessing :members: :undoc-members: :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst b/docs/api/tasks/pyhealth.tasks.PTBXLResampling.rst similarity index 100% rename from docs/api/tasks/pyhealth.tasks.ptbxl_resampling.rst rename to docs/api/tasks/pyhealth.tasks.PTBXLResampling.rst diff --git a/examples/ecgqa_fsl.py b/examples/ecgqa_fsl.py index 97d7cd34c..3e4abd1b7 100644 --- a/examples/ecgqa_fsl.py +++ b/examples/ecgqa_fsl.py @@ -5,7 +5,7 @@ Pipeline: 1. PTBXLDataset → PTBXLResampling task → resampled ECG signals (12, 2500) - 2. ECGQADataset → ECGQA task (with signal_loader) → multimodal QA samples + 2. ECGQADataset → ECGQAPreprocessing task (with signal_loader) → multimodal QA samples For the full meta-learning training loop, see: https://github.com/Tang-Jia-Lu/FSL_ECG_QA/blob/main/train.py @@ -28,7 +28,7 @@ import torch from pyhealth.datasets import PTBXLDataset, ECGQADataset -from pyhealth.tasks import PTBXLResampling, ECGQA +from pyhealth.tasks import PTBXLResampling, ECGQAPreprocessing # ---------- Configuration ---------- # Update these paths to match your local setup @@ -88,7 +88,7 @@ def signal_loader(ecg_id): return torch.FloatTensor(signal_lookup[ecg_id]) qa = ECGQADataset(root=tmp_dir) - samples = qa.set_task(ECGQA(signal_loader=signal_loader)) + samples = qa.set_task(ECGQAPreprocessing(signal_loader=signal_loader)) print(f" Created {len(samples)} matched QA samples") return samples, signal_lookup @@ -112,7 +112,7 @@ def signal_loader(ecg_id: int) -> torch.Tensor: # Step 3: Load ECG-QA data with signals print("Loading ECG-QA dataset...") qa = ECGQADataset(root=ECGQA_ROOT) - samples = qa.set_task(ECGQA(signal_loader=signal_loader)) + samples = qa.set_task(ECGQAPreprocessing(signal_loader=signal_loader)) print(f" Created {len(samples)} QA samples") # ---------- Inspect a sample ---------- diff --git a/pyhealth/datasets/configs/ecg_qa.yaml b/pyhealth/datasets/configs/ecgqa.yaml similarity index 100% rename from pyhealth/datasets/configs/ecg_qa.yaml rename to pyhealth/datasets/configs/ecgqa.yaml diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py index 419da4214..e5dcc81d5 100644 --- a/pyhealth/datasets/ecgqa.py +++ b/pyhealth/datasets/ecgqa.py @@ -97,7 +97,7 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") - config_path = Path(__file__).parent / "configs" / "ecg_qa.yaml" + config_path = Path(__file__).parent / "configs" / "ecgqa.yaml" self.root = root @@ -284,6 +284,6 @@ def _verify_data(self, root: str) -> None: @property def default_task(self): - """Returns the default task for the ECG-QA dataset: ECGQA.""" - from pyhealth.tasks import ECGQA - return ECGQA() + """Returns the default task for the ECG-QA dataset: ECGQAPreprocessing.""" + from pyhealth.tasks import ECGQAPreprocessing + return ECGQAPreprocessing() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 40e992c47..ea9a7ae6f 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -10,7 +10,7 @@ cardiology_isWA_fn, ) from .chestxray14_binary_classification import ChestXray14BinaryClassification -from .ecgqa_preprocess import ECGQA +from .ecgqa_preprocess import ECGQAPreprocessing from .ptbxl_resampling import PTBXLResampling from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification diff --git a/pyhealth/tasks/ecgqa_preprocess.py b/pyhealth/tasks/ecgqa_preprocess.py index 067db250f..8cc73d45d 100644 --- a/pyhealth/tasks/ecgqa_preprocess.py +++ b/pyhealth/tasks/ecgqa_preprocess.py @@ -26,8 +26,8 @@ from pyhealth.tasks import BaseTask -class ECGQA(BaseTask): - """ECG Question Answering task. +class ECGQAPreprocessing(BaseTask): + """ECG Question Answering preprocessing task. For each patient (ECG recording), this task returns one sample per QA pair, containing the question, answer, question type, and an diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py index 4aea1204b..4aaebfcd3 100644 --- a/tests/core/test_ecgqa.py +++ b/tests/core/test_ecgqa.py @@ -7,7 +7,7 @@ import torch from pyhealth.datasets import ECGQADataset -from pyhealth.tasks import ECGQA +from pyhealth.tasks import ECGQAPreprocessing class TestECGQADataset(unittest.TestCase): @@ -175,7 +175,7 @@ def test_empty_split_dir_raises(self): class TestECGQATask(unittest.TestCase): - """Test the ECGQA task with synthetic data.""" + """Test the ECGQAPreprocessing task with synthetic data.""" def setUp(self): self.temp_dir = tempfile.mkdtemp() @@ -220,7 +220,7 @@ def tearDown(self): def test_text_only_mode(self): """Task with no signal_loader returns text-only samples.""" dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - samples = dataset.set_task(ECGQA()) + samples = dataset.set_task(ECGQAPreprocessing()) sample = samples[0] self.assertIn("question", sample) @@ -235,7 +235,7 @@ def fake_loader(ecg_id): return torch.randn(12, 2500) dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) - samples = dataset.set_task(ECGQA(signal_loader=fake_loader)) + samples = dataset.set_task(ECGQAPreprocessing(signal_loader=fake_loader)) sample = samples[0] self.assertIn("signal", sample)