diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 1875698ae..64c880b99 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.ECGQADataset datasets/pyhealth.datasets.PhysioNetDeIDDataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset @@ -246,3 +247,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.PTBXLDataset diff --git a/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst b/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst new file mode 100644 index 000000000..789764aef --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ECGQADataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.ECGQADataset +=================================== + +The ECG-QA dataset (Oh et al., 2024) provides natural-language question-answer pairs grounded in ECG recordings from PTB-XL or MIMIC-IV-ECG, restructured for few-shot learning by Tang et al. (CHIL 2025). For more information see the `FSL_ECG_QA repository `_. + +.. autoclass:: pyhealth.datasets.ECGQADataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst new file mode 100644 index 000000000..39e29d253 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.PTBXLDataset +=================================== + +The PTB-XL 1.0.3 dataset. For the original dataset see `here `_. + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 69e5aa592..d8d16496d 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -226,6 +226,8 @@ Available Tasks ChestX-ray14 Binary Classification De-Identification NER ChestX-ray14 Multilabel Classification + ECG Question Answering + PTB-XL Signal Resampling Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst b/docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst new file mode 100644 index 000000000..c81fb95c2 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ECGQAPreprocessing.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ECGQAPreprocessing +======================================= + +.. autoclass:: pyhealth.tasks.ECGQAPreprocessing + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.PTBXLResampling.rst b/docs/api/tasks/pyhealth.tasks.PTBXLResampling.rst new file mode 100644 index 000000000..c6a47182e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.PTBXLResampling.rst @@ -0,0 +1,11 @@ +PTB-XL Signal Resampling +======================== + +.. currentmodule:: pyhealth.tasks.ptbxl_resampling + +.. autoclass:: PTBXLResampling + :members: + :show-inheritance: + :exclude-members: __init__ + + .. automethod:: __call__ \ No newline at end of file diff --git a/examples/ecgqa_fsl.py b/examples/ecgqa_fsl.py new file mode 100644 index 000000000..3e4abd1b7 --- /dev/null +++ b/examples/ecgqa_fsl.py @@ -0,0 +1,135 @@ +"""ECG Question Answering with Few-Shot Learning — PyHealth example. + +This script demonstrates the full data pipeline for the FSL_ECG_QA +project (Tang et al., CHIL 2025) using PyHealth datasets and tasks. + +Pipeline: + 1. PTBXLDataset → PTBXLResampling task → resampled ECG signals (12, 2500) + 2. ECGQADataset → ECGQAPreprocessing task (with signal_loader) → multimodal QA samples + +For the full meta-learning training loop, see: + https://github.com/Tang-Jia-Lu/FSL_ECG_QA/blob/main/train.py + +Requirements: + - PTB-XL dataset (https://physionet.org/content/ptb-xl/1.0.3/) + - ECG-QA data (https://github.com/Tang-Jia-Lu/FSL_ECG_QA/tree/main/ecgqa) + - pip install wfdb + +Authors: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" + +import json +import os +import tempfile +from pathlib import Path + +import torch +from pyhealth.datasets import PTBXLDataset, ECGQADataset +from pyhealth.tasks import PTBXLResampling, ECGQAPreprocessing + +# ---------- Configuration ---------- +# Update these paths to match your local setup +PTBXL_ROOT = "/path/to/ptb-xl/1.0.3/" # contains records500/, records100/ +ECGQA_ROOT = "/path/to/ecgqa/ptbxl/paraphrased/" # contains train/, valid/, test/ + +# Set to True for a quick test run (loads a small matched subset). +# Set to False to process the full dataset. +DEV_MODE = True + + +def _load_dev_subset(ptbxl_root, ecgqa_root): + """Load a small matched subset for quick testing. + + PTBXLDataset dev mode picks 5 random patients from the full 21K + range. To guarantee every QA sample has a matching signal, this + helper pre-filters the QA JSON files to only include records whose + ecg_id appears in the loaded PTB-XL signals, then loads the filtered + data through ECGQADataset. + """ + # Load + resample PTB-XL signals (5 in dev mode) + print(" Loading PTB-XL signals...") + ptb = PTBXLDataset(root=ptbxl_root, downsampled=False, dev=True) + signal_ds = ptb.set_task(PTBXLResampling(root=ptbxl_root)) + signal_lookup = {s["record_id"]: s["signal"] for s in signal_ds} + matched_ecg_ids = set(int(k) for k in signal_lookup.keys()) + print(f" PTB-XL: {len(signal_lookup)} signals loaded (ecg_ids: {matched_ecg_ids})") + + # Pre-filter QA JSONs to only records with matching ecg_ids + print(" Filtering QA data to matched ecg_ids...") + tmp_dir = tempfile.mkdtemp() + src = Path(ecgqa_root) + total_kept = 0 + for split in ("train", "valid", "test"): + dst = Path(tmp_dir) / split + dst.mkdir() + split_records = [] + for fpath in sorted((src / split).glob("*.json")): + with open(fpath) as f: + records = json.load(f) + split_records.extend(r for r in records if r["ecg_id"][0] in matched_ecg_ids) + if not split_records: + # Write a dummy record so _verify_data passes; it gets filtered + # out by prepare_metadata (question_type won't start with "single-") + split_records = [{"ecg_id": [0], "question": "", "answer": [""], + "question_type": "dummy", "attribute_type": "", + "template_id": 0, "question_id": 0, + "sample_id": 0, "attribute": [""]}] + else: + total_kept += len(split_records) + with open(dst / "00.json", "w") as f: + json.dump(split_records, f) + print(f" Kept {total_kept} QA records for {len(matched_ecg_ids)} ecg_ids") + + # Load filtered QA data with signal loader + def signal_loader(ecg_id): + return torch.FloatTensor(signal_lookup[ecg_id]) + + qa = ECGQADataset(root=tmp_dir) + samples = qa.set_task(ECGQAPreprocessing(signal_loader=signal_loader)) + print(f" Created {len(samples)} matched QA samples") + return samples, signal_lookup + + +def main(): + if DEV_MODE: + samples, signal_lookup = _load_dev_subset(PTBXL_ROOT, ECGQA_ROOT) + else: + # ---------- Full pipeline ---------- + # Step 1: Load + resample all PTB-XL signals + print("Loading PTB-XL dataset...") + ptb = PTBXLDataset(root=PTBXL_ROOT, downsampled=False) + signal_ds = ptb.set_task(PTBXLResampling(root=PTBXL_ROOT)) + signal_lookup = {s["record_id"]: s["signal"] for s in signal_ds} + print(f" Loaded {len(signal_lookup)} signal samples") + + # Step 2: Build signal loader + def signal_loader(ecg_id: int) -> torch.Tensor: + return torch.FloatTensor(signal_lookup[ecg_id]) + + # Step 3: Load ECG-QA data with signals + print("Loading ECG-QA dataset...") + qa = ECGQADataset(root=ECGQA_ROOT) + samples = qa.set_task(ECGQAPreprocessing(signal_loader=signal_loader)) + print(f" Created {len(samples)} QA samples") + + # ---------- Inspect a sample ---------- + if len(samples) == 0: + print("\nNo matched samples found. Check that PTBXL_ROOT and ECGQA_ROOT are correct.") + return + + sample = samples[0] + print("\n=== Sample ===") + print(f" patient_id: {sample['patient_id']}") + print(f" question: {sample['question'][:80]}...") + print(f" answer: {sample['answer']}") + print(f" question_type: {sample['question_type']}") + print(f" episode_class: {sample['episode_class']}") + if "signal" in sample: + print(f" signal shape: {sample['signal'].shape}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..c741e1447 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -53,6 +53,7 @@ def __init__(self, *args, **kwargs): from .cosmic import COSMICDataset from .covid19_cxr import COVID19CXRDataset from .dreamt import DREAMTDataset +from .ecgqa import ECGQADataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset from .isruc import ISRUCDataset @@ -61,6 +62,7 @@ def __init__(self, *args, **kwargs): from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset +from .ptbxl import PTBXLDataset from .physionet_deid import PhysioNetDeIDDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset diff --git a/pyhealth/datasets/configs/ecgqa.yaml b/pyhealth/datasets/configs/ecgqa.yaml new file mode 100644 index 000000000..cc18708d6 --- /dev/null +++ b/pyhealth/datasets/configs/ecgqa.yaml @@ -0,0 +1,16 @@ +version: "3.0.0" +tables: + ecg_qa: + file_path: "ecg-qa-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "ecg_id" + - "question" + - "answer" + - "question_type" + - "attribute_type" + - "template_id" + - "question_id" + - "sample_id" + - "attribute" diff --git a/pyhealth/datasets/configs/ptbxl.yaml b/pyhealth/datasets/configs/ptbxl.yaml new file mode 100644 index 000000000..3b9883b26 --- /dev/null +++ b/pyhealth/datasets/configs/ptbxl.yaml @@ -0,0 +1,11 @@ +version: "1.0.3" +tables: + ptb-xl: + file_path: "ptbxl.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "load_from_path" + - "signal_file" + - "label_file" + - "save_to_path" diff --git a/pyhealth/datasets/ecgqa.py b/pyhealth/datasets/ecgqa.py new file mode 100644 index 000000000..e5dcc81d5 --- /dev/null +++ b/pyhealth/datasets/ecgqa.py @@ -0,0 +1,289 @@ +""" +ECG Question Answering dataset. + +This dataset provides natural language question-answer pairs linked to +ECG recordings via ecg_id. It is an annotation layer on top of ECG +recordings from PTB-XL or MIMIC-IV-ECG. + +The QA data originates from the ECG-QA dataset (Oh et al., 2024), +restructured for few-shot learning by Tang et al. (CHIL 2025). + +Dataset link: + Dataset is available at https://github.com/Tang-Jia-Lu/FSL_ECG_QA + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, "Electrocardiogram-language model + for few-shot question answering with meta learning," + arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import hashlib +import json +import logging +import os +import tarfile +import urllib.request +from pathlib import Path +from typing import Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + +_VALID_ECG_SOURCES = {"ptbxl": "ptbxl", "mimic": "mimic-iv-ecg"} + +""" +Three question types are supported: + - single-verify: yes/no questions about ECG findings + - single-choose: multi-choice questions (answer is one option, "both", or "none") + - single-query: open-ended questions with free-form answers + +Args: + root: directory that holds (or will hold) the paraphrased QA splits as + train/, valid/, test/ subdirectories of JSON files. + dataset_name: name of the dataset. Default is "ecg_qa". + config_path: path to the YAML config file. Default uses built-in config. + download: if True, download the chosen variant from GitHub into ``root`` + before loading. Defaults to False. + ecg_source: which underlying ECG dataset the QA pairs are grounded in. + One of ``"ptbxl"`` (PTB-XL) or ``"mimic"`` (MIMIC-IV-ECG). + Defaults to ``"ptbxl"``. + include_demographics: if True, download the modified variant whose + question text includes patient sex and age. Defaults to False + (the original Tang et al. release). + +Examples: + >>> from pyhealth.datasets import ECGQADataset + >>> # Use a pre-downloaded local copy + >>> dataset = ECGQADataset( + ... root="/path/to/ecgqa/ptbxl/paraphrased/", + ... ) + >>> # Or download the modified PTB-XL variant on the fly + >>> dataset = ECGQADataset( + ... root="./ecg_qa_ptbxl_demo", + ... download=True, + ... ecg_source="ptbxl", + ... include_demographics=True, + ... ) + >>> dataset.stats() +""" +class ECGQADataset(BaseDataset): + + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + download: bool = False, + ecg_source: str = "ptbxl", + include_demographics: bool = False, + **kwargs, + ) -> None: + if ecg_source not in _VALID_ECG_SOURCES: + raise ValueError( + f"ecg_source must be one of {sorted(_VALID_ECG_SOURCES)}, " + f"got {ecg_source!r}" + ) + + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "ecgqa.yaml" + + self.root = root + + if download: + self._download_data(root, ecg_source, include_demographics) + self._verify_data(root) + + self.prepare_metadata() + + super().__init__( + root=root, + tables=["ecg_qa"], + dataset_name=dataset_name or "ecg_qa", + config_path=config_path, + **kwargs, + ) + + def prepare_metadata(self) -> None: + """Build and save a metadata CSV from all ECG-QA JSON files. + + Scans train/, valid/, test/ subdirectories under root, loads all + JSON files, filters to single-* question types, and writes a + single CSV with columns: + patient_id, ecg_id, question, answer, question_type, + attribute_type, template_id, question_id, sample_id, attribute + """ + root = Path(self.root) + csv_path = root / "ecg-qa-pyhealth.csv" + if csv_path.exists(): + return + + data = [] + for split_dir in ("train", "valid", "test"): + for fpath in sorted((root / split_dir).glob("*.json")): + with open(fpath, "r") as f: + data.extend(json.load(f)) + + rows: list[dict] = [] + for record in data: + qt = record.get("question_type", "") + if not qt.startswith("single-"): + continue + + ecg_id = record["ecg_id"][0] + answer = ";".join(record["answer"]) + attribute = ";".join(record.get("attribute", [])) + + rows.append({ + "patient_id": f"{ecg_id:05d}", + "ecg_id": ecg_id, + "question": record["question"], + "answer": answer, + "question_type": qt, + "attribute_type": record.get("attribute_type", ""), + "template_id": record.get("template_id", 0), + "question_id": record.get("question_id", 0), + "sample_id": record.get("sample_id", 0), + "attribute": attribute, + }) + + if not rows: + raise ValueError("No single-* question type records found in JSON data") + + df = pd.DataFrame(rows) + df.sort_values(["patient_id", "question_type", "template_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + df.to_csv(csv_path, index=False) + logger.info(f"Wrote metadata to {csv_path}") + + def _download_data( + self, root: str, ecg_source: str, include_demographics: bool + ) -> None: + """Downloads the requested ECG-QA dataset from GitHub into ``root``. + + Fetches a commit-pinned tarball from the original ``Tang-Jia-Lu`` repo + (or a modified fork when ``include_demographics`` is True), + verifies its MD5, and extracts only the + ``ecgqa//paraphrased/{train,valid,test}/`` subtree directly + into ``root``. The tarball is deleted after extraction. + + Args: + root: directory the splits will land in. + ecg_source: ``"ptbxl"`` or ``"mimic"``. + include_demographics: selects the modified variant when True. + + Raises: + ValueError: if the downloaded tarball fails MD5 verification or + if it contains an unsafe path during extraction. + """ + # URLs are pinned to specific commit SHAs so the MD5s below stay stable + # even if either repo gains new commits later. + if include_demographics: + url = ( + "https://github.com/jovianw/FSL_ECG_QA/archive/" + "2e2d4ac185d6069c741d083269ea40ca01bfd50b.tar.gz" + ) + expected_md5 = "e65c4b6ae127103ad92a33ec9246039e" + archive_prefix = "FSL_ECG_QA-2e2d4ac185d6069c741d083269ea40ca01bfd50b" + else: + url = ( + "https://github.com/Tang-Jia-Lu/FSL_ECG_QA/archive/" + "b0ec9bd84ae2337052ca977941e37a703dcb492e.tar.gz" + ) + expected_md5 = "894b4af304e99c48ecd62a914ba3ba2b" + archive_prefix = "FSL_ECG_QA-b0ec9bd84ae2337052ca977941e37a703dcb492e" + + os.makedirs(root, exist_ok=True) + archive_path = os.path.join(root, "ecgqa-download.tar.gz") + + logger.info(f"Downloading {url} -> {archive_path}") + urllib.request.urlretrieve(url, archive_path) + + logger.info(f"Checking MD5 checksum for {archive_path}...") + with open(archive_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != expected_md5: + msg = ( + f"Invalid MD5 checksum for {archive_path}: " + f"expected {expected_md5}, got {file_md5}" + ) + logger.error(msg) + raise ValueError(msg) + + ecg_source_dir = _VALID_ECG_SOURCES[ecg_source] + prefix = f"{archive_prefix}/ecgqa/{ecg_source_dir}/paraphrased/" + abs_root = os.path.abspath(root) + + logger.info(f"Extracting {prefix}* from {archive_path} into {root}") + with tarfile.open(archive_path, "r:gz") as tar: + for member in tar.getmembers(): + if not member.name.startswith(prefix): + continue + rel = member.name[len(prefix):] + if not rel: + continue + + target_path = os.path.abspath(os.path.join(abs_root, rel)) + if os.path.commonpath([abs_root]) != os.path.commonpath( + [abs_root, target_path] + ): + msg = f"Unsafe path detected in tar file: '{member.name}'!" + logger.error(msg) + raise ValueError(msg) + + member.name = rel + tar.extract(member, path=root) + + os.remove(archive_path) + logger.info("Download complete") + + def _verify_data(self, root: str) -> None: + """Verifies the presence and structure of the dataset directory. + + Checks that ``root`` exists, that ``train/``, ``valid/``, and ``test/`` + subdirectories are present, and that each contains at least one + ``*.json`` file. + + Args: + root: directory expected to hold the dataset splits. + + Raises: + FileNotFoundError: if ``root`` or any required split directory is + missing. + ValueError: if a split directory contains no JSON files. + """ + if not os.path.exists(root): + msg = f"Dataset path does not exist: {root}" + logger.error(msg) + raise FileNotFoundError(msg) + + for split in ("train", "valid", "test"): + split_dir = os.path.join(root, split) + if not os.path.isdir(split_dir): + msg = ( + f"Dataset path must contain a '{split}' subdirectory: " + f"{split_dir}" + ) + logger.error(msg) + raise FileNotFoundError(msg) + if not list(Path(split_dir).glob("*.json")): + msg = f"Dataset '{split}' directory must contain JSON files!" + logger.error(msg) + raise ValueError(msg) + + @property + def default_task(self): + """Returns the default task for the ECG-QA dataset: ECGQAPreprocessing.""" + from pyhealth.tasks import ECGQAPreprocessing + return ECGQAPreprocessing() diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py new file mode 100644 index 000000000..c7ba88015 --- /dev/null +++ b/pyhealth/datasets/ptbxl.py @@ -0,0 +1,208 @@ +""" +Pyhealth dataset for the 1.0.3 PTB-XL dataset. + +Dataset link: + https://physionet.org/content/ptb-xl/1.0.3/ + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, + "Electrocardiogram-language model for few-shot question answering with meta + learning," arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import pandas as pd +import os +import logging +import urllib.request +import requests +import zipfile +import random +import csv +from pathlib import Path +from typing import Optional +from pyhealth.datasets.utils import hash_str, MODULE_CACHE_PATH +from . import BaseDataset + +logger = logging.getLogger(__name__) + +""" +Dataset class for the PTB-XL 1.0.3 dataset. + +Args: + dataset_name: name of the dataset. + root: root directory of the raw data. + Expected to contain folders for original (records500) or + downsampled (records100) data with determined names. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. + Default is False. +""" +class PTBXLDataset(BaseDataset): + """ + Initialize the PTB-XL dataset. + + Attributes: + root (str): Root directory of the raw data. + download (bool): True iff requested to download dataset. + Default to False. + dev (bool): True iff enable dev mode. + downsampled (bool): True iff use downsampled signal data. + """ + def __init__( + self, + root: str = ".", + download: bool = False, + dev: bool = False, + downsampled: bool = False, + config_path: Optional[str] = None, + **kwargs) -> None: + + self.dev = dev + + # Determine the root path, where most of the data is stored + self.root = root + + # Determine signal path, where to fetch the signal samples + signal_folder = "records100" if downsampled else "records500" + + if download: + root_path = os.path.join(root, "ptb_xl_processed_final/\ + ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3") + else: + root_path = root + + self.signal_path: str = os.path.join(root_path, signal_folder) + + # Download the dataset from online source to root if needed + self._download(download) + + # Determine config_path if it isn't provided + if config_path is None: + logger.info("No config path provided. Using default config.") + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ptbxl.yaml") + + # Validate data + self._validate() + + self._prepare_metadata() + + super().__init__( + root=root, + dataset_name="PTB-XL", + tables=["ptb-xl"], + config_path=config_path, + **kwargs, + ) + + + """ + Download PTB-XL dataset from public google drive sources. + It will contain both the original and downsampled versions, + in /records500 and /records100 folder respectively. + """ + def _download(self, download) -> None: + if download: + data_path: str = os.path.join(self.root, "ptb_xl_processed_final.zip") + zip_id = "1btbPiHEOUBLNLfUYkLnKzs50ZTmgqdI2" + response = requests.get( + f"https://drive.google.com/uc?export=download&id={zip_id}") + with open(data_path, "wb") as file: + file.write(response.content) + + with zipfile.ZipFile(data_path, "r") as z: + z.extractall(self.root) + + """ + Verifies if the dataset directory exists and its structure. + Check if specified records folders exists underneath the root, + each patient directory inside contains at least one pair of .dat and .hea files, + and there"s no other unexpected type of files in the directory. + + Raises: + FileNotFoundError: if any directory is not found. + ValueError: if a patient directory contains not .dat/.hea file + or if there"s a mismatch of .dat/.hea. + """ + def _validate(self) -> None: + if not os.path.exists(self.root): + e = f"Dataset root path doesn't exist: {self.root}" + logger.error(e) + raise FileNotFoundError(e) + + if not os.path.exists(self.signal_path): + e = f"Dataset signal path doesn't exist: {self.signal_path}" + logger.error(e) + raise FileNotFoundError(e) + + dat = set() + hea = set() + for dirpath, dirnames, filenames in os.walk(self.signal_path): + for filename in filenames: + f, suffix = filename.split(".") + if suffix == "dat": + dat.add(f) + elif suffix == "hea": + hea.add(f) + else: + e = f"Unexpected file format {suffix} in the directory" + logger.error(e) + raise ValueError(e) + + if len(dat ^ hea) != 0: + e = f".dat and .hea files mismatch for patient id {dat ^ hea}." + logger.error(e) + raise ValueError(e) + + """ + Process and return a dictionary of the requested PTB-XL data for each patient. + Each patient will have a corresponding object that contains + load_from_path, patient_id, signal_file, label_file, and save_to_path. + """ + def _prepare_metadata(self) -> None: + patients = {} + + for dirpath, dirnames, filenames in os.walk(self.signal_path): + for filename in filenames: + f = Path(filename).stem + pid = f.split("_")[0] + + if pid not in patients: + patients[pid] = { + "load_from_path": dirpath, + "patient_id": pid, + "signal_file": f + ".dat", + "label_file": f + ".hea", + "save_to_path": os.path.join( + MODULE_CACHE_PATH, hash_str(filename)), + } + + if self.dev: + keys = random.sample(list(patients), min(len(patients), 5)) + values = [patients[k] for k in keys] + patients = dict(zip(keys, values)) + + with open(os.path.join(self.root, "ptbxl.csv"), "w") as file: + w = csv.DictWriter(file, fieldnames=[ + "load_from_path", + "patient_id", + "signal_file", + "label_file", + "save_to_path"]) + w.writeheader() + w.writerows(list(patients.values())) + return None + +if __name__ == "__main__": + dataset = PTBXLDataset(root="../../", download=False, downsampled=True, dev=False) + dataset.stats() + print(dataset.load_data()) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..f9c70459e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -10,6 +10,8 @@ cardiology_isWA_fn, ) from .chestxray14_binary_classification import ChestXray14BinaryClassification +from .ecgqa_preprocess import ECGQAPreprocessing +from .ptbxl_resampling import PTBXLResampling from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .deid_ner import DeIDNERTask diff --git a/pyhealth/tasks/ecgqa_preprocess.py b/pyhealth/tasks/ecgqa_preprocess.py new file mode 100644 index 000000000..8cc73d45d --- /dev/null +++ b/pyhealth/tasks/ecgqa_preprocess.py @@ -0,0 +1,100 @@ +""" +A PyHealth task for ECG Question Answering preprocessing. + +Produces one sample per QA pair, optionally attaching a resampled ECG +signal tensor via a user-provided signal_loader callable. + +Dataset link: + https://github.com/Tang-Jia-Lu/FSL_ECG_QA + +Dataset paper: + J. Tang, T. Xia, Y. Lu, C. Mascolo, and A. Saeed, + "Electrocardiogram-language model for few-shot question answering + with meta learning," arXiv preprint arXiv:2410.14464, 2024. + +Dataset paper link: + https://arxiv.org/abs/2410.14464 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import torch +from typing import Any, Callable, Dict, List, Optional + +from pyhealth.tasks import BaseTask + + +class ECGQAPreprocessing(BaseTask): + """ECG Question Answering preprocessing task. + + For each patient (ECG recording), this task returns one sample per + QA pair, containing the question, answer, question type, and an + episode_class key for episodic sampling. Optionally attaches an ECG + signal tensor via a user-provided signal_loader. + + Works with both PTB-XL and MIMIC-IV-ECG based QA data. + + Args: + signal_loader: optional callable mapping ecg_id (int) to a signal + tensor of shape (12, N). If None, samples are text-only. + + Each returned sample contains: + - "question": str, the natural language question + - "answer": str, the answer (semicolon-separated if multiple) + - "question_type": str, one of "single-verify", "single-choose", "single-query" + - "episode_class": str, class key for episodic sampling (template_id + attribute + answer) + - "signal": torch.FloatTensor (only if signal_loader is provided) + """ + + task_name: str = "ECG_QA" + input_schema: Dict[str, str] = {"question": "text"} + output_schema: Dict[str, str] = {"answer": "text"} + + def __init__( + self, + signal_loader: Optional[Callable[[int], torch.Tensor]] = None, + ) -> None: + self.signal_loader = signal_loader + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient (ECG recording). Creates one sample per QA pair. + + The ECG signal is loaded once per patient via signal_loader and + shared across all QA pairs for the same ecg_id. + """ + pid = patient.patient_id + samples: List[Dict[str, Any]] = [] + + events = patient.get_events("ecg_qa") + if not events: + return samples + + # Load signal once for this patient if loader provided. + # If loading fails, skip the patient entirely. + signal = None + if self.signal_loader is not None: + try: + signal = self.signal_loader(int(pid)) + if not isinstance(signal, torch.Tensor): + signal = torch.FloatTensor(signal) + except Exception: + return samples + + for event in events: + episode_class = f"{event.template_id}_{event.attribute}_{event.answer}" + + sample = { + "patient_id": pid, + "question": event.question, + "answer": event.answer, + "question_type": event.question_type, + "episode_class": episode_class, + } + if signal is not None: + sample["signal"] = signal + + samples.append(sample) + + return samples diff --git a/pyhealth/tasks/ptbxl_resampling.py b/pyhealth/tasks/ptbxl_resampling.py new file mode 100644 index 000000000..afd26f35c --- /dev/null +++ b/pyhealth/tasks/ptbxl_resampling.py @@ -0,0 +1,67 @@ +""" +A PyHealth task that performs dynamic resampling of ECG signals. + +Dataset link: + https://physionet.org/content/ptb-xl/1.0.3/ + +Dataset paper: (please cite if you use this dataset) + Wagner, P., Strodthoff, N., Bousseljot, R. D., Samek, W., & Schaeffter, T. + "PTB-XL, a large publicly available electrocardiography dataset." + Scientific Data, 7(1), 1-15. (2020). + +Dataset paper link: + https://www.nature.com/articles/s41597-020-0495-6 + +Author: + Jovian Wang (jovianw2@illinois.edu) + Matthew Pham (mdpham2@illinois.edu) + Yiyun Wang (yiyunw3@illinois.edu) +""" +import logging +import os +import wfdb +import numpy as np +from scipy import signal +from typing import Dict, List + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + +class PTBXLResampling(BaseTask): + """ + Task: Downsample high-resolution (500Hz) signals to 250Hz. + This provides a balance between detail and computational efficiency. + """ + task_name: str = "PTBXLResampling" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"signal": "tensor"} + + def __init__(self, root: str) -> None: + self.root = root + + def __call__(self, patient: Patient) -> List[Dict]: + events = patient.get_events(event_type="ptb-xl") + samples = [] + + ecg_id = int(patient.patient_id) + subfolder = f"{str((ecg_id // 1000) * 1000).zfill(5)}" + hr_path = os.path.join(self.root, "records500", subfolder, f"{str(ecg_id).zfill(5)}_hr") + + try: + record = wfdb.rdrecord(hr_path) + data_500hz = record.p_signal.T # Shape: (12, 5000) + + # Downsample to 250Hz (2500 samples for a 10s record) + num_samples_target = 2500 + data_250hz = signal.resample(data_500hz, num_samples_target, axis=1) + + samples.append({ + "signal": data_250hz.astype(np.float32), + "record_id": ecg_id + }) + except Exception: + pass + + return samples diff --git a/tests/core/test_ecgqa.py b/tests/core/test_ecgqa.py new file mode 100644 index 000000000..4aaebfcd3 --- /dev/null +++ b/tests/core/test_ecgqa.py @@ -0,0 +1,246 @@ +import unittest +import tempfile +import shutil +import json +from pathlib import Path + +import torch + +from pyhealth.datasets import ECGQADataset +from pyhealth.tasks import ECGQAPreprocessing + + +class TestECGQADataset(unittest.TestCase): + """Test ECG-QA dataset with synthetic test data.""" + + def setUp(self): + """Set up train/valid/test JSON files and a temp cache dir.""" + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + + train_records = [ + { + "ecg_id": [1], + "question": "Does this ECG show normal sinus rhythm?", + "answer": ["yes"], + "question_type": "single-verify", + "attribute_type": "scp_code", + "template_id": 1, + "question_id": 100, + "sample_id": 1000, + "attribute": ["NORM"], + }, + { + "ecg_id": [2], + "question": "What rhythm does this ECG show?", + "answer": ["sinus rhythm", "atrial fibrillation"], + "question_type": "single-choose", + "attribute_type": "rhythm", + "template_id": 2, + "question_id": 101, + "sample_id": 1001, + "attribute": ["SR", "AFIB"], + }, + { + "ecg_id": [3], + "question": "Are both left axis deviation and right bundle branch block present?", + "answer": ["yes"], + "question_type": "comparison-verify", + "attribute_type": "scp_code", + "template_id": 3, + "question_id": 102, + "sample_id": 1002, + "attribute": ["LAD", "RBBB"], + }, + ] + valid_records = [ + { + "ecg_id": [4], + "question": "What ECG abnormalities are present?", + "answer": ["left axis deviation"], + "question_type": "single-query", + "attribute_type": "scp_code", + "template_id": 4, + "question_id": 103, + "sample_id": 1003, + "attribute": ["LAD"], + }, + ] + test_records = [ + { + "ecg_id": [5], + "question": "Is the heart rate above 100 beats per minute?", + "answer": ["no"], + "question_type": "single-verify", + "attribute_type": "heart_rate", + "template_id": 5, + "question_id": 104, + "sample_id": 1004, + "attribute": ["bradycardia"], + }, + ] + + (self.root / "train" / "00.json").write_text(json.dumps(train_records)) + (self.root / "valid" / "00.json").write_text(json.dumps(valid_records)) + (self.root / "test" / "00.json").write_text(json.dumps(test_records)) + + self._cache_tmp = tempfile.mkdtemp() + + def tearDown(self): + """Clean up temporary files""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + shutil.rmtree(self._cache_tmp, ignore_errors=True) + + def test_dataset_initialization(self): + """Test ECGQADataset initialization""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + + self.assertIsNotNone(dataset) + self.assertEqual(dataset.dataset_name, "ecg_qa") + self.assertEqual(dataset.root, str(self.root)) + + def test_patient_count(self): + """Test only single-* records are loaded (4 of 5)""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + # ecg_ids 1, 2, 4, 5 survive; ecg_id 3 is comparison-verify and is filtered. + self.assertEqual(len(dataset.unique_patient_ids), 4) + self.assertEqual( + sorted(dataset.unique_patient_ids), ["00001", "00002", "00004", "00005"] + ) + + def test_filters_non_single_question_types(self): + """Test that comparison-verify records are dropped""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + self.assertNotIn("00003", dataset.unique_patient_ids) + + def test_stats_method(self): + """Test stats method runs without error""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + dataset.stats() + + def test_get_patient(self): + """Test get_patient method""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + patient = dataset.get_patient("00001") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "00001") + + def test_single_verify_event_fields(self): + """Test a single-verify event surfaces the expected attributes""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + events = dataset.get_patient("00001").get_events() + self.assertEqual(len(events), 1) + self.assertEqual(events[0]["question_type"], "single-verify") + self.assertEqual(events[0]["answer"], "yes") + self.assertEqual(events[0]["attribute"], "NORM") + + def test_single_choose_event_joins_multi_valued_fields(self): + """Test single-choose records join answer/attribute with ';'""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + events = dataset.get_patient("00002").get_events() + self.assertEqual(events[0]["answer"], "sinus rhythm;atrial fibrillation") + self.assertEqual(events[0]["attribute"], "SR;AFIB") + + +class TestECGQAVerifyData(unittest.TestCase): + """Test the structural checks performed by ECGQADataset._verify_data.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_nonexistent_root_raises(self): + """Test FileNotFoundError when root does not exist""" + bogus = self.root / "does-not-exist" + with self.assertRaises(FileNotFoundError): + ECGQADataset(root=str(bogus)) + + def test_missing_split_dir_raises(self): + """Test FileNotFoundError when train/valid/test dirs are missing""" + with self.assertRaises(FileNotFoundError): + ECGQADataset(root=str(self.root)) + + def test_empty_split_dir_raises(self): + """Test ValueError when a split dir has no JSON files""" + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + with self.assertRaises(ValueError): + ECGQADataset(root=str(self.root)) + + +class TestECGQATask(unittest.TestCase): + """Test the ECGQAPreprocessing task with synthetic data.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + for split in ("train", "valid", "test"): + (self.root / split).mkdir() + + records = [ + { + "ecg_id": [1], + "question": "Does this ECG show normal sinus rhythm?", + "answer": ["yes"], + "question_type": "single-verify", + "attribute_type": "scp_code", + "template_id": 1, + "question_id": 100, + "sample_id": 1000, + "attribute": ["NORM"], + }, + { + "ecg_id": [1], + "question": "What rhythm does this ECG show?", + "answer": ["sinus rhythm"], + "question_type": "single-query", + "attribute_type": "rhythm", + "template_id": 2, + "question_id": 101, + "sample_id": 1001, + "attribute": ["SR"], + }, + ] + (self.root / "train" / "00.json").write_text(json.dumps(records)) + (self.root / "valid" / "00.json").write_text(json.dumps([records[0]])) + (self.root / "test" / "00.json").write_text(json.dumps([records[1]])) + + self._cache_tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + shutil.rmtree(self._cache_tmp, ignore_errors=True) + + def test_text_only_mode(self): + """Task with no signal_loader returns text-only samples.""" + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + samples = dataset.set_task(ECGQAPreprocessing()) + sample = samples[0] + + self.assertIn("question", sample) + self.assertIn("answer", sample) + self.assertIn("question_type", sample) + self.assertIn("episode_class", sample) + self.assertNotIn("signal", sample) + + def test_signal_loader_attaches_signal(self): + """Task with signal_loader attaches signal tensor to samples.""" + def fake_loader(ecg_id): + return torch.randn(12, 2500) + + dataset = ECGQADataset(root=str(self.root), cache_dir=self._cache_tmp) + samples = dataset.set_task(ECGQAPreprocessing(signal_loader=fake_loader)) + sample = samples[0] + + self.assertIn("signal", sample) + self.assertEqual(sample["signal"].shape, torch.Size([12, 2500])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_ptbxl.py b/tests/core/test_ptbxl.py new file mode 100644 index 000000000..973a71bf1 --- /dev/null +++ b/tests/core/test_ptbxl.py @@ -0,0 +1,98 @@ +import unittest +import tempfile +import os +import shutil +from pyhealth.datasets import PTBXLDataset +from pathlib import Path + +class TestPTBXLDataset(unittest.TestCase): + """ + Test PTB-XL 1.0.3 dataset with demo data. + """ + def setUp(self): + sample_records = { + '00001': { + 'load_from_path': 'sample/path', + 'signal_file': '00001_lr.dat', + 'label_file': '00001_lr.hea', + 'save_to_path': 'sample/path', + }, + '00002': { + 'load_from_path': 'sample/path', + 'signal_file': '00002_lr.dat', + 'label_file': '00002_lr.hea', + 'save_to_path': 'sample/path', + }, + '00003': { + 'load_from_path': 'sample/path', + 'signal_file': '00003_lr.dat', + 'label_file': '00003_lr.hea', + 'save_to_path': 'sample/path', + }, + '00004': { + 'load_from_path': 'sample/path', + 'signal_file': '00004_lr.dat', + 'label_file': '00004_lr.hea', + 'save_to_path': 'sample/path', + }, + '00005': { + 'load_from_path': 'sample/path', + 'signal_file': '00005_lr.dat', + 'label_file': '00005_lr.hea', + 'save_to_path': 'sample/path', + }, + '00006': { + 'load_from_path': 'sample/path', + 'signal_file': '00006_lr.dat', + 'label_file': '00006_lr.hea', + 'save_to_path': 'sample/path', + }, + } + + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + os.makedirs(os.path.join(self.root, 'records100')) + + for i in sample_records.keys(): + with open(os.path.join(self.root, f'records100/{i}.dat'), 'w') as f: + f.write('sample .dat data') + with open(os.path.join(self.root, f'records100/{i}.hea'), 'w') as f: + f.write('sample .hea data') + + self.dataset = PTBXLDataset(root=self.root, download=False, downsampled=True, dev=True) + + + """ + Verify if the dataset contains correct basic information + """ + def testBasicInfo(self): + self.assertEqual(self.dataset.dataset_name, 'PTB-XL') + self.dataset.stats() + + """ + Verify if the dataset contains expected data + """ + def testData(self): + + # Test if dev mode has been applied + self.assertLessEqual(len(self.dataset.unique_patient_ids), 5) + + # Test info of the first patient + patient = self.dataset.get_patient('00001') + self.assertIsNotNone(patient) + self.assertIsInstance(patient.patient_id, str) + self.assertEqual(patient.patient_id, '00001') + + events = self.dataset.get_patient('00001').get_events() + self.assertEqual(len(events), 1) + self.assertIsNotNone(events[0]['load_from_path']) + self.assertIsNotNone(events[0]['signal_file']) + self.assertIsNotNone(events[0]['label_file']) + self.assertIsNotNone(events[0]['save_to_path']) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file