From 2e9f2f8c388e42798ee3cce0a2851d3a64f6fff7 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sat, 30 May 2026 19:29:54 -0500 Subject: [PATCH 1/4] transfer FHIR pipeline to branch --- docs/api/datasets.rst | 2 + .../pyhealth.datasets.FHIRDataset.rst | 306 +++++++ .../datasets/pyhealth.datasets.MIMIC4FHIR.rst | 78 ++ docs/api/models.rst | 1 + .../models/pyhealth.models.EHRMambaCEHR.rst | 12 + docs/api/tasks.rst | 1 + ...pyhealth.tasks.mpf_clinical_prediction.rst | 12 + examples/mimic4fhir_mpf_ehrmamba.py | 56 ++ pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/base_dataset.py | 10 +- pyhealth/datasets/fhir/__init__.py | 16 + pyhealth/datasets/fhir/base.py | 415 ++++++++++ .../datasets/fhir/configs/mimic4fhir.yaml | 210 +++++ pyhealth/datasets/fhir/mimic4.py | 42 + pyhealth/datasets/fhir/utils.py | 720 +++++++++++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/cehr_embeddings.py | 112 +++ pyhealth/models/ehrmamba_cehr.py | 117 +++ pyhealth/models/utils.py | 28 + pyhealth/processors/__init__.py | 3 + pyhealth/processors/cehr_processor.py | 175 ++++ pyhealth/tasks/__init__.py | 8 + pyhealth/tasks/mpf_clinical_prediction.py | 310 +++++++ tests/core/test_ehrmamba_cehr.py | 126 +++ tests/core/test_fhir_dataset.py | 764 ++++++++++++++++++ tests/core/test_fhir_ndjson_fixtures.py | 110 +++ tests/core/test_mpf_task.py | 99 +++ 27 files changed, 3733 insertions(+), 2 deletions(-) create mode 100644 docs/api/datasets/pyhealth.datasets.FHIRDataset.rst create mode 100644 docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst create mode 100644 docs/api/models/pyhealth.models.EHRMambaCEHR.rst create mode 100644 docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst create mode 100644 examples/mimic4fhir_mpf_ehrmamba.py create mode 100644 pyhealth/datasets/fhir/__init__.py create mode 100644 pyhealth/datasets/fhir/base.py create mode 100644 pyhealth/datasets/fhir/configs/mimic4fhir.yaml create mode 100644 pyhealth/datasets/fhir/mimic4.py create mode 100644 pyhealth/datasets/fhir/utils.py create mode 100644 pyhealth/models/cehr_embeddings.py create mode 100644 pyhealth/models/ehrmamba_cehr.py create mode 100644 pyhealth/processors/cehr_processor.py create mode 100644 pyhealth/tasks/mpf_clinical_prediction.py create mode 100644 tests/core/test_ehrmamba_cehr.py create mode 100644 tests/core/test_fhir_dataset.py create mode 100644 tests/core/test_fhir_ndjson_fixtures.py create mode 100644 tests/core/test_mpf_task.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..ad34a7a0a 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,8 @@ Available Datasets datasets/pyhealth.datasets.SampleDataset datasets/pyhealth.datasets.MIMIC3Dataset datasets/pyhealth.datasets.MIMIC4Dataset + datasets/pyhealth.datasets.FHIRDataset + datasets/pyhealth.datasets.MIMIC4FHIR datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset datasets/pyhealth.datasets.eICUDataset diff --git a/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst new file mode 100644 index 000000000..16dbefc13 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst @@ -0,0 +1,306 @@ +pyhealth.datasets.FHIRDataset +===================================== + +A generic, config-driven NDJSON ingest for `HL7 FHIR +`_ datasets. The whole pipeline is described by **a +single YAML config** with three top-level sections — what files to read, how to +turn each FHIR resource into a flat row, and how those rows appear as events +downstream. A custom FHIR ingest is "point at a YAML" — no Python required. + +The bundled :class:`~pyhealth.datasets.MIMIC4FHIR` subclass uses this engine +with the ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` config tuned for +PhysioNet's MIMIC-IV on FHIR export. See the sub-page below for the quick-start. + +.. contents:: On this page + :local: + :depth: 1 + + +Quick start +----------- + +.. code-block:: python + + from pyhealth.datasets import MIMIC4FHIR, get_dataloader, split_by_patient + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + from pyhealth.models import EHRMambaCEHR + from pyhealth.trainer import Trainer + + def main(): + ds = MIMIC4FHIR(root="/data/mimic-iv-fhir") + sample_ds = ds.set_task(MPFClinicalPredictionTask(), num_workers=1) + train, val, test = split_by_patient(sample_ds, [0.7, 0.1, 0.2]) + vocab_size = sample_ds.input_processors["concept_ids"].vocab.vocab_size + model = EHRMambaCEHR(dataset=sample_ds, vocab_size=vocab_size) + Trainer(model=model).train( + train_dataloader=get_dataloader(train, batch_size=8, shuffle=True), + val_dataloader=get_dataloader(val, batch_size=8), + epochs=2, + ) + + if __name__ == "__main__": + main() + +(``if __name__ == "__main__":`` matters — :meth:`~pyhealth.datasets.BaseDataset.set_task` +forks Dask workers; without the guard the workers re-import and re-spawn.) + + +Pipeline at a glance +-------------------- + +:: + + NDJSON shards on disk + | + | (Phase A) — stream line by line, route by resourceType, + | project via the YAML's resource_specs + v + flattened_tables/.parquet <- cache #1 + | + | (Phase B) — load_table, dd.concat, sort by patient_id (Dask) + v + global_event_df.parquet/part-*.parquet <- cache #2 + | + | (Phase C) — task_transform per-patient sample emit + v + task_df.ld/ <- cache #3a + | + | fit CehrProcessor vocab via SampleBuilder.fit(dataset) + | proc_transform per-sample tensorisation + v + samples_*.ld/ <- cache #3b ──> SampleDataset + +Each of the three cache tiers has its own existence check; re-running with +identical inputs skips every phase. Cache identity hashes the YAML byte digest, +glob patterns, ``max_patients``, and engine schema version — any meaningful +config change invalidates everything below it. See +:class:`~pyhealth.datasets.BaseDataset` for the Phase B/C internals that are +shared with all other PyHealth datasets. + + +The unified YAML config +----------------------- + +A FHIR ingest YAML has three top-level sections. The bundled +``mimic4fhir.yaml`` is the canonical worked example; what follows is the +section-by-section reference. + +Section 1: ``glob_patterns:`` (which files to read) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: yaml + + glob_patterns: + - "**/MimicPatient*.ndjson.gz" + - "**/MimicEncounter*.ndjson.gz" + # ... one pattern per resource-type shard family + +Defaults to ``["**/*.ndjson.gz"]`` when omitted. Only worth setting when your +export has a per-resource-type file-naming convention you want to exploit for +speed — PhysioNet MIMIC-IV FHIR ships shards as ``MimicPatient*.ndjson.gz``, +``MimicEncounter*.ndjson.gz``, etc., and filtering at the file level avoids +decompressing ~10% of the export that contains only unconfigured resource +types. For a generic export where everything is in ``bundles.ndjson.gz``, omit +this block and the streamer will filter by ``resourceType`` after parsing. + +Override at runtime via ``MIMIC4FHIR(glob_pattern=...)`` or +``MIMIC4FHIR(glob_patterns=[...])``. + +Section 2: ``resource_specs:`` (how to project JSON into rows) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Keys are FHIR ``resourceType`` strings. For each, declare a ``table`` name and +an ordered ``columns`` mapping: + +.. code-block:: yaml + + resource_specs: + + Patient: + table: patient + columns: + patient_id: { locate: ["id"], required: true } + birth_date: { locate: ["birthDate"] } + gender: { locate: ["gender"] } + deceased_boolean: { locate: ["deceasedBoolean"], transform: bool_norm } + + Observation: + table: observation + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["encounter.reference"], transform: ref_id } + event_time: { locate: ["effectiveDateTime", "effectivePeriod.start", "issued"] } + concept_key: { locate: ["code"], transform: coding_key } + +Each column entry has three fields: + +``locate`` *(required, list of dotted paths)* + Ordered JSON paths into the resource; the first that resolves to a non-null + value wins. This is how FHIR choice-types (``onset[x]``, ``effective[x]``, + ``performed[x]``, …) are handled — list every variant explicitly. A single + string is accepted as shorthand for a one-element list. + +``transform`` *(optional, name of a built-in transform, default ``identity``)* + Maps the located leaf to a flat scalar string. See the registry below. + +``required`` *(optional, bool, default false)* + When ``true``, a resource whose ``locate`` cannot be resolved is **dropped** + (and logged) rather than emitted with a null. Use this on the patient + reference column so events without a discoverable patient never reach the + global event frame. + +Transform registry +^^^^^^^^^^^^^^^^^^ + +Available transforms (defined in +``pyhealth/datasets/fhir/utils.py`` ``TRANSFORMS`` dict): + +================== =========================================================== +``identity`` Pass the value through. Stringifies non-string scalars. +``ref_id`` Reference object or ``"Patient/p1"`` -> ``"p1"``. +``coding_key`` CodeableConcept -> ``"system|code"`` of its first coding. +``bool_norm`` JSON boolean / ``"true"``/``"false"`` -> ``"true"``/``"false"``/None. +``med_concept`` MedicationRequest medication[x] -> codeable-concept or + ``"MedicationRequest/reference|"`` fallback. +================== =========================================================== + +Adding a new transform is a one-liner: register a callable in ``TRANSFORMS`` +in ``utils.py`` and reference it by name from the YAML. + +Section 3: ``tables:`` (how rows are exposed as events) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Keys here must match the ``table:`` values from Section 2. Each entry tells +:meth:`~pyhealth.datasets.BaseDataset.load_table` how to read the flat parquet: + +.. code-block:: yaml + + tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: ["birth_date", "gender", "deceased_boolean"] + + observation: + file_path: "observation.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: ["resource_id", "encounter_id", "event_time", "concept_key"] + +``file_path`` is the parquet filename inside the cached +``flattened_tables/`` directory. ``patient_id`` and ``timestamp`` name the +columns to surface as the normalised ``patient_id`` and ``timestamp`` on each +event. ``attributes`` is the list of columns surfaced as event attributes — in +the global event frame they're renamed to ``{table}/{attr}`` and later show up +on ``patient.get_events(event_type=...).attr_name``. + +Cross-section validation +~~~~~~~~~~~~~~~~~~~~~~~~ + +At load time the dataset checks that every ``table:`` value declared in +Section 2 has a matching ``tables.`` block in Section 3. Typos surface +as a config error at startup, not silent empty parquets. + + +Customising for a non-MIMIC FHIR export +--------------------------------------- + +Step 1 — write your YAML. +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Copy ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` and adapt the +``resource_specs:`` and ``tables:`` blocks for the resources you care about. +For an export that adds Immunizations: + +.. code-block:: yaml + + resource_specs: + Patient: + table: patient + columns: + patient_id: { locate: ["id"], required: true } + birth_date: { locate: ["birthDate"] } + Immunization: + table: immunization + columns: + patient_id: { locate: ["patient.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + event_time: { locate: ["occurrenceDateTime", "recorded"] } + concept_key: { locate: ["vaccineCode"], transform: coding_key } + + tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: ["birth_date"] + immunization: + file_path: "immunization.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: ["resource_id", "event_time", "concept_key"] + +Step 2 — instantiate +~~~~~~~~~~~~~~~~~~~~ + +Either pass ``config_path=...`` directly: + +.. code-block:: python + + from pyhealth.datasets import FHIRDataset + + ds = FHIRDataset( + root="/data/my_fhir_export", + config_path="/path/to/my_export.yaml", + ) + +or write a 3-line subclass that bundles your config: + +.. code-block:: python + + from pyhealth.datasets import FHIRDataset + + class MyFHIR(FHIRDataset): + DEFAULT_CONFIG_PATH = "/path/to/my_export.yaml" + + ds = MyFHIR(root="/data/my_fhir_export") + +Step 3 — that's it. +~~~~~~~~~~~~~~~~~~~ + +Everything downstream — :meth:`~pyhealth.datasets.BaseDataset.set_task`, +:meth:`~pyhealth.datasets.BaseDataset.iter_patients`, +:meth:`~pyhealth.datasets.BaseDataset.get_patient` — works the same as for any +other PyHealth dataset. + + +Notes on resource use +--------------------- + +Streaming ingest avoids loading the whole NDJSON corpus into RAM, but downstream +steps still scale with cohort size. For a **smoke run** the bundled example +fixtures fit on any laptop. For a **laptop-scale real subset**, set +``max_patients=`` and/or narrow ``glob_patterns`` to keep cache and task passes +manageable; ≥16 GB system RAM is a comfort target for Polars + the trainer. +For the **full PhysioNet export**, prefer fast SSD, large disk, and plenty of +RAM — total work scales with the corpus size even if RAM ingest is bounded. + + +Bundled FHIR datasets +--------------------- + +.. toctree:: + :maxdepth: 1 + + pyhealth.datasets.MIMIC4FHIR + + +API reference +------------- + +.. autoclass:: pyhealth.datasets.FHIRDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst new file mode 100644 index 000000000..344f60cf7 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst @@ -0,0 +1,78 @@ +pyhealth.datasets.MIMIC4FHIR +============================ + +A pre-bundled :class:`~pyhealth.datasets.FHIRDataset` for the PhysioNet +`MIMIC-IV on FHIR `_ export +(R4, demo 2.1.0 and full release). All ingest logic — file globs, per-resource +projection, downstream event schema — is described by the bundled YAML at +``pyhealth/datasets/fhir/configs/mimic4fhir.yaml``; this class only points at +that path. + +For everything outside the MIMIC-specific defaults (transform registry, +``Col`` / ``ResourceSpec`` syntax, the three-tier cache story), see the parent +page: :doc:`pyhealth.datasets.FHIRDataset`. + +Quick start +----------- + +.. code-block:: python + + from pyhealth.datasets import MIMIC4FHIR + from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + def main(): + ds = MIMIC4FHIR(root="/data/mimic-iv-fhir") + sample_ds = ds.set_task(MPFClinicalPredictionTask(), num_workers=1) + # ... split / dataloader / model / trainer ... + + if __name__ == "__main__": + main() + +For the full end-to-end demo (training EHR-Mamba on MPF samples) see +``examples/mimic4fhir_mpf_ehrmamba.py``. + +Resource coverage +----------------- + +The bundled config flattens six FHIR resource types out of the PhysioNet +export: + +========================== ============================ =============================== +FHIR resourceType Output table Key columns +========================== ============================ =============================== +``Patient`` ``patient.parquet`` ``patient_id``, ``birth_date``, ``gender``, ``deceased_*`` +``Encounter`` ``encounter.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``encounter_class`` +``Condition`` ``condition.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key`` +``Observation`` ``observation.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key`` +``MedicationRequest`` ``medication_request.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key`` +``Procedure`` ``procedure.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key`` +========================== ============================ =============================== + +PhysioNet shards that contain only other resource types +(``MedicationAdministration``, ``Specimen``, ``Organization``, …) are skipped +at the file level by the bundled ``glob_patterns``. To include them, override +``glob_patterns=`` at the constructor and add a ``resource_specs:`` entry plus +matching ``tables:`` entry in a copy of the YAML. + +Customising +----------- + +The bundled config is the easiest starting point for authoring a similar ingest +for other FHIR exports. Copy +``pyhealth/datasets/fhir/configs/mimic4fhir.yaml``, edit the +``resource_specs:`` and ``tables:`` blocks for the resources you care about, +and either: + +* pass ``config_path=...`` directly to ``FHIRDataset(root=..., config_path=...)``, or +* subclass ``FHIRDataset`` and set ``DEFAULT_CONFIG_PATH`` on the subclass. + +See the "Customising for a non-MIMIC FHIR export" section of +:doc:`pyhealth.datasets.FHIRDataset` for the step-by-step. + +API reference +------------- + +.. autoclass:: pyhealth.datasets.MIMIC4FHIR + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..3a29c8773 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -186,6 +186,7 @@ API Reference models/pyhealth.models.MoleRec models/pyhealth.models.Deepr models/pyhealth.models.EHRMamba + models/pyhealth.models.EHRMambaCEHR models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst new file mode 100644 index 000000000..c15a09962 --- /dev/null +++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst @@ -0,0 +1,12 @@ +pyhealth.models.EHRMambaCEHR +=================================== + +EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`) +and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with +:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and +:class:`~pyhealth.datasets.fhir.FHIRDataset`. + +.. autoclass:: pyhealth.models.EHRMambaCEHR + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..17d07026b 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -214,6 +214,7 @@ Available Tasks Drug Recommendation Length of Stay Prediction Medical Transcriptions Classification + MPF Clinical Prediction (FHIR) Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) Patient Linkage (MIMIC-III) diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst new file mode 100644 index 000000000..569f75a38 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst @@ -0,0 +1,12 @@ +pyhealth.tasks.mpf_clinical_prediction +====================================== + +Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR +token timelines, paired with :class:`~pyhealth.datasets.FHIRDataset` and +:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the +paper linked in the course replication PR. + +.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py new file mode 100644 index 000000000..d6d83ad88 --- /dev/null +++ b/examples/mimic4fhir_mpf_ehrmamba.py @@ -0,0 +1,56 @@ +"""EHRMambaCEHR on the local MIMIC-IV FHIR demo. + +Barebones path: Dataset -> task -> model -> trainer -> evaluate. + +Runs against the bundled demo at +``datasets/physionet.org/mimic-iv-fhir-demo/2.1.0/fhir`` and persists the +flattened-table cache under ``datasets/.cache/pyhealth/fhir-demo`` so a +second run hits the cache. + + PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py +""" + +from __future__ import annotations + +from pathlib import Path + +from pyhealth.datasets import MIMIC4FHIR, get_dataloader, split_by_patient +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask +from pyhealth.trainer import Trainer + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEMO_ROOT = REPO_ROOT / "datasets" / "physionet.org" / "mimic-iv-fhir-demo" / "2.1.0" / "fhir" +CACHE_DIR = REPO_ROOT / "datasets" / ".cache" / "pyhealth" / "fhir-demo" + + +def main() -> None: + dataset = MIMIC4FHIR(root=str(DEMO_ROOT), cache_dir=str(CACHE_DIR)) + sample_dataset = dataset.set_task(MPFClinicalPredictionTask(), num_workers=1) + + train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2]) + train_loader = get_dataloader(train_ds, batch_size=8, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=8, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=8, shuffle=False) + + vocab_size = sample_dataset.input_processors["concept_ids"].vocab.vocab_size + model = EHRMambaCEHR( + dataset=sample_dataset, + vocab_size=vocab_size, + embedding_dim=32, + num_layers=2, + dropout=0.1, + ) + + trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"]) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=2, + monitor="roc_auc", + ) + print(trainer.evaluate(test_loader)) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..c29955e7d 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -59,6 +59,7 @@ def __init__(self, *args, **kwargs): from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset +from .fhir import FHIRDataset, MIMIC4FHIR from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .physionet_deid import PhysioNetDeIDDataset diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 0e4280aab..54de1b961 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -420,10 +420,16 @@ def create_tmpdir(self) -> Path: return tmp_dir def clean_tmpdir(self) -> None: - """Cleans up the temporary directory within the cache.""" + """Cleans up the temporary directory within the cache. + + ``ignore_errors=True`` tolerates polars/pyarrow stream-writer + finalizers that may still be flushing into ``flattened_fhir_tables/`` + when we get here. The tmp dir is not load-bearing -- leftover bytes + will be reclaimed on the next ``clean_tmpdir`` or by the OS. + """ tmp_dir = self.cache_dir / "tmp" if tmp_dir.exists(): - shutil.rmtree(tmp_dir) + shutil.rmtree(tmp_dir, ignore_errors=True) def _scan_csv_tsv_gz(self, source_path: str) -> dd.DataFrame: """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. diff --git a/pyhealth/datasets/fhir/__init__.py b/pyhealth/datasets/fhir/__init__.py new file mode 100644 index 000000000..adcbbea85 --- /dev/null +++ b/pyhealth/datasets/fhir/__init__.py @@ -0,0 +1,16 @@ +"""FHIR datasets: a generic engine + per-source subclasses. + +- :class:`~pyhealth.datasets.fhir.base.FHIRDataset` — generic, config-driven base. +- :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR` — MIMIC-IV-on-FHIR (R4). +- :mod:`~pyhealth.datasets.fhir.utils` — the stateless flattening engine + (``Col``, ``ResourceSpec``, ``flatten_resource``, …). + +Authors: + John Wu and Evan Febrianto +""" + +from .base import FHIRDataset +from .mimic4 import MIMIC4FHIR +from .utils import Col, ResourceSpec + +__all__ = ["FHIRDataset", "MIMIC4FHIR", "Col", "ResourceSpec"] diff --git a/pyhealth/datasets/fhir/base.py b/pyhealth/datasets/fhir/base.py new file mode 100644 index 000000000..517d062bb --- /dev/null +++ b/pyhealth/datasets/fhir/base.py @@ -0,0 +1,415 @@ +"""Generic FHIR ingestion using flattened resource tables. + +Architecture +------------ +1. Stream NDJSON/NDJSON.GZ FHIR resources from disk. +2. Normalize each resource type into a 2D table via a declarative + :class:`~pyhealth.datasets.fhir.utils.ResourceSpec` registry + (``self.resource_specs``) — see :mod:`~pyhealth.datasets.fhir.utils`. +3. Feed those tables through the standard YAML-driven + :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task + processing operates on :class:`~pyhealth.data.Patient` and + ``global_event_df`` rows. + +``FHIRDataset`` is generic: it owns the streaming/cache/validation machinery but +no specific resource specs or config. Use it directly by passing +``resource_specs=`` + ``config_path=``, or subclass it for a concrete source +(e.g. :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`) that bakes those in as +class attributes. + +Authors: + John Wu and Evan Febrianto +""" + +from __future__ import annotations + +import functools +import hashlib +import logging +import operator +import shutil +import uuid +from pathlib import Path +from typing import Any, Dict, List, Mapping, Optional, Sequence + +import dask.dataframe as dd +import narwhals as nw +import orjson +import pandas as pd +import platformdirs +from yaml import safe_load + +from ..base_dataset import BaseDataset +from .utils import ( + FHIR_SCHEMA_VERSION, + SUPPORTED_OUTPUT_FORMATS, + ResourceSpec, + filter_flat_tables_by_patient_ids, + load_resource_specs_from_yaml, + sorted_patient_ids_from_flat_tables, + stream_fhir_ndjson_to_flat_tables, + table_file_name, + tables_from_specs, +) + +logger = logging.getLogger(__name__) + + +def read_fhir_settings_yaml(path: str) -> Dict[str, Any]: + with open(path, encoding="utf-8") as stream: + data = safe_load(stream) + return data if isinstance(data, dict) else {} + + +def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series: + if getattr(part.dtype, "tz", None) is not None: + part = part.dt.tz_localize(None) + return part.astype("datetime64[ms]") + + +class FHIRDataset(BaseDataset): + """FHIR resources flattened into per-type tables, then the standard pipeline. + + Streams raw FHIR NDJSON/NDJSON.GZ exports into flattened tables (one per + configured resource type) and pipelines them through + :class:`~pyhealth.datasets.BaseDataset` for downstream task processing + (global event dataframe, patient iteration, task sampling). + + The entire ingest is driven by a single YAML config with three top-level + sections — ``glob_patterns:`` (which NDJSON files to open), + ``resource_specs:`` (how to project each FHIR resource type into a flat + row), and ``tables:`` (how those rows are exposed as events downstream). + See ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` for a complete + worked example and the FHIRDataset rst page for a section-by-section guide. + + Pass ``config_path=...`` directly, or subclass and set + ``DEFAULT_CONFIG_PATH`` to bundle a default (see + :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`). + + Args: + root: Path to the NDJSON/NDJSON.GZ export directory. + config_path: Path to the FHIR ingest YAML. Defaults to the class + attribute ``DEFAULT_CONFIG_PATH``. The YAML must contain a + ``resource_specs:`` block; any ``glob_patterns:`` and ``tables:`` + blocks are also read from here. + glob_pattern: Single glob for NDJSON files; overrides the YAML's + ``glob_patterns``. Mutually exclusive with *glob_patterns*. + glob_patterns: Multiple glob patterns; overrides the YAML's + ``glob_patterns``. Mutually exclusive with *glob_pattern*. + output_format: Flat-table format, one of ``parquet`` (default), + ``csv``, ``tsv``. Defaults to the class attribute + ``DEFAULT_OUTPUT_FORMAT``. + max_patients: Limit ingest to the first *N* unique patient IDs. + ingest_num_shards: Ignored; retained for API compatibility. + cache_dir: Cache directory root (UUID subdir appended per config). + num_workers: Worker processes for task sampling. + dev: Development mode; limits to 1000 patients if *max_patients* is + ``None``. + + Examples: + >>> # ad-hoc, no subclass + >>> ds = FHIRDataset( + ... root="/data/fhir", + ... config_path="my_fhir.yaml", + ... ) + >>> # or a preconfigured source subclass + >>> from pyhealth.datasets import MIMIC4FHIR + >>> ds = MIMIC4FHIR(root="/data/mimic-iv-fhir", max_patients=500) + """ + + #: Default ingest YAML path; set by source subclasses to bundle a config. + DEFAULT_CONFIG_PATH: Optional[str] = None + #: Default flat-table output format. + DEFAULT_OUTPUT_FORMAT: str = "parquet" + #: Dataset name used for cache identity / logging. + DATASET_NAME: str = "fhir" + + def __init__( + self, + root: str, + config_path: Optional[str] = None, + glob_pattern: Optional[str] = None, + glob_patterns: Optional[Sequence[str]] = None, + output_format: Optional[str] = None, + max_patients: Optional[int] = None, + ingest_num_shards: Optional[int] = None, + cache_dir: Optional[str | Path] = None, + num_workers: int = 1, + dev: bool = False, + ) -> None: + del ingest_num_shards + + resolved_config = config_path or type(self).DEFAULT_CONFIG_PATH + if resolved_config is None: + raise ValueError( + "FHIRDataset requires config_path: pass config_path=... or use a " + "subclass that defines DEFAULT_CONFIG_PATH." + ) + self._fhir_config_path = str(Path(resolved_config).resolve()) + self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path) + + # Section 2 of the YAML: how each FHIR resource type projects into a row. + self.resource_specs: Mapping[str, ResourceSpec] = ( + load_resource_specs_from_yaml(self._fhir_settings) + ) + + # Cross-validate: every table the specs declare must have a downstream + # `tables:` block (Section 3). Catches typos at startup. + spec_tables = set(tables_from_specs(self.resource_specs)) + declared_tables = set((self._fhir_settings.get("tables") or {}).keys()) + missing = spec_tables - declared_tables + if missing: + raise ValueError( + f"config {self._fhir_config_path}: resource_specs references " + f"table(s) {sorted(missing)} not declared in the 'tables:' " + f"block. Add a matching tables. entry (patient_id, " + f"timestamp, attributes) for each." + ) + + self.output_format = output_format or type(self).DEFAULT_OUTPUT_FORMAT + if self.output_format not in SUPPORTED_OUTPUT_FORMATS: + raise ValueError( + f"Unsupported output_format {self.output_format!r}; " + f"expected one of {SUPPORTED_OUTPUT_FORMATS}." + ) + + if glob_pattern is not None and glob_patterns is not None: + raise ValueError("Pass at most one of glob_pattern and glob_patterns.") + if glob_patterns is not None: + self.glob_patterns: List[str] = list(glob_patterns) + elif glob_pattern is not None: + self.glob_patterns = [glob_pattern] + else: + raw_list = self._fhir_settings.get("glob_patterns") + if raw_list: + if not isinstance(raw_list, list): + raise TypeError("config glob_patterns must be a list of strings.") + self.glob_patterns = [str(x) for x in raw_list] + elif self._fhir_settings.get("glob_pattern") is not None: + self.glob_patterns = [str(self._fhir_settings["glob_pattern"])] + else: + self.glob_patterns = ["**/*.ndjson.gz"] + + self.glob_pattern = ( + self.glob_patterns[0] + if len(self.glob_patterns) == 1 + else "; ".join(self.glob_patterns) + ) + self.max_patients = 1000 if dev and max_patients is None else max_patients + + self._fhir_tables = tables_from_specs(self.resource_specs) + + resolved_root = str(Path(root).expanduser().resolve()) + super().__init__( + root=resolved_root, + tables=list(self._fhir_tables), + dataset_name=type(self).DATASET_NAME, + config_path=self._fhir_config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + # ------------------------------------------------------------------ + # Cache identity + # ------------------------------------------------------------------ + + def _init_cache_dir(self, cache_dir: str | Path | None) -> Path: + try: + yaml_digest = hashlib.sha256( + Path(self._fhir_config_path).read_bytes() + ).hexdigest()[:16] + except OSError: + yaml_digest = "missing" + identity = orjson.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + "glob_patterns": self.glob_patterns, + "max_patients": self.max_patients, + "output_format": self.output_format, + "fhir_schema_version": FHIR_SCHEMA_VERSION, + "fhir_yaml_digest16": yaml_digest, + }, + option=orjson.OPT_SORT_KEYS, + ).decode("utf-8") + cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity)) + out = ( + Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id + if cache_dir is None + else Path(cache_dir) / cache_id + ) + out.mkdir(parents=True, exist_ok=True) + logger.info(f"Cache dir: {out}") + return out + + # ------------------------------------------------------------------ + # NDJSON -> flat tables ingest + # ------------------------------------------------------------------ + + @property + def prepared_tables_dir(self) -> Path: + return self.cache_dir / "flattened_tables" + + def _ensure_prepared_tables(self) -> None: + root = Path(self.root) + if not root.is_dir(): + raise FileNotFoundError(f"FHIR root not found: {root}") + + expected = [ + self.prepared_tables_dir / table_file_name(t, self.output_format) + for t in self._fhir_tables + ] + if all(p.is_file() for p in expected): + return + if self.prepared_tables_dir.exists(): + shutil.rmtree(self.prepared_tables_dir) + + try: + staging_root = self.create_tmpdir() + staging = staging_root / "flattened_fhir_tables" + staging.mkdir(parents=True, exist_ok=True) + stream_fhir_ndjson_to_flat_tables( + root, + self.glob_patterns, + staging, + self.resource_specs, + self.output_format, + ) + if self.max_patients is None: + shutil.move(str(staging), str(self.prepared_tables_dir)) + return + + filtered_root = self.create_tmpdir() + filtered = filtered_root / "filtered" + pids = sorted_patient_ids_from_flat_tables( + staging, self._fhir_tables, self.output_format + ) + filter_flat_tables_by_patient_ids( + staging, + filtered, + pids[: self.max_patients], + self._fhir_tables, + self.output_format, + ) + shutil.move(str(filtered), str(self.prepared_tables_dir)) + finally: + self.clean_tmpdir() + + def _event_transform(self, output_dir: Path) -> None: + self._ensure_prepared_tables() + super()._event_transform(output_dir) + + # ------------------------------------------------------------------ + # Table loading (flat tables instead of source CSVs) + # ------------------------------------------------------------------ + + def _read_flat_table(self, path: Path) -> dd.DataFrame: + if self.output_format == "parquet": + return dd.read_parquet( + str(path), split_row_groups=True, blocksize="64MB" + ).replace("", pd.NA) + sep = "\t" if self.output_format == "tsv" else "," + return dd.read_csv( + str(path), sep=sep, dtype=str, blocksize="64MB" + ).replace("", pd.NA) + + def load_table(self, table_name: str) -> dd.DataFrame: + """Load one flattened table into the standard event schema. + + Deviations from ``BaseDataset.load_table`` (CSV via ``_scan_csv_tsv_gz``): + + * Reads pre-built flat tables (parquet/csv/tsv) under + ``prepared_tables_dir``. + * Timestamp parsing uses ``errors="coerce"`` + ``utc=True`` (FHIR ISO + strings include timezone suffix or partial dates). + * Strips tz-aware timestamps to naive UTC for Dask compat. + * Drops rows with null ``patient_id`` before returning. + """ + assert self.config is not None + if table_name not in self.config.tables: + raise ValueError(f"Table {table_name} not found in config") + + table_cfg = self.config.tables[table_name] + path = self.prepared_tables_dir / table_file_name( + table_name, self.output_format + ) + if not path.exists(): + raise FileNotFoundError(f"Flattened table not found: {path}") + + logger.info(f"Scanning FHIR flattened table: {table_name} from {path}") + df: dd.DataFrame = self._read_flat_table(path) + df = df.rename(columns=str.lower) + + preprocess_func = getattr(self, f"preprocess_{table_name}", None) + if preprocess_func is not None: + logger.info( + f"Preprocessing FHIR table: {table_name} " + f"with {preprocess_func.__name__}" + ) + df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr] + + for join_cfg in table_cfg.join: + join_path = self.prepared_tables_dir / Path(join_cfg.file_path).name + if not join_path.exists(): + raise FileNotFoundError(f"FHIR join table not found: {join_path}") + logger.info(f"Joining FHIR table {table_name} with {join_path}") + join_df: dd.DataFrame = self._read_flat_table(join_path) + join_df = join_df.rename(columns=str.lower) + join_key = join_cfg.on.lower() + cols = [c.lower() for c in join_cfg.columns] + df = df.merge(join_df[[join_key] + cols], on=join_key, how=join_cfg.how) + + ts_col = table_cfg.timestamp + if ts_col: + ts = ( + functools.reduce( + operator.add, + (df[c].astype("string") for c in ts_col), + ) + if isinstance(ts_col, list) + else df[ts_col].astype("string") + ) + ts = dd.to_datetime( + ts, format=table_cfg.timestamp_format, errors="coerce", utc=True + ) + df = df.assign(timestamp=ts.map_partitions(_strip_tz_to_naive_ms)) + else: + df = df.assign(timestamp=pd.NaT) + + if table_cfg.patient_id: + df = df.assign(patient_id=df[table_cfg.patient_id].astype("string")) + else: + df = df.reset_index(drop=True) + df = df.assign(patient_id=df.index.astype("string")) + + df = df.dropna(subset=["patient_id"]) + df = df.assign(event_type=table_name) + rename_attr = { + attr.lower(): f"{table_name}/{attr}" for attr in table_cfg.attributes + } + df = df.rename(columns=rename_attr) + return df[ + ["patient_id", "event_type", "timestamp"] + + [rename_attr[a.lower()] for a in table_cfg.attributes] + ] + + # ------------------------------------------------------------------ + # Patient IDs (deterministic sorted order) + # ------------------------------------------------------------------ + + @property + def unique_patient_ids(self) -> List[str]: + if self._unique_patient_ids is None: + self._unique_patient_ids = ( + self.global_event_df.select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming") + .to_series() + .to_list() + ) + logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") + return self._unique_patient_ids diff --git a/pyhealth/datasets/fhir/configs/mimic4fhir.yaml b/pyhealth/datasets/fhir/configs/mimic4fhir.yaml new file mode 100644 index 000000000..1f0f1b697 --- /dev/null +++ b/pyhealth/datasets/fhir/configs/mimic4fhir.yaml @@ -0,0 +1,210 @@ +# MIMIC-IV-on-FHIR (R4) ingest config — single source of truth +# ============================================================ +# +# Authors: John Wu and Evan Febrianto +# +# This YAML drives the entire ingest pipeline for one FHIR export: +# +# 1. ``glob_patterns:`` which NDJSON files on disk to read +# 2. ``resource_specs:`` how to project each FHIR resource type into a row +# 3. ``tables:`` how those rows are exposed as events downstream +# +# Use this file as a complete worked example when authoring a YAML for any +# other FHIR export (BigQuery dumps, Synthea, etc.). To use it as-is, just +# instantiate :class:`~pyhealth.datasets.MIMIC4FHIR` with no extra arguments. +# +# To customise for a different export: +# * subclass FHIRDataset and point ``DEFAULT_CONFIG_PATH`` at your YAML, or +# * pass ``config_path=`` directly to ``FHIRDataset(...)``. + +version: "fhir_r4_flattened" + + +# --------------------------------------------------------------------------- +# Section 1: glob patterns — which NDJSON files to open +# --------------------------------------------------------------------------- +# +# Defaults to ``["**/*.ndjson.gz"]`` when omitted. Only useful when the export +# has per-resource-type file naming (PhysioNet MIMIC-IV FHIR ships shards as +# ``MimicPatient*.ndjson.gz``, ``MimicEncounter*.ndjson.gz``, etc.). Filtering +# at the file level avoids decompressing ~10% of the export that contains only +# unconfigured resource types (MedicationAdministration, Specimen, Organization). +# +# For a generic export where everything lives in ``bundles.ndjson.gz`` / +# ``**/*.ndjson.gz``, leave this commented out and the streamer will filter +# resources by ``resourceType`` after parsing — correct, just slower. +# +# Override at runtime via ``MIMIC4FHIR(glob_pattern=...)`` or +# ``MIMIC4FHIR(glob_patterns=[...])``. + +glob_patterns: + - "**/MimicPatient*.ndjson.gz" + - "**/MimicEncounter*.ndjson.gz" + - "**/MimicCondition*.ndjson.gz" + - "**/MimicObservation*.ndjson.gz" + - "**/MimicMedicationRequest*.ndjson.gz" + - "**/MimicProcedure*.ndjson.gz" + + +# --------------------------------------------------------------------------- +# Section 2: resource_specs — how to turn one FHIR JSON document into a row +# --------------------------------------------------------------------------- +# +# Keys are FHIR ``resourceType`` strings. For each resource type we declare: +# +# table: output table name (also the per-type Parquet filename stem). +# columns: ordered mapping of output column name -> Col spec. +# +# Each Col spec lists ordered JSON paths (``locate``); the first path that +# resolves to a non-null value wins (this is how FHIR choice-types like +# ``onset[x]`` and ``effective[x]`` are handled). ``transform`` names a function +# in the engine's TRANSFORMS registry that maps the located leaf to a flat +# string; ``required: true`` drops the resource if the leaf can't be resolved. +# +# Available transforms (defined in pyhealth/datasets/fhir/utils.py): +# +# identity Pass through (default; stringifies non-string scalars). +# ref_id "{ "reference": "Patient/p1" }" -> "p1". +# coding_key CodeableConcept -> "system|code" of its first coding. +# bool_norm JSON boolean / "true"/"false" string -> "true"/"false"/None. +# med_concept MedicationRequest.medication[x] -> codeable-concept or +# "MedicationRequest/reference|" fallback. +# +# Adding a new transform: register it in TRANSFORMS in utils.py; reference it +# by name here. + +resource_specs: + + Patient: + table: patient + columns: + patient_id: { locate: ["id"], required: true } + patient_fhir_id: { locate: ["id"] } + birth_date: { locate: ["birthDate"] } + gender: { locate: ["gender"] } + deceased_boolean: { locate: ["deceasedBoolean"], transform: bool_norm } + deceased_datetime: { locate: ["deceasedDateTime"] } + + Encounter: + table: encounter + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["id"] } + event_time: { locate: ["period.start"] } + encounter_class: { locate: ["class.code"] } + encounter_end: { locate: ["period.end"] } + + Condition: + table: condition + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["encounter.reference"], transform: ref_id } + event_time: { locate: ["onsetDateTime", "onsetPeriod.start", "recordedDate"] } + concept_key: { locate: ["code"], transform: coding_key } + + Observation: + table: observation + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["encounter.reference"], transform: ref_id } + event_time: { locate: ["effectiveDateTime", "effectivePeriod.start", "issued"] } + concept_key: { locate: ["code"], transform: coding_key } + + MedicationRequest: + table: medication_request + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["encounter.reference"], transform: ref_id } + event_time: { locate: ["authoredOn"] } + concept_key: { locate: ["medicationCodeableConcept", "medicationReference"], transform: med_concept } + + Procedure: + table: procedure + columns: + patient_id: { locate: ["subject.reference"], transform: ref_id, required: true } + resource_id: { locate: ["id"] } + encounter_id: { locate: ["encounter.reference"], transform: ref_id } + event_time: { locate: ["performedDateTime", "performedPeriod.start", "recordedDate"] } + concept_key: { locate: ["code"], transform: coding_key } + + +# --------------------------------------------------------------------------- +# Section 3: tables — how flat rows are exposed as events downstream +# --------------------------------------------------------------------------- +# +# Keys here must match the ``table:`` values in section 2. Each entry declares +# how BaseDataset.load_table reads the flat parquet: +# +# file_path: parquet filename (relative to the cached flattened_tables dir). +# patient_id: column name that holds the patient id. +# timestamp: column name to parse as the event timestamp. +# attributes: list of columns to surface as event attributes; they are +# renamed to ``{table}/{attr}`` in the global event frame and +# later show up on ``Patient.get_events(...).attr_name``. + +tables: + patient: + file_path: "patient.parquet" + patient_id: "patient_id" + timestamp: "birth_date" + attributes: + - "patient_fhir_id" + - "birth_date" + - "gender" + - "deceased_boolean" + - "deceased_datetime" + + encounter: + file_path: "encounter.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "encounter_class" + - "encounter_end" + + condition: + file_path: "condition.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + observation: + file_path: "observation.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + medication_request: + file_path: "medication_request.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" + + procedure: + file_path: "procedure.parquet" + patient_id: "patient_id" + timestamp: "event_time" + attributes: + - "resource_id" + - "encounter_id" + - "event_time" + - "concept_key" diff --git a/pyhealth/datasets/fhir/mimic4.py b/pyhealth/datasets/fhir/mimic4.py new file mode 100644 index 000000000..5732a6776 --- /dev/null +++ b/pyhealth/datasets/fhir/mimic4.py @@ -0,0 +1,42 @@ +"""MIMIC-IV-on-FHIR (R4) dataset. + +A thin :class:`~pyhealth.datasets.fhir.base.FHIRDataset` wrapper that points at +the bundled YAML for the PhysioNet MIMIC-IV on FHIR export. The whole ingest +contract (resource projection + downstream table schema + glob hints) lives in +the YAML; this class only names its default path. + +Use this YAML as the worked example when authoring a config for a different +FHIR export — copy ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` and +adapt the ``resource_specs:`` and ``tables:`` blocks. + +Authors: + John Wu and Evan Febrianto +""" + +from __future__ import annotations + +import os + +from .base import FHIRDataset + + +class MIMIC4FHIR(FHIRDataset): + """MIMIC-IV-on-FHIR (R4) dataset. + + Streams the PhysioNet MIMIC-IV on FHIR NDJSON.GZ export into flattened + Patient/Encounter/Condition/Observation/MedicationRequest/Procedure tables, + then runs the standard :class:`~pyhealth.datasets.BaseDataset` pipeline. + + The bundled config at ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` + matches both the PhysioNet 2.1.0 demo and the full release. Override + ``config_path=`` to point at a customised copy. + + Examples: + >>> ds = MIMIC4FHIR(root="/data/mimic-iv-fhir", max_patients=500) + >>> sample_ds = ds.set_task(task, num_workers=4) + """ + + DEFAULT_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "configs", "mimic4fhir.yaml" + ) + DATASET_NAME = "mimic4fhir" diff --git a/pyhealth/datasets/fhir/utils.py b/pyhealth/datasets/fhir/utils.py new file mode 100644 index 000000000..d67f19077 --- /dev/null +++ b/pyhealth/datasets/fhir/utils.py @@ -0,0 +1,720 @@ +"""FHIR NDJSON parsing, generic flattening, and tabular table writing. + +This module is the **stateless engine** behind FHIR-to-tabular conversion. It +knows nothing about any specific FHIR source or resource type: the per-resource +projection is supplied as a declarative registry of :class:`ResourceSpec` objects +(see ``MIMIC4FHIR.RESOURCE_SPECS`` for an example) and applied generically by +:func:`flatten_resource`. + +Key public API +-------------- +flatten_resource(resource, specs) + Project one FHIR resource dict into ``(table_name, row_dict)`` using a spec + registry, or ``None`` if the resource is unconfigured / missing a required + field. + +stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir, specs, output_format) + Stream all matching NDJSON/NDJSON.GZ resources into per-type flat tables + (parquet/csv/tsv), validating + counting drops along the way. + +sorted_ndjson_files(root, glob_pattern) + List matching NDJSON files under root (deduplicated, sorted). + +filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids, tables, output_format) + Subset existing flattened tables to a specific patient cohort. + +Authors: + John Wu and Evan Febrianto +""" + +from __future__ import annotations + +import gzip +import logging +from collections import Counter +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple + +import orjson +import polars as pl +import pyarrow as pa +import pyarrow.csv as pa_csv +import pyarrow.parquet as pq + +logger = logging.getLogger(__name__) + +GlobPatternArg = str | Sequence[str] +"""Single glob string or sequence of strings for NDJSON file discovery.""" + +__all__ = [ + # Types + "GlobPatternArg", + "Col", + "ResourceSpec", + # Constants + "FHIR_SCHEMA_VERSION", + "SUPPORTED_OUTPUT_FORMATS", + "TRANSFORMS", + # Spec helpers + "tables_from_specs", + "columns_from_specs", + "table_file_name", + "load_resource_specs_from_yaml", + # Datetime helpers + "parse_dt", + "as_naive", + # FHIR iteration + "iter_ndjson_objects", + "iter_resources_from_ndjson_obj", + # Extraction + "flatten_resource", + # Pipeline + "sorted_ndjson_files", + "stream_fhir_ndjson_to_flat_tables", + "filter_flat_tables_by_patient_ids", + "sorted_patient_ids_from_flat_tables", +] + +# Bump when the flattening engine or its output layout changes; folded into the +# dataset cache identity so stale caches rebuild automatically. +FHIR_SCHEMA_VERSION = 4 + +SUPPORTED_OUTPUT_FORMATS = ("parquet", "csv", "tsv") + + +# --------------------------------------------------------------------------- +# Declarative extraction spec (the registry's value type) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Col: + """How to project one flat column out of a FHIR resource. + + Typically constructed indirectly by :meth:`ResourceSpec.from_dict` while + loading the dataset's YAML config; direct instantiation is supported for + programmatic use. + + Attributes: + locate: Ordered dotted paths into the resource; the first that resolves + to a non-null value wins. This is how FHIR choice-types (``onset[x]``, + ``effective[x]``, …) are handled — list every variant explicitly. + transform: Name of a value transform in :data:`TRANSFORMS` that converts + the located leaf into a flat scalar string. + required: When ``True``, a resource whose ``locate`` cannot be resolved is + dropped (and counted) rather than emitted with a null. + """ + + locate: Tuple[str, ...] + transform: str = "identity" + required: bool = False + + @classmethod + def from_dict(cls, data: Mapping[str, Any], *, ctx: str = "") -> "Col": + """Build a :class:`Col` from a YAML-style dict. + + Expected shape:: + + { locate: ["path.a", "path.b"], transform: "ref_id", required: false } + + ``transform`` defaults to ``"identity"`` and must name an entry in + :data:`TRANSFORMS`. ``required`` defaults to ``False``. A missing or + empty ``locate`` field raises ``ValueError``. + + Args: + data: Mapping containing ``locate`` (required) and the optional + ``transform`` / ``required`` keys. + ctx: Optional context string used in error messages + (e.g. ``"Patient.patient_id"``). + """ + if not isinstance(data, Mapping): + raise ValueError( + f"{ctx or 'Col'}: expected a mapping, got {type(data).__name__}." + ) + raw_locate = data.get("locate") + if not raw_locate: + raise ValueError( + f"{ctx or 'Col'}: missing required field 'locate'." + ) + if isinstance(raw_locate, str): + locate: Tuple[str, ...] = (raw_locate,) + else: + locate = tuple(str(p) for p in raw_locate) + if not locate: + raise ValueError( + f"{ctx or 'Col'}: 'locate' must list at least one path." + ) + transform = str(data.get("transform", "identity")) + if transform not in TRANSFORMS: + allowed = ", ".join(sorted(TRANSFORMS.keys())) + raise ValueError( + f"{ctx or 'Col'}: unknown transform {transform!r}. " + f"Allowed: {allowed}." + ) + required = bool(data.get("required", False)) + return cls(locate=locate, transform=transform, required=required) + + +@dataclass(frozen=True) +class ResourceSpec: + """How to project one FHIR resource type into a flat table. + + Typically constructed indirectly by :func:`load_resource_specs_from_yaml` + while loading the dataset's YAML config; direct instantiation is supported + for programmatic use. + + Attributes: + table: Output table name (also the per-type file stem). + columns: Mapping of output column name -> :class:`Col`. Insertion order + defines the table's column order. + """ + + table: str + columns: Mapping[str, Col] + + @classmethod + def from_dict( + cls, resource_type: str, data: Mapping[str, Any] + ) -> "ResourceSpec": + """Build a :class:`ResourceSpec` from a YAML-style dict. + + Expected shape:: + + { + table: "patient", + columns: { + patient_id: { locate: ["id"], required: true }, + birth_date: { locate: ["birthDate"] }, + ... + }, + } + + Args: + resource_type: FHIR resourceType string this spec describes + (e.g. ``"Patient"``). Used only for error messages. + data: Mapping containing ``table`` (required, str) and + ``columns`` (required, mapping of column name -> Col-shaped + mapping). + """ + if not isinstance(data, Mapping): + raise ValueError( + f"resource_specs.{resource_type}: expected a mapping, " + f"got {type(data).__name__}." + ) + table = data.get("table") + if not isinstance(table, str) or not table: + raise ValueError( + f"resource_specs.{resource_type}: missing required field " + f"'table' (string)." + ) + raw_columns = data.get("columns") + if not isinstance(raw_columns, Mapping) or not raw_columns: + raise ValueError( + f"resource_specs.{resource_type}: missing required field " + f"'columns' (non-empty mapping)." + ) + columns: Dict[str, Col] = {} + for col_name, col_data in raw_columns.items(): + columns[str(col_name)] = Col.from_dict( + col_data, + ctx=f"resource_specs.{resource_type}.columns.{col_name}", + ) + return cls(table=str(table), columns=columns) + + +def load_resource_specs_from_yaml( + raw: Mapping[str, Any], +) -> Dict[str, ResourceSpec]: + """Build the spec registry from a parsed YAML's ``resource_specs:`` block. + + Args: + raw: The full parsed YAML mapping (top-level dict). The + ``resource_specs`` key, if present, must be a mapping of FHIR + resourceType -> ResourceSpec-shaped dict. + + Returns: + Insertion-ordered mapping of resourceType to :class:`ResourceSpec`. + + Raises: + ValueError: If the ``resource_specs`` block is missing, empty, or + contains a malformed entry. + """ + block = raw.get("resource_specs") + if not isinstance(block, Mapping) or not block: + raise ValueError( + "config: missing or empty top-level 'resource_specs:' block. " + "Declare at least one FHIR resourceType -> spec mapping." + ) + specs: Dict[str, ResourceSpec] = {} + for resource_type, data in block.items(): + specs[str(resource_type)] = ResourceSpec.from_dict( + str(resource_type), data + ) + return specs + + +def tables_from_specs(specs: Mapping[str, ResourceSpec]) -> List[str]: + """Ordered, de-duplicated list of output table names declared by *specs*.""" + seen: Dict[str, None] = {} + for spec in specs.values(): + seen.setdefault(spec.table, None) + return list(seen.keys()) + + +def columns_from_specs(specs: Mapping[str, ResourceSpec]) -> Dict[str, List[str]]: + """Map each output table name to its ordered column names.""" + return {spec.table: list(spec.columns.keys()) for spec in specs.values()} + + +def table_file_name(table_name: str, output_format: str = "parquet") -> str: + """Filename for a flattened table given the output format.""" + ext = "parquet" if output_format == "parquet" else output_format + return f"{table_name}.{ext}" + + +# --------------------------------------------------------------------------- +# Datetime helpers (kept for external callers) +# --------------------------------------------------------------------------- + + +def parse_dt(s: Optional[str]) -> Optional[datetime]: + if not s: + return None + try: + dt = datetime.fromisoformat(s.replace("Z", "+00:00")) + except ValueError: + dt = None + if dt is None and len(s) >= 10: + try: + dt = datetime.strptime(s[:10], "%Y-%m-%d") + except ValueError: + return None + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +def as_naive(dt: Optional[datetime]) -> Optional[datetime]: + if dt is None: + return None + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +# --------------------------------------------------------------------------- +# FHIR JSON helpers +# --------------------------------------------------------------------------- + + +def _coding_key(coding: Dict[str, Any]) -> str: + return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}" + + +def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]: + """CodeableConcept -> ``"system|code"`` for its first coding (or None).""" + if not isinstance(obj, dict): + return None + codings = obj.get("coding") or [] + if not codings and "concept" in obj: + codings = (obj.get("concept") or {}).get("coding") or [] + return _coding_key(codings[0]) if codings else None + + +def _ref_id(ref: Optional[Any]) -> Optional[str]: + """``{"reference": "Patient/p1"}`` or ``"Patient/p1"`` -> ``"p1"``.""" + if isinstance(ref, dict): + ref = ref.get("reference") + if not ref: + return None + return ref.rsplit("/", 1)[-1] if "/" in ref else ref + + +def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]: + """Map Patient.deceasedBoolean to stored "true"/"false"/None. + + FHIR JSON uses real booleans; some exports use strings. Python's + bool("false") is True, so we must not coerce with bool(). + """ + if value is None: + return None + if value is True: + return "true" + if value is False: + return "false" + if isinstance(value, str): + key = value.strip().lower() + if key in ("true", "1", "yes", "y", "t"): + return "true" + if key in ("false", "0", "no", "n", "f", ""): + return "false" + return None + if isinstance(value, (int, float)) and not isinstance(value, bool): + if value == 0: + return "false" + if value == 1: + return "true" + return None + return None + + +def _medication_concept_key(value: Any) -> Optional[str]: + """MedicationRequest medication[x] -> a stable concept key. + + Accepts either a ``medicationCodeableConcept`` (-> ``"system|code"``) or a + ``medicationReference`` (-> ``"MedicationRequest/reference|"``). + """ + if not isinstance(value, dict): + return None + if "coding" in value or "concept" in value: + key = _first_coding(value) + if key: + return key + ref = value.get("reference") + if ref: + return f"MedicationRequest/reference|{_ref_id(ref) or ref}" + return None + + +def _identity(value: Any) -> Optional[str]: + if value is None or isinstance(value, str): + return value + return str(value) + + +# Transform registry: how a located leaf becomes a flat scalar string. +TRANSFORMS = { + "identity": _identity, + "ref_id": _ref_id, + "coding_key": _first_coding, + "bool_norm": _normalize_deceased_boolean_for_storage, + "med_concept": _medication_concept_key, +} + + +def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]: + if not isinstance(raw, dict): + return None + resource = raw.get("resource") if "resource" in raw else raw + return resource if isinstance(resource, dict) else None + + +def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource).""" + if "entry" in obj: + for entry in obj.get("entry") or []: + resource = entry.get("resource") + if isinstance(resource, dict): + yield resource + return + resource = _unwrap_resource_dict(obj) + if resource is not None: + yield resource + + +def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]: + """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file.""" + opener = ( + gzip.open(path, "rt", encoding="utf-8", errors="replace") + if path.suffix == ".gz" + else open(path, encoding="utf-8", errors="replace") + ) + with opener as stream: + for line in stream: + line = line.strip() + if not line: + continue + parsed = orjson.loads(line) + if isinstance(parsed, dict): + yield parsed + + +# --------------------------------------------------------------------------- +# Generic extraction engine +# --------------------------------------------------------------------------- + + +def _get_path(obj: Any, path: str) -> Any: + """Walk a dotted path (e.g. ``"encounter.reference"``) safely; None if absent.""" + cur = obj + for part in path.split("."): + if not isinstance(cur, dict): + return None + cur = cur.get(part) + return cur + + +def _first_located(resource: Dict[str, Any], paths: Tuple[str, ...]) -> Any: + """First non-null value among the ordered ``paths`` (choice-type resolution).""" + for path in paths: + value = _get_path(resource, path) + if value is not None: + return value + return None + + +def flatten_resource( + resource: Dict[str, Any], + specs: Mapping[str, ResourceSpec], +) -> Optional[Tuple[str, Dict[str, Optional[str]]]]: + """Project one FHIR resource into ``(table_name, row)`` via *specs*. + + Returns ``None`` if the resource type is not configured in *specs*, or if a + column marked ``required`` cannot be resolved (a dropped/corrupted resource). + """ + spec = specs.get(resource.get("resourceType")) + if spec is None: + return None + row: Dict[str, Optional[str]] = {} + for name, col in spec.columns.items(): + raw = _first_located(resource, col.locate) + if raw is None and col.required: + return None + row[name] = TRANSFORMS[col.transform](raw) + return spec.table, row + + +# --------------------------------------------------------------------------- +# Tabular writer (parquet / csv / tsv) +# --------------------------------------------------------------------------- + + +def _table_schema(columns: Sequence[str]) -> pa.Schema: + return pa.schema([(col, pa.string()) for col in columns]) + + +class _BufferedTableWriter: + """Buffered, streaming writer for one flat table in parquet/csv/tsv.""" + + def __init__( + self, + path: Path, + schema: pa.Schema, + output_format: str = "parquet", + batch_size: int = 50_000, + ) -> None: + self.path = path + self.schema = schema + self.output_format = output_format + self.batch_size = batch_size + self.rows: List[Dict[str, Any]] = [] + self._pq_writer: Optional[pq.ParquetWriter] = None + self._fh = None + self._csv_header_written = False + self._delimiter = "\t" if output_format == "tsv" else "," + self.path.parent.mkdir(parents=True, exist_ok=True) + + def add(self, row: Dict[str, Any]) -> None: + self.rows.append(row) + if len(self.rows) >= self.batch_size: + self.flush() + + def flush(self) -> None: + if not self.rows: + return + table = pa.Table.from_pylist(self.rows, schema=self.schema) + if self.output_format == "parquet": + if self._pq_writer is None: + self._pq_writer = pq.ParquetWriter(str(self.path), self.schema) + self._pq_writer.write_table(table) + else: + if self._fh is None: + self._fh = open(self.path, "wb") + pa_csv.write_csv( + table, + self._fh, + write_options=pa_csv.WriteOptions( + include_header=not self._csv_header_written, + delimiter=self._delimiter, + ), + ) + self._csv_header_written = True + self.rows.clear() + + def close(self) -> None: + self.flush() + if self.output_format == "parquet": + if self._pq_writer is None: + pq.write_table( + pa.Table.from_pylist([], schema=self.schema), str(self.path) + ) + else: + self._pq_writer.close() + return + if self._fh is None: + # Empty table: still write a header-only file for a stable schema. + self._fh = open(self.path, "wb") + pa_csv.write_csv( + pa.Table.from_pylist([], schema=self.schema), + self._fh, + write_options=pa_csv.WriteOptions( + include_header=True, delimiter=self._delimiter + ), + ) + self._fh.close() + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]: + """Return sorted unique file paths under root matching glob pattern(s). + + Args: + root: Root directory to search under. + glob_pattern: Single glob string or sequence of glob strings. + + Returns: + Sorted list of matching files. Empty if no matches. + """ + patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern) + files: set[Path] = set() + for pat in patterns: + files.update(p for p in root.glob(pat) if p.is_file()) + return sorted(files, key=lambda p: str(p)) + + +def stream_fhir_ndjson_to_flat_tables( + root: Path, + glob_pattern: GlobPatternArg, + out_dir: Path, + specs: Mapping[str, ResourceSpec], + output_format: str = "parquet", +) -> None: + """Stream NDJSON resources into normalized per-resource flat tables. + + Resources are validated as they stream: anything whose type is not in + *specs*, or which is missing a ``required`` field, is dropped and counted; a + summary is logged at the end so corruption is visible rather than silent. + + Args: + root: Root directory containing NDJSON/NDJSON.GZ files. + glob_pattern: Single glob string or sequence of glob strings. + out_dir: Output directory for per-resource-type tables. + specs: Registry mapping FHIR resourceType -> :class:`ResourceSpec`. + output_format: One of :data:`SUPPORTED_OUTPUT_FORMATS`. + """ + if output_format not in SUPPORTED_OUTPUT_FORMATS: + raise ValueError( + f"Unsupported output_format {output_format!r}; " + f"expected one of {SUPPORTED_OUTPUT_FORMATS}." + ) + out_dir.mkdir(parents=True, exist_ok=True) + tables = tables_from_specs(specs) + columns = columns_from_specs(specs) + writers = { + name: _BufferedTableWriter( + path=out_dir / table_file_name(name, output_format), + schema=_table_schema(columns[name]), + output_format=output_format, + ) + for name in tables + } + + ingested: Counter = Counter() + dropped: Counter = Counter() + skipped_unconfigured: Counter = Counter() + try: + for file_path in sorted_ndjson_files(root, glob_pattern): + for ndjson_obj in iter_ndjson_objects(file_path): + for resource in iter_resources_from_ndjson_obj(ndjson_obj): + resource_type = resource.get("resourceType") + result = flatten_resource(resource, specs) + if result is None: + if resource_type in specs: + dropped[resource_type] += 1 + else: + skipped_unconfigured[resource_type] += 1 + continue + table_name, row = result + writers[table_name].add(row) + ingested[table_name] += 1 + finally: + for writer in writers.values(): + writer.close() + + logger.info( + "FHIR flatten complete (%s): %s", + output_format, + {name: ingested[name] for name in tables}, + ) + for resource_type, count in dropped.items(): + logger.warning( + "FHIR flatten: dropped %d %s resource(s) missing a required field " + "(e.g. patient reference).", + count, + resource_type, + ) + if skipped_unconfigured: + total = sum(skipped_unconfigured.values()) + logger.info( + "FHIR flatten: skipped %d resource(s) of %d unconfigured type(s): %s", + total, + len(skipped_unconfigured), + dict(skipped_unconfigured), + ) + + +def _scan_flat_table(path: Path, output_format: str) -> pl.LazyFrame: + if output_format == "parquet": + return pl.scan_parquet(str(path)) + sep = "\t" if output_format == "tsv" else "," + # infer_schema_length=0 keeps every column as Utf8 (flat tables are all strings). + return pl.scan_csv(str(path), separator=sep, infer_schema_length=0) + + +def sorted_patient_ids_from_flat_tables( + table_dir: Path, + tables: Sequence[str], + output_format: str = "parquet", +) -> List[str]: + """Return sorted unique patient IDs from a directory of flattened tables.""" + patient_path = table_dir / table_file_name("patient", output_format) + if "patient" in tables and patient_path.exists(): + return ( + _scan_flat_table(patient_path, output_format) + .select("patient_id") + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + frames = [ + _scan_flat_table( + table_dir / table_file_name(t, output_format), output_format + ).select("patient_id") + for t in tables + if t != "patient" + ] + return ( + pl.concat(frames) + .unique() + .sort("patient_id") + .collect(engine="streaming")["patient_id"] + .to_list() + ) + + +def filter_flat_tables_by_patient_ids( + source_dir: Path, + out_dir: Path, + keep_ids: Sequence[str], + tables: Sequence[str], + output_format: str = "parquet", +) -> None: + """Filter all flattened tables to only include rows for the given patient IDs.""" + out_dir.mkdir(parents=True, exist_ok=True) + keep_set = set(keep_ids) + for name in tables: + src = source_dir / table_file_name(name, output_format) + dst = out_dir / table_file_name(name, output_format) + lf = _scan_flat_table(src, output_format).filter( + pl.col("patient_id").is_in(keep_set) + ) + if output_format == "parquet": + lf.sink_parquet(str(dst)) + else: + sep = "\t" if output_format == "tsv" else "," + lf.sink_csv(str(dst), separator=sep) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..cba1a04c2 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -39,6 +39,7 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .ehrmamba import EHRMamba, MambaBlock +from .ehrmamba_cehr import EHRMambaCEHR from .vae import VAE from .vision_embedding import VisionEmbeddingModel from .text_embedding import TextEmbedding diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py new file mode 100644 index 000000000..7974a699e --- /dev/null +++ b/pyhealth/models/cehr_embeddings.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Vector Institute / Odyssey authors +# +# Derived from Odyssey (https://github.com/VectorInstitute/odyssey): +# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding +# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args. + +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch import nn + + +class TimeEmbeddingLayer(nn.Module): + """Embedding layer for time features (sinusoidal).""" + + def __init__(self, embedding_size: int, is_time_delta: bool = False): + super().__init__() + self.embedding_size = embedding_size + self.is_time_delta = is_time_delta + self.w = nn.Parameter(torch.empty(1, self.embedding_size)) + self.phi = nn.Parameter(torch.empty(1, self.embedding_size)) + nn.init.xavier_uniform_(self.w) + nn.init.xavier_uniform_(self.phi) + + def forward(self, time_stamps: torch.Tensor) -> torch.Tensor: + if self.is_time_delta: + time_stamps = torch.cat( + (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]), + dim=-1, + ) + time_stamps = time_stamps.float() + next_input = time_stamps.unsqueeze(-1) * self.w + self.phi + return torch.sin(next_input) + + +class VisitEmbedding(nn.Module): + """Embedding layer for visit segments.""" + + def __init__(self, visit_order_size: int, embedding_size: int): + super().__init__() + self.embedding = nn.Embedding(visit_order_size, embedding_size) + + def forward(self, visit_segments: torch.Tensor) -> torch.Tensor: + return self.embedding(visit_segments) + + +class MambaEmbeddingsForCEHR(nn.Module): + """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit).""" + + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int = 0, + type_vocab_size: int = 9, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_order_size: int = 3, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.pad_token_id = pad_token_id + self.type_vocab_size = type_vocab_size + self.max_num_visits = max_num_visits + self.word_embeddings = nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size) + self.time_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=True + ) + self.age_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, is_time_delta=False + ) + self.visit_segment_embeddings = VisitEmbedding( + visit_order_size=visit_order_size, embedding_size=hidden_size + ) + self.scale_back_concat_layer = nn.Linear( + hidden_size + 2 * time_embeddings_size, hidden_size + ) + self.tanh = nn.Tanh() + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids_batch: torch.Tensor, + time_stamps: torch.Tensor, + ages: torch.Tensor, + visit_orders: torch.Tensor, + visit_segments: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.word_embeddings(input_ids) + time_stamps_embeds = self.time_embeddings(time_stamps) + ages_embeds = self.age_embeddings(ages) + visit_segments_embeds = self.visit_segment_embeddings(visit_segments) + visit_order_embeds = self.visit_order_embeddings(visit_orders) + token_type_embeds = self.token_type_embeddings(token_type_ids_batch) + concat_in = torch.cat( + (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1 + ) + h = self.tanh(self.scale_back_concat_layer(concat_in)) + embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds + embeddings = self.dropout(embeddings) + return self.LayerNorm(embeddings) diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py new file mode 100644 index 000000000..cd555629c --- /dev/null +++ b/pyhealth/models/ehrmamba_cehr.py @@ -0,0 +1,117 @@ +"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from pyhealth.datasets import SampleDataset + +from .base_model import BaseModel +from .cehr_embeddings import MambaEmbeddingsForCEHR +from .ehrmamba import MambaBlock +from .utils import get_rightmost_masked_timestep + + +class EHRMambaCEHR(BaseModel): + """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline). + + Args: + dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema. + vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``). + embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings). + num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers. + pad_token_id: Padding id for masking (default 0). + state_size: SSM state size per channel. + conv_kernel: Causal conv kernel in each block. + dropout: Dropout before classifier. + """ + + def __init__( + self, + dataset: SampleDataset, + vocab_size: int, + embedding_dim: int = 128, + num_layers: int = 2, + pad_token_id: int = 0, + state_size: int = 16, + conv_kernel: int = 4, + dropout: float = 0.1, + type_vocab_size: int = 16, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_segment_vocab: int = 3, + ): + super().__init__(dataset=dataset) + self.embedding_dim = embedding_dim + self.num_layers = num_layers + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + + assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only" + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embeddings = MambaEmbeddingsForCEHR( + vocab_size=vocab_size, + hidden_size=embedding_dim, + pad_token_id=pad_token_id, + type_vocab_size=type_vocab_size, + max_num_visits=max_num_visits, + time_embeddings_size=time_embeddings_size, + visit_order_size=visit_segment_vocab, + ) + self.blocks = nn.ModuleList( + [ + MambaBlock( + d_model=embedding_dim, + state_size=state_size, + conv_kernel=conv_kernel, + ) + for _ in range(num_layers) + ] + ) + self.dropout = nn.Dropout(dropout) + out_dim = self.get_output_size() + self.fc = nn.Linear(embedding_dim, out_dim) + self._forecasting_head: Optional[nn.Module] = None + + def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]: + """Optional next-token / forecasting head (extension point; not implemented).""" + + return None + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + concept_ids = kwargs["concept_ids"].to(self.device).long() + token_type_ids = kwargs["token_type_ids"].to(self.device).long() + time_stamps = kwargs["time_stamps"].to(self.device).float() + ages = kwargs["ages"].to(self.device).float() + visit_orders = kwargs["visit_orders"].to(self.device).long() + visit_segments = kwargs["visit_segments"].to(self.device).long() + + x = self.embeddings( + input_ids=concept_ids, + token_type_ids_batch=token_type_ids, + time_stamps=time_stamps, + ages=ages, + visit_orders=visit_orders, + visit_segments=visit_segments, + ) + mask = concept_ids != self.pad_token_id + for blk in self.blocks: + x = blk(x) + pooled = get_rightmost_masked_timestep(x, mask) + logits = self.fc(self.dropout(pooled)) + y_true = kwargs[self.label_key].to(self.device).float() + if y_true.dim() == 1: + y_true = y_true.unsqueeze(-1) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py index 67edc010e..45cd6608d 100644 --- a/pyhealth/models/utils.py +++ b/pyhealth/models/utils.py @@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask): last_hidden_states = torch.gather(hidden_states, 1, last_visit) last_hidden_state = last_hidden_states[:, 0, :] return last_hidden_state + + +def get_rightmost_masked_timestep(hidden_states, mask): + """Gather hidden state at the last True position in ``mask`` per row. + + Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a + contiguous prefix; it picks the maximum index where ``mask`` is True. + Use for MPF / CEHR layouts where padding can appear between boundary tokens. + + Args: + hidden_states: ``[batch, seq_len, hidden_size]``. + mask: ``[batch, seq_len]`` bool. + + Returns: + Tensor ``[batch, hidden_size]``. + """ + if mask is None: + return hidden_states[:, -1, :] + batch, seq_len, hidden = hidden_states.shape + device = hidden_states.device + idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand( + batch, -1 + ) + idx_m = torch.where(mask, idx, torch.full_like(idx, -1)) + last_idx = idx_m.max(dim=1).values.clamp(min=0) + last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden) + gathered = torch.gather(hidden_states, 1, last_idx) + return gathered[:, 0, :] diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index b48072270..4568a5ece 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -50,6 +50,7 @@ def get_processor(name: str): from .ignore_processor import IgnoreProcessor from .temporal_timeseries_processor import TemporalTimeseriesProcessor from .tuple_time_text_processor import TupleTimeTextProcessor +from .cehr_processor import CehrProcessor, ConceptVocab # Expose public API from .base_processor import ( @@ -79,4 +80,6 @@ def get_processor(name: str): "GraphProcessor", "AudioProcessor", "TupleTimeTextProcessor", + "CehrProcessor", + "ConceptVocab", ] diff --git a/pyhealth/processors/cehr_processor.py b/pyhealth/processors/cehr_processor.py new file mode 100644 index 000000000..9199f51d7 --- /dev/null +++ b/pyhealth/processors/cehr_processor.py @@ -0,0 +1,175 @@ +"""Concept vocabulary and CEHR feature processor for FHIR timelines. + +Public API +---------- +ConceptVocab + Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serialisable. +ensure_special_tokens(vocab) + Add CEHR/MPF specials (````, ````, ````, ````) and + return their ids. +CehrProcessor + Standard :class:`~pyhealth.processors.FeatureProcessor` that maps a sample's + list of concept-key strings (already boundary-padded by the task) to a 1-D + ``torch.long`` tensor of token ids. Vocab growth happens inside the standard + ``SampleBuilder.fit(samples)`` loop -- no warm-up or freeze flag needed. + +The per-patient timeline-extraction helpers (`collect_cehr_timeline_events`, +`build_cehr_sequences`, `infer_mortality_label`, etc.) live with the task that +owns that logic: :mod:`pyhealth.tasks.mpf_clinical_prediction`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List + +import orjson +import torch + +from . import register_processor +from .base_processor import FeatureProcessor + +DEFAULT_PAD = 0 +DEFAULT_UNK = 1 +PAD_TOKEN = "" +UNK_TOKEN = "" + +__all__ = [ + "DEFAULT_PAD", + "DEFAULT_UNK", + "PAD_TOKEN", + "UNK_TOKEN", + "ConceptVocab", + "ensure_special_tokens", + "CehrProcessor", +] + + +# --------------------------------------------------------------------------- +# Vocabulary +# --------------------------------------------------------------------------- + + +@dataclass +class ConceptVocab: + """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1.""" + + token_to_id: Dict[str, int] = field(default_factory=dict) + pad_id: int = DEFAULT_PAD + unk_id: int = DEFAULT_UNK + _next_id: int = 2 + + def __post_init__(self) -> None: + if not self.token_to_id: + self.token_to_id = {PAD_TOKEN: self.pad_id, UNK_TOKEN: self.unk_id} + self._next_id = 2 + + def add_token(self, key: str) -> int: + if key in self.token_to_id: + return self.token_to_id[key] + tid = self._next_id + self.token_to_id[key] = tid + self._next_id += 1 + return tid + + def __getitem__(self, key: str) -> int: + return self.token_to_id.get(key, self.unk_id) + + @property + def vocab_size(self) -> int: + return self._next_id + + def to_json(self) -> Dict[str, Any]: + return { + "token_to_id": self.token_to_id, + "next_id": self._next_id, + "pad_id": self.pad_id, + "unk_id": self.unk_id, + } + + @classmethod + def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab": + pad_id = int(data.get("pad_id", DEFAULT_PAD)) + unk_id = int(data.get("unk_id", DEFAULT_UNK)) + vocab = cls(pad_id=pad_id, unk_id=unk_id) + loaded = dict(data.get("token_to_id") or {}) + if loaded: + vocab.token_to_id = loaded + vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1)) + else: + vocab._next_id = int(data.get("next_id", 2)) + return vocab + + def save(self, path: str) -> None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS)) + + @classmethod + def load(cls, path: str) -> "ConceptVocab": + return cls.from_json(orjson.loads(Path(path).read_bytes())) + + +def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]: + """Add EHRMamba/CEHR special tokens and return their ids.""" + return {name: vocab.add_token(name) for name in ("", "", "", "")} + + +# --------------------------------------------------------------------------- +# Processor +# --------------------------------------------------------------------------- + + +@register_processor("cehr") +class CehrProcessor(FeatureProcessor): + """Map a sample's list of concept-key strings to a 1-D LongTensor of ids. + + The task is expected to have already done all boundary-token insertion + (```` / ```` / ````) and left-padding with ````. This + processor's only state is a :class:`ConceptVocab`, grown during the + standard :meth:`~pyhealth.datasets.sample_dataset.SampleBuilder.fit` + pass over cached samples. + """ + + def __init__(self, max_len: int = 512) -> None: + self.vocab = ConceptVocab() + ensure_special_tokens(self.vocab) + self.max_len = max_len + + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> "CehrProcessor": + for sample in samples: + keys = sample.get(field) + if not keys: + continue + for key in keys: + if isinstance(key, str): + self.vocab.add_token(key) + return self + + def process(self, value: List[Any]) -> torch.Tensor: + ids = [ + self.vocab[k] if isinstance(k, str) else int(k) + for k in value + ] + return torch.tensor(ids, dtype=torch.long) + + def save(self, path: str) -> None: + self.vocab.save(path) + + def load(self, path: str) -> None: + self.vocab = ConceptVocab.load(path) + + def is_token(self) -> bool: + return True + + def schema(self) -> tuple[str, ...]: + return ("value",) + + def dim(self) -> tuple[int, ...]: + return (1,) + + def spatial(self) -> tuple[bool, ...]: + return (True,) + + def __repr__(self) -> str: + return f"CehrProcessor(max_len={self.max_len})" diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..3729cf5f5 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,11 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + + +def __getattr__(name: str): + if name == "MPFClinicalPredictionTask": + from .mpf_clinical_prediction import MPFClinicalPredictionTask + + return MPFClinicalPredictionTask + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py new file mode 100644 index 000000000..8e387b512 --- /dev/null +++ b/pyhealth/tasks/mpf_clinical_prediction.py @@ -0,0 +1,310 @@ +"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines. + +The task reads per-patient events via :meth:`pyhealth.data.Patient.get_events` +and :class:`~pyhealth.data.Event` attribute access (the standard PyHealth +idiom). It builds six aligned CEHR feature sequences, inserts MPF boundary +specials, and left-pads to ``max_len``. + +Concept-key → integer-id mapping happens later, inside the standard pipeline: +``SampleBuilder.fit`` walks the cached ``task_df.ld`` and fits a +:class:`~pyhealth.processors.CehrProcessor` on the ``concept_ids`` field; +that processor's vocab is then applied per sample by ``_proc_transform``. +The other five sequences are plain numeric lists handled by the standard +tensor processor. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import torch + +import polars as pl + +from pyhealth.data import Event, Patient +from pyhealth.processors.cehr_processor import PAD_TOKEN + +from .base_task import BaseTask + +__all__ = [ + "EVENT_TYPE_TO_TOKEN_TYPE", + "MPFClinicalPredictionTask", + "collect_cehr_timeline_events", + "infer_mortality_label", +] + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +EVENT_TYPE_TO_TOKEN_TYPE: Dict[str, int] = { + "encounter": 1, + "condition": 2, + "medication_request": 3, + "observation": 4, + "procedure": 5, +} + +_CLINICAL_EVENT_TYPES: Tuple[str, ...] = ( + "condition", + "observation", + "medication_request", + "procedure", +) + + +# --------------------------------------------------------------------------- +# Small pure helpers +# --------------------------------------------------------------------------- + + +def _deceased_boolean_column_means_dead(value: Any) -> bool: + """True only for an explicit ``"true"`` flag (not Python truthiness).""" + if value is None: + return False + return str(value).strip().lower() == "true" + + +def _encounter_concept_key(event: Any) -> str: + enc_class = getattr(event, "encounter_class", None) + if enc_class: + return f"encounter|{enc_class}" + return "encounter|unknown" + + +def _sequential_visit_idx_for_time( + event_time: Optional[datetime], + visit_encounters: List[Tuple[datetime, int]], +) -> int: + """Bucket an unlinked event into the nearest preceding encounter's index.""" + if not visit_encounters: + return 0 + if event_time is None: + return visit_encounters[-1][1] + chosen = visit_encounters[0][1] + for encounter_start, visit_idx in visit_encounters: + if encounter_start <= event_time: + chosen = visit_idx + else: + break + return chosen + + +def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]: + """Patient's birth date. + + The ``patient`` table's yaml entry declares ``timestamp: birth_date``, so + the Event's ``timestamp`` field is the birth date itself. + """ + events = patient.get_events(event_type="patient") + return events[0].timestamp if events else None + + +# --------------------------------------------------------------------------- +# Timeline extraction +# --------------------------------------------------------------------------- + + +def collect_cehr_timeline_events( + patient: Patient, +) -> List[Tuple[datetime, str, str, int]]: + """Collect ``(time, concept_key, event_type, visit_idx)`` tuples for one patient. + + Encounters define the visit boundaries. Clinical events that reference a + known encounter id are linked directly; events without a matching + encounter reference are bucketed into the chronologically nearest + preceding visit. + """ + # Only well-formed encounters (real id + non-null timestamp) define visit + # indices. We have to inspect the raw polars frame here: + # ``Event.__init__`` silently coerces ``timestamp=None`` to + # ``datetime.now()`` (data.py:43-45), so by the time we get back an Event + # we can no longer tell which encounters were timestamp-less. + encounters_df = patient.get_events(event_type="encounter", return_df=True) + valid_encounters = [ + Event.from_dict(row) + for row in encounters_df.filter( + pl.col("timestamp").is_not_null() + & pl.col("encounter/encounter_id").is_not_null() + ).iter_rows(named=True) + ] + + encounter_visit_idx: Dict[str, int] = {} + encounter_start_by_id: Dict[str, datetime] = {} + visit_encounters: List[Tuple[datetime, int]] = [] + for idx, enc in enumerate(valid_encounters): + enc_id = enc.encounter_id + encounter_visit_idx[enc_id] = idx + encounter_start_by_id[enc_id] = enc.timestamp + visit_encounters.append((enc.timestamp, idx)) + + events: List[Tuple[datetime, str, str, int]] = [] + unlinked: List[Tuple[Optional[datetime], str, str]] = [] + + for enc in valid_encounters: + events.append( + ( + enc.timestamp, + _encounter_concept_key(enc), + "encounter", + encounter_visit_idx[enc.encounter_id], + ) + ) + + for et in _CLINICAL_EVENT_TYPES: + for ev in patient.get_events(event_type=et): + concept_key = getattr(ev, "concept_key", None) or f"{et}|unknown" + enc_id = getattr(ev, "encounter_id", None) + t = ev.timestamp + if enc_id and enc_id in encounter_visit_idx: + if t is None: + t = encounter_start_by_id.get(enc_id) + if t is None: + continue + events.append((t, concept_key, et, encounter_visit_idx[enc_id])) + else: + unlinked.append((t, concept_key, et)) + + for t, concept_key, et in unlinked: + idx = _sequential_visit_idx_for_time(t, visit_encounters) + if t is None: + if not visit_encounters: + continue + # Use the start of the chosen visit; fall back to the latest encounter. + t = next( + (start for start, v_idx in visit_encounters if v_idx == idx), + visit_encounters[-1][0], + ) + events.append((t, concept_key, et, idx)) + + events.sort(key=lambda item: item[0]) + return events + + +# --------------------------------------------------------------------------- +# Label +# --------------------------------------------------------------------------- + + +def infer_mortality_label(patient: Patient) -> int: + """Heuristic binary mortality label from flattened patient rows.""" + for ev in patient.get_events(event_type="patient"): + if _deceased_boolean_column_means_dead(getattr(ev, "deceased_boolean", None)): + return 1 + if getattr(ev, "deceased_datetime", None): + return 1 + for ev in patient.get_events(event_type="condition"): + ck = (getattr(ev, "concept_key", None) or "").lower() + if any(token in ck for token in ("death", "deceased", "mortality")): + return 1 + return 0 + + +# --------------------------------------------------------------------------- +# Task +# --------------------------------------------------------------------------- + + +class MPFClinicalPredictionTask(BaseTask): + """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens. + + The task does timeline extraction and emits **raw** per-event lists, + including concept keys as strings. Tokenization is the + :class:`~pyhealth.processors.CehrProcessor`'s job, fit during the + standard ``SampleBuilder.fit(dataset)`` pass. + + Attributes: + max_len: Output sequence length (must be >= 2 for boundary tokens). + use_mpf: If True, prepend ```` to the sequence; else ````. + The closing ```` is always emitted. + """ + + task_name: str = "MPFClinicalPredictionFHIR" + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, max_len: int = 512, use_mpf: bool = True) -> None: + if max_len < 2: + raise ValueError("max_len must be >= 2 for MPF boundary tokens") + self.max_len = max_len + self.use_mpf = use_mpf + self.boundary_start = "" if use_mpf else "" + self.boundary_end = "" + self.input_schema: Dict[str, Any] = { + "concept_ids": ("cehr", {"max_len": max_len}), + "token_type_ids": ("tensor", {"dtype": torch.long}), + "time_stamps": ("tensor", {"dtype": torch.float32}), + "ages": ("tensor", {"dtype": torch.float32}), + "visit_orders": ("tensor", {"dtype": torch.long}), + "visit_segments": ("tensor", {"dtype": torch.long}), + } + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Build one labeled sample dict per patient.""" + timeline = collect_cehr_timeline_events(patient) + birth = _birth_datetime_from_patient(patient) + + clinical_cap = self.max_len - 2 + tail = timeline[-clinical_cap:] if clinical_cap > 0 else [] + base_time = tail[0][0] if tail else None + + # Build the six aligned sequences in a single pass. + keys: List[str] = [self.boundary_start] + token_types: List[int] = [0] + time_stamps: List[float] = [0.0] + ages: List[float] = [0.0] + vis_o: List[int] = [0] + vis_s: List[int] = [0] + + for event_time, concept_key, event_type, visit_idx in tail: + time_delta = ( + float((event_time - base_time).total_seconds()) + if base_time is not None and event_time is not None + else 0.0 + ) + age_years = ( + (event_time - birth).days / 365.25 + if birth is not None and event_time is not None + else 0.0 + ) + keys.append(concept_key) + token_types.append(EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0)) + time_stamps.append(time_delta) + ages.append(age_years) + vis_o.append(min(visit_idx, 511)) + vis_s.append(visit_idx % 2) + + keys.append(self.boundary_end) + token_types.append(0) + time_stamps.append(0.0) + ages.append(0.0) + vis_o.append(0) + vis_s.append(0) + + ml = self.max_len + keys = _left_pad(keys, ml, PAD_TOKEN) + token_types = _left_pad(token_types, ml, 0) + time_stamps = _left_pad(time_stamps, ml, 0.0) + ages = _left_pad(ages, ml, 0.0) + vis_o = _left_pad(vis_o, ml, 0) + vis_s = _left_pad(vis_s, ml, 0) + + return [ + { + "patient_id": patient.patient_id, + "concept_ids": keys, + "token_type_ids": token_types, + "time_stamps": time_stamps, + "ages": ages, + "visit_orders": vis_o, + "visit_segments": vis_s, + "label": infer_mortality_label(patient), + } + ] + + +def _left_pad(seq: List[Any], max_len: int, pad: Any) -> List[Any]: + if len(seq) >= max_len: + return seq[-max_len:] + return [pad] * (max_len - len(seq)) + seq diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py new file mode 100644 index 000000000..81c995f28 --- /dev/null +++ b/tests/core/test_ehrmamba_cehr.py @@ -0,0 +1,126 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import EHRMambaCEHR +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + + +def _tiny_samples(seq: int = 16) -> tuple: + """Build hand-crafted samples in the new task's emitted format. + + ``concept_ids`` carries raw string tokens (``""`` / ``""`` / a + filler concept key); the ``CehrProcessor`` registered via the task's + ``input_schema`` does the string → integer-id mapping during + ``SampleBuilder.fit``. + """ + task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True) + samples = [] + for lab in (0, 1): + samples.append( + { + "patient_id": f"p{lab}", + "visit_id": f"v{lab}", + "concept_ids": [""] + ["test|filler"] * (seq - 2) + [""], + "token_type_ids": [0] * seq, + "time_stamps": [0.0] * seq, + "ages": [50.0] * seq, + "visit_orders": [0] * seq, + "visit_segments": [0] * seq, + "label": lab, + } + ) + return samples, task + + +class TestEHRMambaCEHR(unittest.TestCase): + def test_readout_pools_rightmost_non_pad(self) -> None: + """MPF padding between tokens must not make pooling pick a pad position.""" + + from pyhealth.models.utils import ( + get_last_visit, + get_rightmost_masked_timestep, + ) + + h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]]) + m = torch.tensor([[True, True, False, True]]) + out = get_rightmost_masked_timestep(h, m) + self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0]))) + wrong = get_last_visit(h, m) + self.assertFalse(torch.allclose(out[0], wrong[0])) + + def test_end_to_end_fhir_pipeline(self) -> None: + import tempfile + from pathlib import Path + + from pyhealth.datasets import MIMIC4FHIR, create_sample_dataset + from pyhealth.datasets import get_dataloader + + from tests.core.test_fhir_ndjson_fixtures import run_task, write_two_class_ndjson + + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + with tempfile.TemporaryDirectory() as tmp: + write_two_class_ndjson(Path(tmp)) + ds = MIMIC4FHIR( + root=tmp, glob_pattern="*.ndjson", cache_dir=tmp + ) + samples = run_task(ds, task) + sample_ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name="fhir_test", + ) + vocab_size = sample_ds.input_processors["concept_ids"].vocab.vocab_size + model = EHRMambaCEHR( + dataset=sample_ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + ) + batch = next( + iter(get_dataloader(sample_ds, batch_size=2, shuffle=False)) + ) + out = model(**batch) + self.assertIn("loss", out) + out["loss"].backward() + + def test_forward_backward(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = ds.input_processors["concept_ids"].vocab.vocab_size + model = EHRMambaCEHR( + dataset=ds, + vocab_size=vocab_size, + embedding_dim=64, + num_layers=1, + state_size=8, + ) + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertEqual(out["logit"].shape[0], 2) + out["loss"].backward() + + def test_eval_mode(self) -> None: + samples, task = _tiny_samples() + ds = create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + ) + vocab_size = ds.input_processors["concept_ids"].vocab.vocab_size + model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1) + model.eval() + with torch.no_grad(): + batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False))) + out = model(**batch) + self.assertIn("y_prob", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_fhir_dataset.py b/tests/core/test_fhir_dataset.py new file mode 100644 index 000000000..235c9eeeb --- /dev/null +++ b/tests/core/test_fhir_dataset.py @@ -0,0 +1,764 @@ +import gzip +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Dict, List, Tuple + +import orjson +import polars as pl + +from pyhealth.data import Patient +from pyhealth.datasets import MIMIC4FHIR +from pyhealth.datasets.fhir.utils import ( + flatten_resource, + load_resource_specs_from_yaml, +) +from pyhealth.processors.cehr_processor import ConceptVocab +from pyhealth.tasks.mpf_clinical_prediction import ( + MPFClinicalPredictionTask, + collect_cehr_timeline_events, + infer_mortality_label, +) + + +def _mimic4_specs(): + """Load the bundled MIMIC4 ResourceSpec registry from its YAML.""" + import yaml as _yaml + with open(MIMIC4FHIR.DEFAULT_CONFIG_PATH, encoding="utf-8") as _f: + return load_resource_specs_from_yaml(_yaml.safe_load(_f)) + + +_MIMIC4_SPECS = _mimic4_specs() + + +def _flatten_resource_to_table_row(resource): + """Flatten one resource via the MIMIC4 spec registry (test convenience).""" + return flatten_resource(resource, _MIMIC4_SPECS) + + +def _clinical_slice(sample: Dict[str, object]) -> Tuple[List[str], List[int], List[int]]: + """Drop ```` and the leading/trailing boundary tokens from a sample. + + Returns the per-event ``(concept_keys, visit_orders, visit_segments)`` + lists for the patient's clinical events only. + """ + keys = list(sample["concept_ids"]) # type: ignore[arg-type] + v_o = list(sample["visit_orders"]) # type: ignore[arg-type] + v_s = list(sample["visit_segments"]) # type: ignore[arg-type] + non_pad = [ + (k, o, s) for k, o, s in zip(keys, v_o, v_s) if k != "" + ] + # Strip leading boundary (/) and trailing . + middle = non_pad[1:-1] if len(non_pad) >= 2 else [] + return ( + [k for k, _, _ in middle], + [o for _, o, _ in middle], + [s for _, _, s in middle], + ) + +from tests.core.test_fhir_ndjson_fixtures import ( + ndjson_two_class_text, + write_one_patient_ndjson, + write_two_class_ndjson, +) + + +def _third_patient_loinc_resources() -> List[Dict[str, object]]: + return [ + { + "resourceType": "Patient", + "id": "p-synth-3", + "birthDate": "1960-01-01", + }, + { + "resourceType": "Encounter", + "id": "e3", + "subject": {"reference": "Patient/p-synth-3"}, + "period": {"start": "2020-08-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o3", + "subject": {"reference": "Patient/p-synth-3"}, + "encounter": {"reference": "Encounter/e3"}, + "effectiveDateTime": "2020-08-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]}, + }, + ] + + +def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + lines = ndjson_two_class_text().strip().split("\n") + lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources()) + path = directory / name + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path + + +def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient: + """Build a Patient whose ``timestamp`` column is a real datetime, matching + the shape ``FHIRDataset.load_table`` produces in production. + """ + df = pl.DataFrame(rows).with_columns( + pl.col("timestamp").str.to_datetime(strict=False) + ) + return Patient(patient_id=patient_id, data_source=df) + + +class TestDeceasedBooleanFlattening(unittest.TestCase): + def test_string_false_not_coerced_by_python_bool(self) -> None: + """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``.""" + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-false", + "deceasedBoolean": "false", + } + ) + self.assertIsNotNone(row) + _table, payload = row + self.assertEqual(payload.get("deceased_boolean"), "false") + + def test_string_true_parsed(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-str-true", + "deceasedBoolean": "true", + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), "true") + + def test_json_booleans_unchanged(self) -> None: + for raw, expected in ((True, "true"), (False, "false")): + with self.subTest(raw=raw): + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-bool", + "deceasedBoolean": raw, + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1].get("deceased_boolean"), expected) + + def test_unknown_deceased_type_stored_as_none(self) -> None: + row = _flatten_resource_to_table_row( + { + "resourceType": "Patient", + "id": "p-garbage", + "deceasedBoolean": {"unexpected": "object"}, + } + ) + self.assertIsNotNone(row) + self.assertIsNone(row[1].get("deceased_boolean")) + + def test_infer_mortality_respects_string_false_row(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "event_type": "patient", + "timestamp": "2020-01-01T00:00:00", + "patient/deceased_boolean": "false", + }, + ], + ) + self.assertEqual(infer_mortality_label(patient), 0) + + +class TestFHIRDataset(unittest.TestCase): + def test_concept_vocab_from_json_empty_token_to_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}}) + self.assertIn("", v.token_to_id) + self.assertIn("", v.token_to_id) + self.assertEqual(v._next_id, 2) + + def test_concept_vocab_from_json_empty_respects_next_id(self) -> None: + v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50}) + self.assertEqual(v._next_id, 50) + + def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None: + from pyhealth.datasets.fhir.utils import sorted_ndjson_files + + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8") + (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8") + (root / "notes.txt").write_text("z", encoding="utf-8") + wide = sorted_ndjson_files(root, "**/*.ndjson.gz") + narrow = sorted_ndjson_files( + root, + ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"], + ) + self.assertEqual(len(wide), 2) + self.assertEqual(len(narrow), 1) + self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz") + + def test_dataset_accepts_glob_patterns_kwarg(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + write_one_patient_ndjson(Path(tmp)) + ds = MIMIC4FHIR( + root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp + ) + self.assertEqual(ds.glob_patterns, ["*.ndjson"]) + + def test_dataset_rejects_both_glob_kwargs(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + MIMIC4FHIR( + root=tmp, + glob_pattern="*.ndjson", + glob_patterns=["*.ndjson"], + cache_dir=tmp, + ) + + def test_disk_ingest_gz_and_max_patients(self) -> None: + """gzip ingest path + ``max_patients`` cap, covered in one build. + + The heavier build/schema/set_task/pre_filter assertions now live in + ``TestFHIRSharedWorkflow`` (one shared build), so this is the only + ingest-variant build left in this class. + """ + with tempfile.TemporaryDirectory() as tmp: + gz_path = Path(tmp) / "fixture.ndjson.gz" + with gzip.open(gz_path, "wt", encoding="utf-8") as gz: + gz.write(ndjson_two_class_text()) + ds = MIMIC4FHIR(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5) + self.assertEqual(len(ds.unique_patient_ids), 2) + + def test_encounter_reference_requires_exact_id(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-02T10:00:00", + "encounter/encounter_id": "e10", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-02T11:00:00", + "condition/encounter_id": "e10", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + self.assertEqual( + sample["concept_ids"].count("http://hl7.org/fhir/sid/icd-10-cm|I99"), + 1, + ) + + def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "ea", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "eb", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + self.assertEqual( + sample["concept_ids"].count("http://hl7.org/fhir/sid/icd-10-cm|Z00"), + 1, + ) + + def test_max_len_two_keeps_only_boundary_tokens(self) -> None: + """``max_len=2`` leaves room for only the two boundary tokens; the + clinical timeline is truncated away. + """ + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=2, use_mpf=True)(patient)[0] + self.assertEqual(sample["concept_ids"], ["", ""]) + self.assertEqual(sample["visit_segments"], [0, 0]) + + def test_visit_segments_alternate_by_visit_index(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e0", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e0", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-07-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-07-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + _, _, visit_segments = _clinical_slice(sample) + self.assertEqual(visit_segments, [0, 0, 1, 1]) + + def test_unlinked_visit_idx_matches_sequential_counter(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": None, + "encounter/encounter_id": "e_bad", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-03-01T10:00:00", + "encounter/encounter_id": "e_ok", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-05T11:00:00", + "condition/encounter_id": "e_ok", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-03-15T12:00:00", + "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + keys = sample["concept_ids"] + i_link = keys.index("http://hl7.org/fhir/sid/icd-10-cm|I10") + i_free = keys.index("http://hl7.org/fhir/sid/icd-10-cm|Z00") + self.assertEqual(sample["visit_orders"][i_link], sample["visit_orders"][i_free]) + self.assertEqual(sample["visit_segments"][i_link], sample["visit_segments"][i_free]) + + def test_medication_request_uses_medication_codeable_concept(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T12:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + keys = sample["concept_ids"] + ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111" + kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222" + self.assertEqual(keys.count(ka), 1) + self.assertEqual(keys.count(kb), 1) + + def test_medication_request_medication_reference_token(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "IMP", + }, + { + "patient_id": "p1", + "event_type": "medication_request", + "timestamp": "2020-06-01T11:00:00", + "medication_request/encounter_id": "e1", + "medication_request/concept_key": "MedicationRequest/reference|med-abc", + }, + ], + ) + sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0] + key = "MedicationRequest/reference|med-abc" + self.assertIn(key, sample["concept_ids"]) + self.assertEqual(sample["concept_ids"].count(key), 1) + + def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None: + patient = _patient_from_rows( + "p1", + [ + { + "patient_id": "p1", + "event_type": "patient", + "timestamp": None, + "patient/birth_date": "1950-01-01", + }, + { + "patient_id": "p1", + "event_type": "encounter", + "timestamp": "2020-06-01T10:00:00", + "encounter/encounter_id": "e1", + "encounter/encounter_class": "AMB", + }, + { + "patient_id": "p1", + "event_type": "condition", + "timestamp": "2020-06-01T11:00:00", + "condition/encounter_id": "e1", + "condition/concept_key": "a|1", + }, + { + "patient_id": "p1", + "event_type": "observation", + "timestamp": "2020-06-01T12:00:00", + "observation/encounter_id": "e1", + "observation/concept_key": "b|2", + }, + ], + ) + events = collect_cehr_timeline_events(patient) + self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"]) + + def test_observation_effective_period_start_yields_event_time(self) -> None: + """Choice-type fix: an Observation carrying only ``effectivePeriod.start`` + (no ``effectiveDateTime``) must still resolve a non-null event_time. + The pre-refactor extractor silently dropped this variant. + """ + row = _flatten_resource_to_table_row( + { + "resourceType": "Observation", + "id": "o-period", + "subject": {"reference": "Patient/p"}, + "effectivePeriod": {"start": "2022-02-02T00:00:00Z"}, + "code": {"coding": [{"system": "http://loinc.org", "code": "1-1"}]}, + } + ) + self.assertIsNotNone(row) + self.assertEqual(row[1]["event_time"], "2022-02-02T00:00:00Z") + + def test_new_resource_type_via_registry_flows_through(self) -> None: + """A resource type absent from MIMIC4's specs flows end-to-end purely by + adding a YAML entry — no engine change. + + Also exercises the directly-usable generic ``FHIRDataset`` (whole + ingest contract authored in a single YAML, no subclass). + """ + from pyhealth.datasets import FHIRDataset + + resources = [ + {"resourceType": "Patient", "id": "imm-1", "birthDate": "1970-01-01"}, + { + "resourceType": "Immunization", + "id": "i1", + "patient": {"reference": "Patient/imm-1"}, + "occurrenceDateTime": "2021-01-01T00:00:00Z", + "vaccineCode": { + "coding": [{"system": "http://hl7.org/fhir/sid/cvx", "code": "208"}] + }, + }, + ] + config_yaml = ( + "version: test\n" + "resource_specs:\n" + " Patient:\n" + " table: patient\n" + " columns:\n" + " patient_id: { locate: [id], required: true }\n" + " birth_date: { locate: [birthDate] }\n" + " Immunization:\n" + " table: immunization\n" + " columns:\n" + " patient_id: { locate: [patient.reference], transform: ref_id, required: true }\n" + " resource_id: { locate: [id] }\n" + " encounter_id: { locate: [encounter.reference], transform: ref_id }\n" + " event_time: { locate: [occurrenceDateTime, recorded] }\n" + " concept_key: { locate: [vaccineCode], transform: coding_key }\n" + "tables:\n" + " patient:\n" + " file_path: patient.parquet\n" + " patient_id: patient_id\n" + " timestamp: birth_date\n" + " attributes: [birth_date]\n" + " immunization:\n" + " file_path: immunization.parquet\n" + " patient_id: patient_id\n" + " timestamp: event_time\n" + " attributes: [resource_id, encounter_id, event_time, concept_key]\n" + ) + with tempfile.TemporaryDirectory() as tmp: + tmpp = Path(tmp) + (tmpp / "fx.ndjson").write_text( + "\n".join(orjson.dumps(r).decode("utf-8") for r in resources) + "\n", + encoding="utf-8", + ) + cfg = tmpp / "immun.yaml" + cfg.write_text(config_yaml, encoding="utf-8") + ds = FHIRDataset( + root=str(tmpp), + config_path=str(cfg), + glob_pattern="*.ndjson", + cache_dir=str(tmpp), + ) + df = ds.global_event_df.collect(engine="streaming") + self.assertIn("immunization/concept_key", df.columns) + keys = ( + df.filter(pl.col("event_type") == "immunization")[ + "immunization/concept_key" + ] + .to_list() + ) + self.assertIn("http://hl7.org/fhir/sid/cvx|208", keys) + + def test_fhir_dataset_requires_specs(self) -> None: + """Bare ``FHIRDataset`` (no specs, no subclass) errors clearly.""" + from pyhealth.datasets import FHIRDataset + + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp) + + +class TestFHIRSharedWorkflow(unittest.TestCase): + """Build the dataset and run ``set_task`` ONCE, then assert over the shared + artifacts. Mirrors a realistic "ingest once, do many things" workflow and + keeps the suite fast: a single Dask build (plus one canonical ``set_task``) + shared by every assertion, instead of rebuilding per test. + """ + + @classmethod + def setUpClass(cls) -> None: + cls._tmp = tempfile.mkdtemp() + write_two_class_plus_third_ndjson(Path(cls._tmp)) + cls.ds = MIMIC4FHIR( + root=cls._tmp, glob_pattern="*.ndjson", cache_dir=cls._tmp, num_workers=1 + ) + # The single Dask build for the whole class. + cls.global_df = cls.ds.global_event_df.collect(engine="streaming") + # The canonical set_task (reuses the build above; no rebuild). + cls.sample_ds = cls.ds.set_task( + MPFClinicalPredictionTask(max_len=48, use_mpf=True), num_workers=1 + ) + cls.samples = sorted( + [cls.sample_ds[i] for i in range(len(cls.sample_ds))], + key=lambda s: s["patient_id"], + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmp, ignore_errors=True) + + def test_build_produces_expected_tables_and_schema(self) -> None: + """Flat parquet tables exist, and the global event frame has the + expected long-format + namespaced columns with a patient's events. + """ + prepared = self.ds.prepared_tables_dir + for name in ("patient", "encounter", "condition", "observation"): + self.assertTrue((prepared / f"{name}.parquet").is_file()) + for col in ( + "patient_id", + "timestamp", + "event_type", + "condition/concept_key", + "observation/concept_key", + "patient/deceased_boolean", + ): + self.assertIn(col, self.global_df.columns) + sub = self.global_df.filter(pl.col("patient_id") == "p-synth-1") + self.assertGreaterEqual(len(sub), 2) + + def test_set_task_builds_vocab(self) -> None: + vocab = self.sample_ds.input_processors["concept_ids"].vocab + self.assertGreater(vocab.vocab_size, 6) + + def test_set_task_produces_correct_samples(self) -> None: + self.assertEqual(len(self.samples), 3) + self.assertEqual( + {s["patient_id"] for s in self.samples}, + {"p-synth-1", "p-synth-2", "p-synth-3"}, + ) + for s in self.samples: + self.assertIn("concept_ids", s) + self.assertIn("label", s) + self.assertEqual({int(s["label"]) for s in self.samples}, {0, 1}) + + def test_cehr_sequence_shapes(self) -> None: + patient = self.ds.get_patient("p-synth-1") + sample = MPFClinicalPredictionTask(max_len=32, use_mpf=True)(patient)[0] + n = len(sample["concept_ids"]) + self.assertEqual(n, 32) + for key in ( + "token_type_ids", "time_stamps", "ages", "visit_orders", "visit_segments", + ): + self.assertEqual(len(sample[key]), n) + non_special = { + k for k in sample["concept_ids"] + if k not in ("", "", "", "") + } + self.assertGreater(len(non_special), 0) + + def test_mpf_pre_filter_single_patient_limits_effective_workers(self) -> None: + """Pre-filter yielding one patient caps effective_workers to 1 (the + formula is verified directly; a 1-patient ``set_task`` would raise on a + single label class). + """ + class OnePatientMPFTask(MPFClinicalPredictionTask): + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter(pl.col("patient_id") == "p-synth-1") + + warmup_pids = ( + OnePatientMPFTask(max_len=48, use_mpf=True) + .pre_filter(self.ds.global_event_df) + .select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + .sort() + .to_list() + ) + self.assertEqual(warmup_pids, ["p-synth-1"]) + effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1 + self.assertEqual(effective_workers, 1) + + def test_mpf_pre_filter_excludes_dropped_patients_from_vocab(self) -> None: + """A task ``pre_filter`` that drops a patient also drops their concept + keys from the fitted vocab. Reuses the shared build; a distinct + ``task_name`` keeps this run's sample cache separate from the canonical + ``set_task`` in setUpClass (identical params would otherwise collide on + the shared ``cache_dir``). + """ + class TwoPatientMPFTask(MPFClinicalPredictionTask): + task_name = "MPFClinicalPredictionFHIR_prefilter" + + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter( + pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"]) + ) + + sample_ds = self.ds.set_task( + TwoPatientMPFTask(max_len=48, use_mpf=True), num_workers=1 + ) + vocab = sample_ds.input_processors["concept_ids"].vocab + self.assertNotIn("http://loinc.org|999-9", vocab.token_to_id) + self.assertIn("http://loinc.org|789-0", vocab.token_to_id) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_fhir_ndjson_fixtures.py b/tests/core/test_fhir_ndjson_fixtures.py new file mode 100644 index 000000000..7311d4802 --- /dev/null +++ b/tests/core/test_fhir_ndjson_fixtures.py @@ -0,0 +1,110 @@ +"""NDJSON file bodies for :mod:`tests.core.test_fhir_dataset` (disk-only ingest).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List + +import orjson + + +# --------------------------------------------------------------------------- +# Synthetic in-memory FHIR resources +# --------------------------------------------------------------------------- + + +def _one_patient_resources() -> List[Dict[str, Any]]: + return [ + {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"}, + { + "resourceType": "Encounter", + "id": "e1", + "subject": {"reference": "Patient/p-synth-1"}, + "period": {"start": "2020-06-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Condition", + "id": "c1", + "subject": {"reference": "Patient/p-synth-1"}, + "encounter": {"reference": "Encounter/e1"}, + "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]}, + "onsetDateTime": "2020-06-01T11:00:00Z", + }, + ] + + +def _two_patient_resources() -> List[Dict[str, Any]]: + return [ + *_one_patient_resources(), + {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True}, + { + "resourceType": "Encounter", + "id": "e-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "period": {"start": "2020-07-01T10:00:00Z"}, + "class": {"code": "IMP"}, + }, + { + "resourceType": "Observation", + "id": "o-dead", + "subject": {"reference": "Patient/p-synth-2"}, + "encounter": {"reference": "Encounter/e-dead"}, + "effectiveDateTime": "2020-07-01T12:00:00Z", + "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]}, + }, + ] + + +# --------------------------------------------------------------------------- +# Text serialisers +# --------------------------------------------------------------------------- + + +def ndjson_one_patient_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in _one_patient_resources()) + "\n" + + +def ndjson_two_class_text() -> str: + return "\n".join(orjson.dumps(r).decode("utf-8") for r in _two_patient_resources()) + "\n" + + +# --------------------------------------------------------------------------- +# Disk writers +# --------------------------------------------------------------------------- + + +def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_two_class_text(), encoding="utf-8") + return path + + +def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path: + path = directory / name + path.write_text(ndjson_one_patient_text(), encoding="utf-8") + return path + + +# --------------------------------------------------------------------------- +# Shared test helper +# --------------------------------------------------------------------------- + + +def run_task(ds: Any, task: Any) -> List[Dict[str, Any]]: + """Run *task* over every patient in *ds* without the LitData caching pipeline. + + This helper mirrors the direct-iteration path that the old + ``FHIRDataset.gather_samples`` provided. It is intentionally kept + here (the shared fixture module) so all FHIR test files can import it + rather than each maintaining their own copy. + + Args: + ds: A :class:`~pyhealth.datasets.FHIRDataset` instance whose + ``global_event_df`` has already been built. + task: A :class:`~pyhealth.tasks.MPFClinicalPredictionTask` instance. + + Returns: + Flat list of sample dicts, one per patient. + """ + return [s for patient in ds.iter_patients() for s in task(patient)] diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py new file mode 100644 index 000000000..a1b261719 --- /dev/null +++ b/tests/core/test_mpf_task.py @@ -0,0 +1,99 @@ +import shutil +import tempfile +import unittest +from pathlib import Path + +from pyhealth.datasets import MIMIC4FHIR +from pyhealth.processors.cehr_processor import PAD_TOKEN +from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask + +from tests.core.test_fhir_ndjson_fixtures import ( + run_task, + write_two_class_ndjson, +) + + +class TestMPFClinicalPredictionTask(unittest.TestCase): + """Verifies the task emits boundary-marker strings at the expected + positions in its raw output. Vocab → integer-id mapping is the + ``CehrProcessor``'s job and is exercised separately via the standard + ``SampleBuilder.fit`` pipeline. + """ + + @classmethod + def setUpClass(cls) -> None: + cls._tmp = tempfile.mkdtemp() + write_two_class_ndjson(Path(cls._tmp)) + # One shared build for the whole class; tests reuse it via run_task + # (run_task just applies the task to cached patients — no rebuild). + cls.ds = MIMIC4FHIR( + root=cls._tmp, glob_pattern="*.ndjson", cache_dir=cls._tmp + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmp, ignore_errors=True) + + def test_max_len_validation(self) -> None: + with self.assertRaises(ValueError): + MPFClinicalPredictionTask(max_len=1, use_mpf=True) + + def test_mpf_sets_boundary_tokens(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=True) + samples = run_task(self.ds, task) + self.assertGreater(len(samples), 0) + keys = samples[0]["concept_ids"] + first = next(i for i, x in enumerate(keys) if x != PAD_TOKEN) + last_nz = next( + i for i in range(len(keys) - 1, -1, -1) if keys[i] != PAD_TOKEN + ) + self.assertEqual(keys[first], "") + self.assertEqual(keys[last_nz], "") + self.assertEqual(keys[-1], "") + + def test_no_mpf_uses_cls_reg(self) -> None: + task = MPFClinicalPredictionTask(max_len=32, use_mpf=False) + samples = run_task(self.ds, task) + keys = samples[0]["concept_ids"] + first = next(i for i, x in enumerate(keys) if x != PAD_TOKEN) + last_nz = next( + i for i in range(len(keys) - 1, -1, -1) if keys[i] != PAD_TOKEN + ) + self.assertEqual(keys[first], "") + self.assertEqual(keys[last_nz], "") + self.assertEqual(keys[-1], "") + + def test_schema_keys(self) -> None: + task = MPFClinicalPredictionTask(max_len=16, use_mpf=True) + samples = run_task(self.ds, task) + for k in task.input_schema: + self.assertIn(k, samples[0]) + self.assertIn("label", samples[0]) + + def test_max_len_two_keeps_boundary_tokens(self) -> None: + """At ``max_len=2`` the sequence is exactly ``[, ]``.""" + + task = MPFClinicalPredictionTask(max_len=2, use_mpf=True) + samples = run_task(self.ds, task) + for s in samples: + keys = s["concept_ids"] + self.assertEqual(len(keys), 2) + self.assertEqual(keys[0], "") + self.assertEqual(keys[1], "") + + def test_fixed_length_alignment(self) -> None: + """All six per-event lists must be the same length (max_len).""" + + task = MPFClinicalPredictionTask(max_len=24, use_mpf=True) + samples = run_task(self.ds, task) + for s in samples: + self.assertEqual(len(s["concept_ids"]), 24) + self.assertEqual(len(s["token_type_ids"]), 24) + self.assertEqual(len(s["time_stamps"]), 24) + self.assertEqual(len(s["ages"]), 24) + self.assertEqual(len(s["visit_orders"]), 24) + self.assertEqual(len(s["visit_segments"]), 24) + + +if __name__ == "__main__": + unittest.main() From 7ce0206e1a7b065c576bcb595dad840dddcd1bc1 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sat, 30 May 2026 19:43:25 -0500 Subject: [PATCH 2/4] fix --- examples/mimic4fhir_mpf_ehrmamba.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py index d6d83ad88..33df598af 100644 --- a/examples/mimic4fhir_mpf_ehrmamba.py +++ b/examples/mimic4fhir_mpf_ehrmamba.py @@ -19,9 +19,14 @@ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask from pyhealth.trainer import Trainer -REPO_ROOT = Path(__file__).resolve().parents[3] -DEMO_ROOT = REPO_ROOT / "datasets" / "physionet.org" / "mimic-iv-fhir-demo" / "2.1.0" / "fhir" -CACHE_DIR = REPO_ROOT / "datasets" / ".cache" / "pyhealth" / "fhir-demo" +# Absolute paths to the bundled PhysioNet MIMIC-IV-on-FHIR demo and its cache. +DEMO_ROOT = Path( + "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/" + "physionet.org/mimic-iv-fhir-demo/2.1.0/fhir" +) +CACHE_DIR = Path( + "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/.cache/pyhealth/fhir-demo" +) def main() -> None: From f620aeb837db0863236c78c0f9032af8af4c9eb4 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sat, 30 May 2026 22:34:51 -0500 Subject: [PATCH 3/4] fix unit test using fast json readers --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..1dac44895 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "dask[complete]~=2025.11.0", "litdata~=0.2.59", "pyarrow~=22.0.0", + "orjson~=3.10", "narwhals~=2.13.0", "more-itertools~=10.8.0", "einops>=0.8.0", From b82c6448a92f8cae529e8467d93eb87d225b9959 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 31 May 2026 21:45:00 +0000 Subject: [PATCH 4/4] Replace editdistance with rapidfuzz for Python 3.13 compatibility editdistance 0.8.1 only ships cp311 wheels and has no Python 3.13 binary, causing CI installs to fail on Linux. rapidfuzz>=3.0.0 ships wheels for all major platforms including cp313 and provides an equivalent Levenshtein.distance() API. https://claude.ai/code/session_01L5qHpvAZQSgmZyc6tMTX6d --- pyhealth/nlp/metrics.py | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyhealth/nlp/metrics.py b/pyhealth/nlp/metrics.py index 667a6665b..a61ec3a2d 100644 --- a/pyhealth/nlp/metrics.py +++ b/pyhealth/nlp/metrics.py @@ -353,13 +353,13 @@ class LevenshteinDistanceScoreMethod(ScoreMethod): """ @classmethod def _get_external_modules(cls: Type) -> Tuple[str, ...]: - return ('editdistance~=0.8.1',) + return ('rapidfuzz>=3.0.0',) def _score(self, meth: str, context: ScoreContext) -> Iterable[FloatScore]: - import editdistance + from rapidfuzz.distance import Levenshtein for s1, s2 in context.pairs: - val: int = editdistance.eval(s1, s2) + val: int = Levenshtein.distance(s1, s2) if self.normalize: text_len: int = max(len(s1), len(s2)) val = 1. - (val / text_len) diff --git a/pyproject.toml b/pyproject.toml index 1dac44895..65e0e2757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ graph = [ "torch-geometric>=2.6.0", ] nlp = [ - "editdistance~=0.8.1", + "rapidfuzz>=3.0.0", "rouge_score~=0.1.2", "nltk~=3.9.1", ]