diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 8d9a59d21..ad34a7a0a 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -224,6 +224,8 @@ Available Datasets
datasets/pyhealth.datasets.SampleDataset
datasets/pyhealth.datasets.MIMIC3Dataset
datasets/pyhealth.datasets.MIMIC4Dataset
+ datasets/pyhealth.datasets.FHIRDataset
+ datasets/pyhealth.datasets.MIMIC4FHIR
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.CardiologyDataset
datasets/pyhealth.datasets.eICUDataset
diff --git a/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst b/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst
new file mode 100644
index 000000000..16dbefc13
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.FHIRDataset.rst
@@ -0,0 +1,306 @@
+pyhealth.datasets.FHIRDataset
+=====================================
+
+A generic, config-driven NDJSON ingest for `HL7 FHIR
+`_ datasets. The whole pipeline is described by **a
+single YAML config** with three top-level sections — what files to read, how to
+turn each FHIR resource into a flat row, and how those rows appear as events
+downstream. A custom FHIR ingest is "point at a YAML" — no Python required.
+
+The bundled :class:`~pyhealth.datasets.MIMIC4FHIR` subclass uses this engine
+with the ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` config tuned for
+PhysioNet's MIMIC-IV on FHIR export. See the sub-page below for the quick-start.
+
+.. contents:: On this page
+ :local:
+ :depth: 1
+
+
+Quick start
+-----------
+
+.. code-block:: python
+
+ from pyhealth.datasets import MIMIC4FHIR, get_dataloader, split_by_patient
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+ from pyhealth.models import EHRMambaCEHR
+ from pyhealth.trainer import Trainer
+
+ def main():
+ ds = MIMIC4FHIR(root="/data/mimic-iv-fhir")
+ sample_ds = ds.set_task(MPFClinicalPredictionTask(), num_workers=1)
+ train, val, test = split_by_patient(sample_ds, [0.7, 0.1, 0.2])
+ vocab_size = sample_ds.input_processors["concept_ids"].vocab.vocab_size
+ model = EHRMambaCEHR(dataset=sample_ds, vocab_size=vocab_size)
+ Trainer(model=model).train(
+ train_dataloader=get_dataloader(train, batch_size=8, shuffle=True),
+ val_dataloader=get_dataloader(val, batch_size=8),
+ epochs=2,
+ )
+
+ if __name__ == "__main__":
+ main()
+
+(``if __name__ == "__main__":`` matters — :meth:`~pyhealth.datasets.BaseDataset.set_task`
+forks Dask workers; without the guard the workers re-import and re-spawn.)
+
+
+Pipeline at a glance
+--------------------
+
+::
+
+ NDJSON shards on disk
+ |
+ | (Phase A) — stream line by line, route by resourceType,
+ | project via the YAML's resource_specs
+ v
+ flattened_tables/
.parquet <- cache #1
+ |
+ | (Phase B) — load_table, dd.concat, sort by patient_id (Dask)
+ v
+ global_event_df.parquet/part-*.parquet <- cache #2
+ |
+ | (Phase C) — task_transform per-patient sample emit
+ v
+ task_df.ld/ <- cache #3a
+ |
+ | fit CehrProcessor vocab via SampleBuilder.fit(dataset)
+ | proc_transform per-sample tensorisation
+ v
+ samples_*.ld/ <- cache #3b ──> SampleDataset
+
+Each of the three cache tiers has its own existence check; re-running with
+identical inputs skips every phase. Cache identity hashes the YAML byte digest,
+glob patterns, ``max_patients``, and engine schema version — any meaningful
+config change invalidates everything below it. See
+:class:`~pyhealth.datasets.BaseDataset` for the Phase B/C internals that are
+shared with all other PyHealth datasets.
+
+
+The unified YAML config
+-----------------------
+
+A FHIR ingest YAML has three top-level sections. The bundled
+``mimic4fhir.yaml`` is the canonical worked example; what follows is the
+section-by-section reference.
+
+Section 1: ``glob_patterns:`` (which files to read)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: yaml
+
+ glob_patterns:
+ - "**/MimicPatient*.ndjson.gz"
+ - "**/MimicEncounter*.ndjson.gz"
+ # ... one pattern per resource-type shard family
+
+Defaults to ``["**/*.ndjson.gz"]`` when omitted. Only worth setting when your
+export has a per-resource-type file-naming convention you want to exploit for
+speed — PhysioNet MIMIC-IV FHIR ships shards as ``MimicPatient*.ndjson.gz``,
+``MimicEncounter*.ndjson.gz``, etc., and filtering at the file level avoids
+decompressing ~10% of the export that contains only unconfigured resource
+types. For a generic export where everything is in ``bundles.ndjson.gz``, omit
+this block and the streamer will filter by ``resourceType`` after parsing.
+
+Override at runtime via ``MIMIC4FHIR(glob_pattern=...)`` or
+``MIMIC4FHIR(glob_patterns=[...])``.
+
+Section 2: ``resource_specs:`` (how to project JSON into rows)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Keys are FHIR ``resourceType`` strings. For each, declare a ``table`` name and
+an ordered ``columns`` mapping:
+
+.. code-block:: yaml
+
+ resource_specs:
+
+ Patient:
+ table: patient
+ columns:
+ patient_id: { locate: ["id"], required: true }
+ birth_date: { locate: ["birthDate"] }
+ gender: { locate: ["gender"] }
+ deceased_boolean: { locate: ["deceasedBoolean"], transform: bool_norm }
+
+ Observation:
+ table: observation
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["encounter.reference"], transform: ref_id }
+ event_time: { locate: ["effectiveDateTime", "effectivePeriod.start", "issued"] }
+ concept_key: { locate: ["code"], transform: coding_key }
+
+Each column entry has three fields:
+
+``locate`` *(required, list of dotted paths)*
+ Ordered JSON paths into the resource; the first that resolves to a non-null
+ value wins. This is how FHIR choice-types (``onset[x]``, ``effective[x]``,
+ ``performed[x]``, …) are handled — list every variant explicitly. A single
+ string is accepted as shorthand for a one-element list.
+
+``transform`` *(optional, name of a built-in transform, default ``identity``)*
+ Maps the located leaf to a flat scalar string. See the registry below.
+
+``required`` *(optional, bool, default false)*
+ When ``true``, a resource whose ``locate`` cannot be resolved is **dropped**
+ (and logged) rather than emitted with a null. Use this on the patient
+ reference column so events without a discoverable patient never reach the
+ global event frame.
+
+Transform registry
+^^^^^^^^^^^^^^^^^^
+
+Available transforms (defined in
+``pyhealth/datasets/fhir/utils.py`` ``TRANSFORMS`` dict):
+
+================== ===========================================================
+``identity`` Pass the value through. Stringifies non-string scalars.
+``ref_id`` Reference object or ``"Patient/p1"`` -> ``"p1"``.
+``coding_key`` CodeableConcept -> ``"system|code"`` of its first coding.
+``bool_norm`` JSON boolean / ``"true"``/``"false"`` -> ``"true"``/``"false"``/None.
+``med_concept`` MedicationRequest medication[x] -> codeable-concept or
+ ``"MedicationRequest/reference|"`` fallback.
+================== ===========================================================
+
+Adding a new transform is a one-liner: register a callable in ``TRANSFORMS``
+in ``utils.py`` and reference it by name from the YAML.
+
+Section 3: ``tables:`` (how rows are exposed as events)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Keys here must match the ``table:`` values from Section 2. Each entry tells
+:meth:`~pyhealth.datasets.BaseDataset.load_table` how to read the flat parquet:
+
+.. code-block:: yaml
+
+ tables:
+ patient:
+ file_path: "patient.parquet"
+ patient_id: "patient_id"
+ timestamp: "birth_date"
+ attributes: ["birth_date", "gender", "deceased_boolean"]
+
+ observation:
+ file_path: "observation.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes: ["resource_id", "encounter_id", "event_time", "concept_key"]
+
+``file_path`` is the parquet filename inside the cached
+``flattened_tables/`` directory. ``patient_id`` and ``timestamp`` name the
+columns to surface as the normalised ``patient_id`` and ``timestamp`` on each
+event. ``attributes`` is the list of columns surfaced as event attributes — in
+the global event frame they're renamed to ``{table}/{attr}`` and later show up
+on ``patient.get_events(event_type=...).attr_name``.
+
+Cross-section validation
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+At load time the dataset checks that every ``table:`` value declared in
+Section 2 has a matching ``tables.`` block in Section 3. Typos surface
+as a config error at startup, not silent empty parquets.
+
+
+Customising for a non-MIMIC FHIR export
+---------------------------------------
+
+Step 1 — write your YAML.
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Copy ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` and adapt the
+``resource_specs:`` and ``tables:`` blocks for the resources you care about.
+For an export that adds Immunizations:
+
+.. code-block:: yaml
+
+ resource_specs:
+ Patient:
+ table: patient
+ columns:
+ patient_id: { locate: ["id"], required: true }
+ birth_date: { locate: ["birthDate"] }
+ Immunization:
+ table: immunization
+ columns:
+ patient_id: { locate: ["patient.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ event_time: { locate: ["occurrenceDateTime", "recorded"] }
+ concept_key: { locate: ["vaccineCode"], transform: coding_key }
+
+ tables:
+ patient:
+ file_path: "patient.parquet"
+ patient_id: "patient_id"
+ timestamp: "birth_date"
+ attributes: ["birth_date"]
+ immunization:
+ file_path: "immunization.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes: ["resource_id", "event_time", "concept_key"]
+
+Step 2 — instantiate
+~~~~~~~~~~~~~~~~~~~~
+
+Either pass ``config_path=...`` directly:
+
+.. code-block:: python
+
+ from pyhealth.datasets import FHIRDataset
+
+ ds = FHIRDataset(
+ root="/data/my_fhir_export",
+ config_path="/path/to/my_export.yaml",
+ )
+
+or write a 3-line subclass that bundles your config:
+
+.. code-block:: python
+
+ from pyhealth.datasets import FHIRDataset
+
+ class MyFHIR(FHIRDataset):
+ DEFAULT_CONFIG_PATH = "/path/to/my_export.yaml"
+
+ ds = MyFHIR(root="/data/my_fhir_export")
+
+Step 3 — that's it.
+~~~~~~~~~~~~~~~~~~~
+
+Everything downstream — :meth:`~pyhealth.datasets.BaseDataset.set_task`,
+:meth:`~pyhealth.datasets.BaseDataset.iter_patients`,
+:meth:`~pyhealth.datasets.BaseDataset.get_patient` — works the same as for any
+other PyHealth dataset.
+
+
+Notes on resource use
+---------------------
+
+Streaming ingest avoids loading the whole NDJSON corpus into RAM, but downstream
+steps still scale with cohort size. For a **smoke run** the bundled example
+fixtures fit on any laptop. For a **laptop-scale real subset**, set
+``max_patients=`` and/or narrow ``glob_patterns`` to keep cache and task passes
+manageable; ≥16 GB system RAM is a comfort target for Polars + the trainer.
+For the **full PhysioNet export**, prefer fast SSD, large disk, and plenty of
+RAM — total work scales with the corpus size even if RAM ingest is bounded.
+
+
+Bundled FHIR datasets
+---------------------
+
+.. toctree::
+ :maxdepth: 1
+
+ pyhealth.datasets.MIMIC4FHIR
+
+
+API reference
+-------------
+
+.. autoclass:: pyhealth.datasets.FHIRDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst
new file mode 100644
index 000000000..344f60cf7
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.MIMIC4FHIR.rst
@@ -0,0 +1,78 @@
+pyhealth.datasets.MIMIC4FHIR
+============================
+
+A pre-bundled :class:`~pyhealth.datasets.FHIRDataset` for the PhysioNet
+`MIMIC-IV on FHIR `_ export
+(R4, demo 2.1.0 and full release). All ingest logic — file globs, per-resource
+projection, downstream event schema — is described by the bundled YAML at
+``pyhealth/datasets/fhir/configs/mimic4fhir.yaml``; this class only points at
+that path.
+
+For everything outside the MIMIC-specific defaults (transform registry,
+``Col`` / ``ResourceSpec`` syntax, the three-tier cache story), see the parent
+page: :doc:`pyhealth.datasets.FHIRDataset`.
+
+Quick start
+-----------
+
+.. code-block:: python
+
+ from pyhealth.datasets import MIMIC4FHIR
+ from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ def main():
+ ds = MIMIC4FHIR(root="/data/mimic-iv-fhir")
+ sample_ds = ds.set_task(MPFClinicalPredictionTask(), num_workers=1)
+ # ... split / dataloader / model / trainer ...
+
+ if __name__ == "__main__":
+ main()
+
+For the full end-to-end demo (training EHR-Mamba on MPF samples) see
+``examples/mimic4fhir_mpf_ehrmamba.py``.
+
+Resource coverage
+-----------------
+
+The bundled config flattens six FHIR resource types out of the PhysioNet
+export:
+
+========================== ============================ ===============================
+FHIR resourceType Output table Key columns
+========================== ============================ ===============================
+``Patient`` ``patient.parquet`` ``patient_id``, ``birth_date``, ``gender``, ``deceased_*``
+``Encounter`` ``encounter.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``encounter_class``
+``Condition`` ``condition.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key``
+``Observation`` ``observation.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key``
+``MedicationRequest`` ``medication_request.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key``
+``Procedure`` ``procedure.parquet`` ``patient_id``, ``encounter_id``, ``event_time``, ``concept_key``
+========================== ============================ ===============================
+
+PhysioNet shards that contain only other resource types
+(``MedicationAdministration``, ``Specimen``, ``Organization``, …) are skipped
+at the file level by the bundled ``glob_patterns``. To include them, override
+``glob_patterns=`` at the constructor and add a ``resource_specs:`` entry plus
+matching ``tables:`` entry in a copy of the YAML.
+
+Customising
+-----------
+
+The bundled config is the easiest starting point for authoring a similar ingest
+for other FHIR exports. Copy
+``pyhealth/datasets/fhir/configs/mimic4fhir.yaml``, edit the
+``resource_specs:`` and ``tables:`` blocks for the resources you care about,
+and either:
+
+* pass ``config_path=...`` directly to ``FHIRDataset(root=..., config_path=...)``, or
+* subclass ``FHIRDataset`` and set ``DEFAULT_CONFIG_PATH`` on the subclass.
+
+See the "Customising for a non-MIMIC FHIR export" section of
+:doc:`pyhealth.datasets.FHIRDataset` for the step-by-step.
+
+API reference
+-------------
+
+.. autoclass:: pyhealth.datasets.MIMIC4FHIR
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7c3ac7c4b..3a29c8773 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -186,6 +186,7 @@ API Reference
models/pyhealth.models.MoleRec
models/pyhealth.models.Deepr
models/pyhealth.models.EHRMamba
+ models/pyhealth.models.EHRMambaCEHR
models/pyhealth.models.JambaEHR
models/pyhealth.models.ContraWR
models/pyhealth.models.SparcNet
diff --git a/docs/api/models/pyhealth.models.EHRMambaCEHR.rst b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
new file mode 100644
index 000000000..c15a09962
--- /dev/null
+++ b/docs/api/models/pyhealth.models.EHRMambaCEHR.rst
@@ -0,0 +1,12 @@
+pyhealth.models.EHRMambaCEHR
+===================================
+
+EHRMambaCEHR applies CEHR-style embeddings (:class:`~pyhealth.models.cehr_embeddings.MambaEmbeddingsForCEHR`)
+and a stack of :class:`~pyhealth.models.MambaBlock` layers to a single FHIR token stream, for use with
+:class:`~pyhealth.tasks.mpf_clinical_prediction.MPFClinicalPredictionTask` and
+:class:`~pyhealth.datasets.fhir.FHIRDataset`.
+
+.. autoclass:: pyhealth.models.EHRMambaCEHR
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 23a4e06e5..17d07026b 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -214,6 +214,7 @@ Available Tasks
Drug Recommendation
Length of Stay Prediction
Medical Transcriptions Classification
+ MPF Clinical Prediction (FHIR)
Mortality Prediction (Next Visit)
Mortality Prediction (StageNet MIMIC-IV)
Patient Linkage (MIMIC-III)
diff --git a/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
new file mode 100644
index 000000000..569f75a38
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.mpf_clinical_prediction.rst
@@ -0,0 +1,12 @@
+pyhealth.tasks.mpf_clinical_prediction
+======================================
+
+Multitask Prompted Fine-tuning (MPF) style binary clinical prediction on FHIR
+token timelines, paired with :class:`~pyhealth.datasets.FHIRDataset` and
+:class:`~pyhealth.models.EHRMambaCEHR`. Based on CEHR / EHRMamba ideas; see the
+paper linked in the course replication PR.
+
+.. autoclass:: pyhealth.tasks.MPFClinicalPredictionTask
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/mimic4fhir_mpf_ehrmamba.py b/examples/mimic4fhir_mpf_ehrmamba.py
new file mode 100644
index 000000000..33df598af
--- /dev/null
+++ b/examples/mimic4fhir_mpf_ehrmamba.py
@@ -0,0 +1,61 @@
+"""EHRMambaCEHR on the local MIMIC-IV FHIR demo.
+
+Barebones path: Dataset -> task -> model -> trainer -> evaluate.
+
+Runs against the bundled demo at
+``datasets/physionet.org/mimic-iv-fhir-demo/2.1.0/fhir`` and persists the
+flattened-table cache under ``datasets/.cache/pyhealth/fhir-demo`` so a
+second run hits the cache.
+
+ PYTHONPATH=. python examples/mimic4fhir_mpf_ehrmamba.py
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from pyhealth.datasets import MIMIC4FHIR, get_dataloader, split_by_patient
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+from pyhealth.trainer import Trainer
+
+# Absolute paths to the bundled PhysioNet MIMIC-IV-on-FHIR demo and its cache.
+DEMO_ROOT = Path(
+ "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/"
+ "physionet.org/mimic-iv-fhir-demo/2.1.0/fhir"
+)
+CACHE_DIR = Path(
+ "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/.cache/pyhealth/fhir-demo"
+)
+
+
+def main() -> None:
+ dataset = MIMIC4FHIR(root=str(DEMO_ROOT), cache_dir=str(CACHE_DIR))
+ sample_dataset = dataset.set_task(MPFClinicalPredictionTask(), num_workers=1)
+
+ train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2])
+ train_loader = get_dataloader(train_ds, batch_size=8, shuffle=True)
+ val_loader = get_dataloader(val_ds, batch_size=8, shuffle=False)
+ test_loader = get_dataloader(test_ds, batch_size=8, shuffle=False)
+
+ vocab_size = sample_dataset.input_processors["concept_ids"].vocab.vocab_size
+ model = EHRMambaCEHR(
+ dataset=sample_dataset,
+ vocab_size=vocab_size,
+ embedding_dim=32,
+ num_layers=2,
+ dropout=0.1,
+ )
+
+ trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"])
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ epochs=2,
+ monitor="roc_auc",
+ )
+ print(trainer.evaluate(test_loader))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 50b1b3887..c29955e7d 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -59,6 +59,7 @@ def __init__(self, *args, **kwargs):
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
+from .fhir import FHIRDataset, MIMIC4FHIR
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .physionet_deid import PhysioNetDeIDDataset
diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py
index 0e4280aab..54de1b961 100644
--- a/pyhealth/datasets/base_dataset.py
+++ b/pyhealth/datasets/base_dataset.py
@@ -420,10 +420,16 @@ def create_tmpdir(self) -> Path:
return tmp_dir
def clean_tmpdir(self) -> None:
- """Cleans up the temporary directory within the cache."""
+ """Cleans up the temporary directory within the cache.
+
+ ``ignore_errors=True`` tolerates polars/pyarrow stream-writer
+ finalizers that may still be flushing into ``flattened_fhir_tables/``
+ when we get here. The tmp dir is not load-bearing -- leftover bytes
+ will be reclaimed on the next ``clean_tmpdir`` or by the OS.
+ """
tmp_dir = self.cache_dir / "tmp"
if tmp_dir.exists():
- shutil.rmtree(tmp_dir)
+ shutil.rmtree(tmp_dir, ignore_errors=True)
def _scan_csv_tsv_gz(self, source_path: str) -> dd.DataFrame:
"""Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame.
diff --git a/pyhealth/datasets/fhir/__init__.py b/pyhealth/datasets/fhir/__init__.py
new file mode 100644
index 000000000..adcbbea85
--- /dev/null
+++ b/pyhealth/datasets/fhir/__init__.py
@@ -0,0 +1,16 @@
+"""FHIR datasets: a generic engine + per-source subclasses.
+
+- :class:`~pyhealth.datasets.fhir.base.FHIRDataset` — generic, config-driven base.
+- :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR` — MIMIC-IV-on-FHIR (R4).
+- :mod:`~pyhealth.datasets.fhir.utils` — the stateless flattening engine
+ (``Col``, ``ResourceSpec``, ``flatten_resource``, …).
+
+Authors:
+ John Wu and Evan Febrianto
+"""
+
+from .base import FHIRDataset
+from .mimic4 import MIMIC4FHIR
+from .utils import Col, ResourceSpec
+
+__all__ = ["FHIRDataset", "MIMIC4FHIR", "Col", "ResourceSpec"]
diff --git a/pyhealth/datasets/fhir/base.py b/pyhealth/datasets/fhir/base.py
new file mode 100644
index 000000000..517d062bb
--- /dev/null
+++ b/pyhealth/datasets/fhir/base.py
@@ -0,0 +1,415 @@
+"""Generic FHIR ingestion using flattened resource tables.
+
+Architecture
+------------
+1. Stream NDJSON/NDJSON.GZ FHIR resources from disk.
+2. Normalize each resource type into a 2D table via a declarative
+ :class:`~pyhealth.datasets.fhir.utils.ResourceSpec` registry
+ (``self.resource_specs``) — see :mod:`~pyhealth.datasets.fhir.utils`.
+3. Feed those tables through the standard YAML-driven
+ :class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task
+ processing operates on :class:`~pyhealth.data.Patient` and
+ ``global_event_df`` rows.
+
+``FHIRDataset`` is generic: it owns the streaming/cache/validation machinery but
+no specific resource specs or config. Use it directly by passing
+``resource_specs=`` + ``config_path=``, or subclass it for a concrete source
+(e.g. :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`) that bakes those in as
+class attributes.
+
+Authors:
+ John Wu and Evan Febrianto
+"""
+
+from __future__ import annotations
+
+import functools
+import hashlib
+import logging
+import operator
+import shutil
+import uuid
+from pathlib import Path
+from typing import Any, Dict, List, Mapping, Optional, Sequence
+
+import dask.dataframe as dd
+import narwhals as nw
+import orjson
+import pandas as pd
+import platformdirs
+from yaml import safe_load
+
+from ..base_dataset import BaseDataset
+from .utils import (
+ FHIR_SCHEMA_VERSION,
+ SUPPORTED_OUTPUT_FORMATS,
+ ResourceSpec,
+ filter_flat_tables_by_patient_ids,
+ load_resource_specs_from_yaml,
+ sorted_patient_ids_from_flat_tables,
+ stream_fhir_ndjson_to_flat_tables,
+ table_file_name,
+ tables_from_specs,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def read_fhir_settings_yaml(path: str) -> Dict[str, Any]:
+ with open(path, encoding="utf-8") as stream:
+ data = safe_load(stream)
+ return data if isinstance(data, dict) else {}
+
+
+def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series:
+ if getattr(part.dtype, "tz", None) is not None:
+ part = part.dt.tz_localize(None)
+ return part.astype("datetime64[ms]")
+
+
+class FHIRDataset(BaseDataset):
+ """FHIR resources flattened into per-type tables, then the standard pipeline.
+
+ Streams raw FHIR NDJSON/NDJSON.GZ exports into flattened tables (one per
+ configured resource type) and pipelines them through
+ :class:`~pyhealth.datasets.BaseDataset` for downstream task processing
+ (global event dataframe, patient iteration, task sampling).
+
+ The entire ingest is driven by a single YAML config with three top-level
+ sections — ``glob_patterns:`` (which NDJSON files to open),
+ ``resource_specs:`` (how to project each FHIR resource type into a flat
+ row), and ``tables:`` (how those rows are exposed as events downstream).
+ See ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` for a complete
+ worked example and the FHIRDataset rst page for a section-by-section guide.
+
+ Pass ``config_path=...`` directly, or subclass and set
+ ``DEFAULT_CONFIG_PATH`` to bundle a default (see
+ :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`).
+
+ Args:
+ root: Path to the NDJSON/NDJSON.GZ export directory.
+ config_path: Path to the FHIR ingest YAML. Defaults to the class
+ attribute ``DEFAULT_CONFIG_PATH``. The YAML must contain a
+ ``resource_specs:`` block; any ``glob_patterns:`` and ``tables:``
+ blocks are also read from here.
+ glob_pattern: Single glob for NDJSON files; overrides the YAML's
+ ``glob_patterns``. Mutually exclusive with *glob_patterns*.
+ glob_patterns: Multiple glob patterns; overrides the YAML's
+ ``glob_patterns``. Mutually exclusive with *glob_pattern*.
+ output_format: Flat-table format, one of ``parquet`` (default),
+ ``csv``, ``tsv``. Defaults to the class attribute
+ ``DEFAULT_OUTPUT_FORMAT``.
+ max_patients: Limit ingest to the first *N* unique patient IDs.
+ ingest_num_shards: Ignored; retained for API compatibility.
+ cache_dir: Cache directory root (UUID subdir appended per config).
+ num_workers: Worker processes for task sampling.
+ dev: Development mode; limits to 1000 patients if *max_patients* is
+ ``None``.
+
+ Examples:
+ >>> # ad-hoc, no subclass
+ >>> ds = FHIRDataset(
+ ... root="/data/fhir",
+ ... config_path="my_fhir.yaml",
+ ... )
+ >>> # or a preconfigured source subclass
+ >>> from pyhealth.datasets import MIMIC4FHIR
+ >>> ds = MIMIC4FHIR(root="/data/mimic-iv-fhir", max_patients=500)
+ """
+
+ #: Default ingest YAML path; set by source subclasses to bundle a config.
+ DEFAULT_CONFIG_PATH: Optional[str] = None
+ #: Default flat-table output format.
+ DEFAULT_OUTPUT_FORMAT: str = "parquet"
+ #: Dataset name used for cache identity / logging.
+ DATASET_NAME: str = "fhir"
+
+ def __init__(
+ self,
+ root: str,
+ config_path: Optional[str] = None,
+ glob_pattern: Optional[str] = None,
+ glob_patterns: Optional[Sequence[str]] = None,
+ output_format: Optional[str] = None,
+ max_patients: Optional[int] = None,
+ ingest_num_shards: Optional[int] = None,
+ cache_dir: Optional[str | Path] = None,
+ num_workers: int = 1,
+ dev: bool = False,
+ ) -> None:
+ del ingest_num_shards
+
+ resolved_config = config_path or type(self).DEFAULT_CONFIG_PATH
+ if resolved_config is None:
+ raise ValueError(
+ "FHIRDataset requires config_path: pass config_path=... or use a "
+ "subclass that defines DEFAULT_CONFIG_PATH."
+ )
+ self._fhir_config_path = str(Path(resolved_config).resolve())
+ self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path)
+
+ # Section 2 of the YAML: how each FHIR resource type projects into a row.
+ self.resource_specs: Mapping[str, ResourceSpec] = (
+ load_resource_specs_from_yaml(self._fhir_settings)
+ )
+
+ # Cross-validate: every table the specs declare must have a downstream
+ # `tables:` block (Section 3). Catches typos at startup.
+ spec_tables = set(tables_from_specs(self.resource_specs))
+ declared_tables = set((self._fhir_settings.get("tables") or {}).keys())
+ missing = spec_tables - declared_tables
+ if missing:
+ raise ValueError(
+ f"config {self._fhir_config_path}: resource_specs references "
+ f"table(s) {sorted(missing)} not declared in the 'tables:' "
+ f"block. Add a matching tables. entry (patient_id, "
+ f"timestamp, attributes) for each."
+ )
+
+ self.output_format = output_format or type(self).DEFAULT_OUTPUT_FORMAT
+ if self.output_format not in SUPPORTED_OUTPUT_FORMATS:
+ raise ValueError(
+ f"Unsupported output_format {self.output_format!r}; "
+ f"expected one of {SUPPORTED_OUTPUT_FORMATS}."
+ )
+
+ if glob_pattern is not None and glob_patterns is not None:
+ raise ValueError("Pass at most one of glob_pattern and glob_patterns.")
+ if glob_patterns is not None:
+ self.glob_patterns: List[str] = list(glob_patterns)
+ elif glob_pattern is not None:
+ self.glob_patterns = [glob_pattern]
+ else:
+ raw_list = self._fhir_settings.get("glob_patterns")
+ if raw_list:
+ if not isinstance(raw_list, list):
+ raise TypeError("config glob_patterns must be a list of strings.")
+ self.glob_patterns = [str(x) for x in raw_list]
+ elif self._fhir_settings.get("glob_pattern") is not None:
+ self.glob_patterns = [str(self._fhir_settings["glob_pattern"])]
+ else:
+ self.glob_patterns = ["**/*.ndjson.gz"]
+
+ self.glob_pattern = (
+ self.glob_patterns[0]
+ if len(self.glob_patterns) == 1
+ else "; ".join(self.glob_patterns)
+ )
+ self.max_patients = 1000 if dev and max_patients is None else max_patients
+
+ self._fhir_tables = tables_from_specs(self.resource_specs)
+
+ resolved_root = str(Path(root).expanduser().resolve())
+ super().__init__(
+ root=resolved_root,
+ tables=list(self._fhir_tables),
+ dataset_name=type(self).DATASET_NAME,
+ config_path=self._fhir_config_path,
+ cache_dir=cache_dir,
+ num_workers=num_workers,
+ dev=dev,
+ )
+
+ # ------------------------------------------------------------------
+ # Cache identity
+ # ------------------------------------------------------------------
+
+ def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
+ try:
+ yaml_digest = hashlib.sha256(
+ Path(self._fhir_config_path).read_bytes()
+ ).hexdigest()[:16]
+ except OSError:
+ yaml_digest = "missing"
+ identity = orjson.dumps(
+ {
+ "root": self.root,
+ "tables": sorted(self.tables),
+ "dataset_name": self.dataset_name,
+ "dev": self.dev,
+ "glob_patterns": self.glob_patterns,
+ "max_patients": self.max_patients,
+ "output_format": self.output_format,
+ "fhir_schema_version": FHIR_SCHEMA_VERSION,
+ "fhir_yaml_digest16": yaml_digest,
+ },
+ option=orjson.OPT_SORT_KEYS,
+ ).decode("utf-8")
+ cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity))
+ out = (
+ Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id
+ if cache_dir is None
+ else Path(cache_dir) / cache_id
+ )
+ out.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Cache dir: {out}")
+ return out
+
+ # ------------------------------------------------------------------
+ # NDJSON -> flat tables ingest
+ # ------------------------------------------------------------------
+
+ @property
+ def prepared_tables_dir(self) -> Path:
+ return self.cache_dir / "flattened_tables"
+
+ def _ensure_prepared_tables(self) -> None:
+ root = Path(self.root)
+ if not root.is_dir():
+ raise FileNotFoundError(f"FHIR root not found: {root}")
+
+ expected = [
+ self.prepared_tables_dir / table_file_name(t, self.output_format)
+ for t in self._fhir_tables
+ ]
+ if all(p.is_file() for p in expected):
+ return
+ if self.prepared_tables_dir.exists():
+ shutil.rmtree(self.prepared_tables_dir)
+
+ try:
+ staging_root = self.create_tmpdir()
+ staging = staging_root / "flattened_fhir_tables"
+ staging.mkdir(parents=True, exist_ok=True)
+ stream_fhir_ndjson_to_flat_tables(
+ root,
+ self.glob_patterns,
+ staging,
+ self.resource_specs,
+ self.output_format,
+ )
+ if self.max_patients is None:
+ shutil.move(str(staging), str(self.prepared_tables_dir))
+ return
+
+ filtered_root = self.create_tmpdir()
+ filtered = filtered_root / "filtered"
+ pids = sorted_patient_ids_from_flat_tables(
+ staging, self._fhir_tables, self.output_format
+ )
+ filter_flat_tables_by_patient_ids(
+ staging,
+ filtered,
+ pids[: self.max_patients],
+ self._fhir_tables,
+ self.output_format,
+ )
+ shutil.move(str(filtered), str(self.prepared_tables_dir))
+ finally:
+ self.clean_tmpdir()
+
+ def _event_transform(self, output_dir: Path) -> None:
+ self._ensure_prepared_tables()
+ super()._event_transform(output_dir)
+
+ # ------------------------------------------------------------------
+ # Table loading (flat tables instead of source CSVs)
+ # ------------------------------------------------------------------
+
+ def _read_flat_table(self, path: Path) -> dd.DataFrame:
+ if self.output_format == "parquet":
+ return dd.read_parquet(
+ str(path), split_row_groups=True, blocksize="64MB"
+ ).replace("", pd.NA)
+ sep = "\t" if self.output_format == "tsv" else ","
+ return dd.read_csv(
+ str(path), sep=sep, dtype=str, blocksize="64MB"
+ ).replace("", pd.NA)
+
+ def load_table(self, table_name: str) -> dd.DataFrame:
+ """Load one flattened table into the standard event schema.
+
+ Deviations from ``BaseDataset.load_table`` (CSV via ``_scan_csv_tsv_gz``):
+
+ * Reads pre-built flat tables (parquet/csv/tsv) under
+ ``prepared_tables_dir``.
+ * Timestamp parsing uses ``errors="coerce"`` + ``utc=True`` (FHIR ISO
+ strings include timezone suffix or partial dates).
+ * Strips tz-aware timestamps to naive UTC for Dask compat.
+ * Drops rows with null ``patient_id`` before returning.
+ """
+ assert self.config is not None
+ if table_name not in self.config.tables:
+ raise ValueError(f"Table {table_name} not found in config")
+
+ table_cfg = self.config.tables[table_name]
+ path = self.prepared_tables_dir / table_file_name(
+ table_name, self.output_format
+ )
+ if not path.exists():
+ raise FileNotFoundError(f"Flattened table not found: {path}")
+
+ logger.info(f"Scanning FHIR flattened table: {table_name} from {path}")
+ df: dd.DataFrame = self._read_flat_table(path)
+ df = df.rename(columns=str.lower)
+
+ preprocess_func = getattr(self, f"preprocess_{table_name}", None)
+ if preprocess_func is not None:
+ logger.info(
+ f"Preprocessing FHIR table: {table_name} "
+ f"with {preprocess_func.__name__}"
+ )
+ df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr]
+
+ for join_cfg in table_cfg.join:
+ join_path = self.prepared_tables_dir / Path(join_cfg.file_path).name
+ if not join_path.exists():
+ raise FileNotFoundError(f"FHIR join table not found: {join_path}")
+ logger.info(f"Joining FHIR table {table_name} with {join_path}")
+ join_df: dd.DataFrame = self._read_flat_table(join_path)
+ join_df = join_df.rename(columns=str.lower)
+ join_key = join_cfg.on.lower()
+ cols = [c.lower() for c in join_cfg.columns]
+ df = df.merge(join_df[[join_key] + cols], on=join_key, how=join_cfg.how)
+
+ ts_col = table_cfg.timestamp
+ if ts_col:
+ ts = (
+ functools.reduce(
+ operator.add,
+ (df[c].astype("string") for c in ts_col),
+ )
+ if isinstance(ts_col, list)
+ else df[ts_col].astype("string")
+ )
+ ts = dd.to_datetime(
+ ts, format=table_cfg.timestamp_format, errors="coerce", utc=True
+ )
+ df = df.assign(timestamp=ts.map_partitions(_strip_tz_to_naive_ms))
+ else:
+ df = df.assign(timestamp=pd.NaT)
+
+ if table_cfg.patient_id:
+ df = df.assign(patient_id=df[table_cfg.patient_id].astype("string"))
+ else:
+ df = df.reset_index(drop=True)
+ df = df.assign(patient_id=df.index.astype("string"))
+
+ df = df.dropna(subset=["patient_id"])
+ df = df.assign(event_type=table_name)
+ rename_attr = {
+ attr.lower(): f"{table_name}/{attr}" for attr in table_cfg.attributes
+ }
+ df = df.rename(columns=rename_attr)
+ return df[
+ ["patient_id", "event_type", "timestamp"]
+ + [rename_attr[a.lower()] for a in table_cfg.attributes]
+ ]
+
+ # ------------------------------------------------------------------
+ # Patient IDs (deterministic sorted order)
+ # ------------------------------------------------------------------
+
+ @property
+ def unique_patient_ids(self) -> List[str]:
+ if self._unique_patient_ids is None:
+ self._unique_patient_ids = (
+ self.global_event_df.select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")
+ .to_series()
+ .to_list()
+ )
+ logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs")
+ return self._unique_patient_ids
diff --git a/pyhealth/datasets/fhir/configs/mimic4fhir.yaml b/pyhealth/datasets/fhir/configs/mimic4fhir.yaml
new file mode 100644
index 000000000..1f0f1b697
--- /dev/null
+++ b/pyhealth/datasets/fhir/configs/mimic4fhir.yaml
@@ -0,0 +1,210 @@
+# MIMIC-IV-on-FHIR (R4) ingest config — single source of truth
+# ============================================================
+#
+# Authors: John Wu and Evan Febrianto
+#
+# This YAML drives the entire ingest pipeline for one FHIR export:
+#
+# 1. ``glob_patterns:`` which NDJSON files on disk to read
+# 2. ``resource_specs:`` how to project each FHIR resource type into a row
+# 3. ``tables:`` how those rows are exposed as events downstream
+#
+# Use this file as a complete worked example when authoring a YAML for any
+# other FHIR export (BigQuery dumps, Synthea, etc.). To use it as-is, just
+# instantiate :class:`~pyhealth.datasets.MIMIC4FHIR` with no extra arguments.
+#
+# To customise for a different export:
+# * subclass FHIRDataset and point ``DEFAULT_CONFIG_PATH`` at your YAML, or
+# * pass ``config_path=`` directly to ``FHIRDataset(...)``.
+
+version: "fhir_r4_flattened"
+
+
+# ---------------------------------------------------------------------------
+# Section 1: glob patterns — which NDJSON files to open
+# ---------------------------------------------------------------------------
+#
+# Defaults to ``["**/*.ndjson.gz"]`` when omitted. Only useful when the export
+# has per-resource-type file naming (PhysioNet MIMIC-IV FHIR ships shards as
+# ``MimicPatient*.ndjson.gz``, ``MimicEncounter*.ndjson.gz``, etc.). Filtering
+# at the file level avoids decompressing ~10% of the export that contains only
+# unconfigured resource types (MedicationAdministration, Specimen, Organization).
+#
+# For a generic export where everything lives in ``bundles.ndjson.gz`` /
+# ``**/*.ndjson.gz``, leave this commented out and the streamer will filter
+# resources by ``resourceType`` after parsing — correct, just slower.
+#
+# Override at runtime via ``MIMIC4FHIR(glob_pattern=...)`` or
+# ``MIMIC4FHIR(glob_patterns=[...])``.
+
+glob_patterns:
+ - "**/MimicPatient*.ndjson.gz"
+ - "**/MimicEncounter*.ndjson.gz"
+ - "**/MimicCondition*.ndjson.gz"
+ - "**/MimicObservation*.ndjson.gz"
+ - "**/MimicMedicationRequest*.ndjson.gz"
+ - "**/MimicProcedure*.ndjson.gz"
+
+
+# ---------------------------------------------------------------------------
+# Section 2: resource_specs — how to turn one FHIR JSON document into a row
+# ---------------------------------------------------------------------------
+#
+# Keys are FHIR ``resourceType`` strings. For each resource type we declare:
+#
+# table: output table name (also the per-type Parquet filename stem).
+# columns: ordered mapping of output column name -> Col spec.
+#
+# Each Col spec lists ordered JSON paths (``locate``); the first path that
+# resolves to a non-null value wins (this is how FHIR choice-types like
+# ``onset[x]`` and ``effective[x]`` are handled). ``transform`` names a function
+# in the engine's TRANSFORMS registry that maps the located leaf to a flat
+# string; ``required: true`` drops the resource if the leaf can't be resolved.
+#
+# Available transforms (defined in pyhealth/datasets/fhir/utils.py):
+#
+# identity Pass through (default; stringifies non-string scalars).
+# ref_id "{ "reference": "Patient/p1" }" -> "p1".
+# coding_key CodeableConcept -> "system|code" of its first coding.
+# bool_norm JSON boolean / "true"/"false" string -> "true"/"false"/None.
+# med_concept MedicationRequest.medication[x] -> codeable-concept or
+# "MedicationRequest/reference|" fallback.
+#
+# Adding a new transform: register it in TRANSFORMS in utils.py; reference it
+# by name here.
+
+resource_specs:
+
+ Patient:
+ table: patient
+ columns:
+ patient_id: { locate: ["id"], required: true }
+ patient_fhir_id: { locate: ["id"] }
+ birth_date: { locate: ["birthDate"] }
+ gender: { locate: ["gender"] }
+ deceased_boolean: { locate: ["deceasedBoolean"], transform: bool_norm }
+ deceased_datetime: { locate: ["deceasedDateTime"] }
+
+ Encounter:
+ table: encounter
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["id"] }
+ event_time: { locate: ["period.start"] }
+ encounter_class: { locate: ["class.code"] }
+ encounter_end: { locate: ["period.end"] }
+
+ Condition:
+ table: condition
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["encounter.reference"], transform: ref_id }
+ event_time: { locate: ["onsetDateTime", "onsetPeriod.start", "recordedDate"] }
+ concept_key: { locate: ["code"], transform: coding_key }
+
+ Observation:
+ table: observation
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["encounter.reference"], transform: ref_id }
+ event_time: { locate: ["effectiveDateTime", "effectivePeriod.start", "issued"] }
+ concept_key: { locate: ["code"], transform: coding_key }
+
+ MedicationRequest:
+ table: medication_request
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["encounter.reference"], transform: ref_id }
+ event_time: { locate: ["authoredOn"] }
+ concept_key: { locate: ["medicationCodeableConcept", "medicationReference"], transform: med_concept }
+
+ Procedure:
+ table: procedure
+ columns:
+ patient_id: { locate: ["subject.reference"], transform: ref_id, required: true }
+ resource_id: { locate: ["id"] }
+ encounter_id: { locate: ["encounter.reference"], transform: ref_id }
+ event_time: { locate: ["performedDateTime", "performedPeriod.start", "recordedDate"] }
+ concept_key: { locate: ["code"], transform: coding_key }
+
+
+# ---------------------------------------------------------------------------
+# Section 3: tables — how flat rows are exposed as events downstream
+# ---------------------------------------------------------------------------
+#
+# Keys here must match the ``table:`` values in section 2. Each entry declares
+# how BaseDataset.load_table reads the flat parquet:
+#
+# file_path: parquet filename (relative to the cached flattened_tables dir).
+# patient_id: column name that holds the patient id.
+# timestamp: column name to parse as the event timestamp.
+# attributes: list of columns to surface as event attributes; they are
+# renamed to ``{table}/{attr}`` in the global event frame and
+# later show up on ``Patient.get_events(...).attr_name``.
+
+tables:
+ patient:
+ file_path: "patient.parquet"
+ patient_id: "patient_id"
+ timestamp: "birth_date"
+ attributes:
+ - "patient_fhir_id"
+ - "birth_date"
+ - "gender"
+ - "deceased_boolean"
+ - "deceased_datetime"
+
+ encounter:
+ file_path: "encounter.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "encounter_class"
+ - "encounter_end"
+
+ condition:
+ file_path: "condition.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ observation:
+ file_path: "observation.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ medication_request:
+ file_path: "medication_request.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
+
+ procedure:
+ file_path: "procedure.parquet"
+ patient_id: "patient_id"
+ timestamp: "event_time"
+ attributes:
+ - "resource_id"
+ - "encounter_id"
+ - "event_time"
+ - "concept_key"
diff --git a/pyhealth/datasets/fhir/mimic4.py b/pyhealth/datasets/fhir/mimic4.py
new file mode 100644
index 000000000..5732a6776
--- /dev/null
+++ b/pyhealth/datasets/fhir/mimic4.py
@@ -0,0 +1,42 @@
+"""MIMIC-IV-on-FHIR (R4) dataset.
+
+A thin :class:`~pyhealth.datasets.fhir.base.FHIRDataset` wrapper that points at
+the bundled YAML for the PhysioNet MIMIC-IV on FHIR export. The whole ingest
+contract (resource projection + downstream table schema + glob hints) lives in
+the YAML; this class only names its default path.
+
+Use this YAML as the worked example when authoring a config for a different
+FHIR export — copy ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` and
+adapt the ``resource_specs:`` and ``tables:`` blocks.
+
+Authors:
+ John Wu and Evan Febrianto
+"""
+
+from __future__ import annotations
+
+import os
+
+from .base import FHIRDataset
+
+
+class MIMIC4FHIR(FHIRDataset):
+ """MIMIC-IV-on-FHIR (R4) dataset.
+
+ Streams the PhysioNet MIMIC-IV on FHIR NDJSON.GZ export into flattened
+ Patient/Encounter/Condition/Observation/MedicationRequest/Procedure tables,
+ then runs the standard :class:`~pyhealth.datasets.BaseDataset` pipeline.
+
+ The bundled config at ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml``
+ matches both the PhysioNet 2.1.0 demo and the full release. Override
+ ``config_path=`` to point at a customised copy.
+
+ Examples:
+ >>> ds = MIMIC4FHIR(root="/data/mimic-iv-fhir", max_patients=500)
+ >>> sample_ds = ds.set_task(task, num_workers=4)
+ """
+
+ DEFAULT_CONFIG_PATH = os.path.join(
+ os.path.dirname(__file__), "configs", "mimic4fhir.yaml"
+ )
+ DATASET_NAME = "mimic4fhir"
diff --git a/pyhealth/datasets/fhir/utils.py b/pyhealth/datasets/fhir/utils.py
new file mode 100644
index 000000000..d67f19077
--- /dev/null
+++ b/pyhealth/datasets/fhir/utils.py
@@ -0,0 +1,720 @@
+"""FHIR NDJSON parsing, generic flattening, and tabular table writing.
+
+This module is the **stateless engine** behind FHIR-to-tabular conversion. It
+knows nothing about any specific FHIR source or resource type: the per-resource
+projection is supplied as a declarative registry of :class:`ResourceSpec` objects
+(see ``MIMIC4FHIR.RESOURCE_SPECS`` for an example) and applied generically by
+:func:`flatten_resource`.
+
+Key public API
+--------------
+flatten_resource(resource, specs)
+ Project one FHIR resource dict into ``(table_name, row_dict)`` using a spec
+ registry, or ``None`` if the resource is unconfigured / missing a required
+ field.
+
+stream_fhir_ndjson_to_flat_tables(root, glob_pattern, out_dir, specs, output_format)
+ Stream all matching NDJSON/NDJSON.GZ resources into per-type flat tables
+ (parquet/csv/tsv), validating + counting drops along the way.
+
+sorted_ndjson_files(root, glob_pattern)
+ List matching NDJSON files under root (deduplicated, sorted).
+
+filter_flat_tables_by_patient_ids(source_dir, out_dir, keep_ids, tables, output_format)
+ Subset existing flattened tables to a specific patient cohort.
+
+Authors:
+ John Wu and Evan Febrianto
+"""
+
+from __future__ import annotations
+
+import gzip
+import logging
+from collections import Counter
+from dataclasses import dataclass
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
+
+import orjson
+import polars as pl
+import pyarrow as pa
+import pyarrow.csv as pa_csv
+import pyarrow.parquet as pq
+
+logger = logging.getLogger(__name__)
+
+GlobPatternArg = str | Sequence[str]
+"""Single glob string or sequence of strings for NDJSON file discovery."""
+
+__all__ = [
+ # Types
+ "GlobPatternArg",
+ "Col",
+ "ResourceSpec",
+ # Constants
+ "FHIR_SCHEMA_VERSION",
+ "SUPPORTED_OUTPUT_FORMATS",
+ "TRANSFORMS",
+ # Spec helpers
+ "tables_from_specs",
+ "columns_from_specs",
+ "table_file_name",
+ "load_resource_specs_from_yaml",
+ # Datetime helpers
+ "parse_dt",
+ "as_naive",
+ # FHIR iteration
+ "iter_ndjson_objects",
+ "iter_resources_from_ndjson_obj",
+ # Extraction
+ "flatten_resource",
+ # Pipeline
+ "sorted_ndjson_files",
+ "stream_fhir_ndjson_to_flat_tables",
+ "filter_flat_tables_by_patient_ids",
+ "sorted_patient_ids_from_flat_tables",
+]
+
+# Bump when the flattening engine or its output layout changes; folded into the
+# dataset cache identity so stale caches rebuild automatically.
+FHIR_SCHEMA_VERSION = 4
+
+SUPPORTED_OUTPUT_FORMATS = ("parquet", "csv", "tsv")
+
+
+# ---------------------------------------------------------------------------
+# Declarative extraction spec (the registry's value type)
+# ---------------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class Col:
+ """How to project one flat column out of a FHIR resource.
+
+ Typically constructed indirectly by :meth:`ResourceSpec.from_dict` while
+ loading the dataset's YAML config; direct instantiation is supported for
+ programmatic use.
+
+ Attributes:
+ locate: Ordered dotted paths into the resource; the first that resolves
+ to a non-null value wins. This is how FHIR choice-types (``onset[x]``,
+ ``effective[x]``, …) are handled — list every variant explicitly.
+ transform: Name of a value transform in :data:`TRANSFORMS` that converts
+ the located leaf into a flat scalar string.
+ required: When ``True``, a resource whose ``locate`` cannot be resolved is
+ dropped (and counted) rather than emitted with a null.
+ """
+
+ locate: Tuple[str, ...]
+ transform: str = "identity"
+ required: bool = False
+
+ @classmethod
+ def from_dict(cls, data: Mapping[str, Any], *, ctx: str = "") -> "Col":
+ """Build a :class:`Col` from a YAML-style dict.
+
+ Expected shape::
+
+ { locate: ["path.a", "path.b"], transform: "ref_id", required: false }
+
+ ``transform`` defaults to ``"identity"`` and must name an entry in
+ :data:`TRANSFORMS`. ``required`` defaults to ``False``. A missing or
+ empty ``locate`` field raises ``ValueError``.
+
+ Args:
+ data: Mapping containing ``locate`` (required) and the optional
+ ``transform`` / ``required`` keys.
+ ctx: Optional context string used in error messages
+ (e.g. ``"Patient.patient_id"``).
+ """
+ if not isinstance(data, Mapping):
+ raise ValueError(
+ f"{ctx or 'Col'}: expected a mapping, got {type(data).__name__}."
+ )
+ raw_locate = data.get("locate")
+ if not raw_locate:
+ raise ValueError(
+ f"{ctx or 'Col'}: missing required field 'locate'."
+ )
+ if isinstance(raw_locate, str):
+ locate: Tuple[str, ...] = (raw_locate,)
+ else:
+ locate = tuple(str(p) for p in raw_locate)
+ if not locate:
+ raise ValueError(
+ f"{ctx or 'Col'}: 'locate' must list at least one path."
+ )
+ transform = str(data.get("transform", "identity"))
+ if transform not in TRANSFORMS:
+ allowed = ", ".join(sorted(TRANSFORMS.keys()))
+ raise ValueError(
+ f"{ctx or 'Col'}: unknown transform {transform!r}. "
+ f"Allowed: {allowed}."
+ )
+ required = bool(data.get("required", False))
+ return cls(locate=locate, transform=transform, required=required)
+
+
+@dataclass(frozen=True)
+class ResourceSpec:
+ """How to project one FHIR resource type into a flat table.
+
+ Typically constructed indirectly by :func:`load_resource_specs_from_yaml`
+ while loading the dataset's YAML config; direct instantiation is supported
+ for programmatic use.
+
+ Attributes:
+ table: Output table name (also the per-type file stem).
+ columns: Mapping of output column name -> :class:`Col`. Insertion order
+ defines the table's column order.
+ """
+
+ table: str
+ columns: Mapping[str, Col]
+
+ @classmethod
+ def from_dict(
+ cls, resource_type: str, data: Mapping[str, Any]
+ ) -> "ResourceSpec":
+ """Build a :class:`ResourceSpec` from a YAML-style dict.
+
+ Expected shape::
+
+ {
+ table: "patient",
+ columns: {
+ patient_id: { locate: ["id"], required: true },
+ birth_date: { locate: ["birthDate"] },
+ ...
+ },
+ }
+
+ Args:
+ resource_type: FHIR resourceType string this spec describes
+ (e.g. ``"Patient"``). Used only for error messages.
+ data: Mapping containing ``table`` (required, str) and
+ ``columns`` (required, mapping of column name -> Col-shaped
+ mapping).
+ """
+ if not isinstance(data, Mapping):
+ raise ValueError(
+ f"resource_specs.{resource_type}: expected a mapping, "
+ f"got {type(data).__name__}."
+ )
+ table = data.get("table")
+ if not isinstance(table, str) or not table:
+ raise ValueError(
+ f"resource_specs.{resource_type}: missing required field "
+ f"'table' (string)."
+ )
+ raw_columns = data.get("columns")
+ if not isinstance(raw_columns, Mapping) or not raw_columns:
+ raise ValueError(
+ f"resource_specs.{resource_type}: missing required field "
+ f"'columns' (non-empty mapping)."
+ )
+ columns: Dict[str, Col] = {}
+ for col_name, col_data in raw_columns.items():
+ columns[str(col_name)] = Col.from_dict(
+ col_data,
+ ctx=f"resource_specs.{resource_type}.columns.{col_name}",
+ )
+ return cls(table=str(table), columns=columns)
+
+
+def load_resource_specs_from_yaml(
+ raw: Mapping[str, Any],
+) -> Dict[str, ResourceSpec]:
+ """Build the spec registry from a parsed YAML's ``resource_specs:`` block.
+
+ Args:
+ raw: The full parsed YAML mapping (top-level dict). The
+ ``resource_specs`` key, if present, must be a mapping of FHIR
+ resourceType -> ResourceSpec-shaped dict.
+
+ Returns:
+ Insertion-ordered mapping of resourceType to :class:`ResourceSpec`.
+
+ Raises:
+ ValueError: If the ``resource_specs`` block is missing, empty, or
+ contains a malformed entry.
+ """
+ block = raw.get("resource_specs")
+ if not isinstance(block, Mapping) or not block:
+ raise ValueError(
+ "config: missing or empty top-level 'resource_specs:' block. "
+ "Declare at least one FHIR resourceType -> spec mapping."
+ )
+ specs: Dict[str, ResourceSpec] = {}
+ for resource_type, data in block.items():
+ specs[str(resource_type)] = ResourceSpec.from_dict(
+ str(resource_type), data
+ )
+ return specs
+
+
+def tables_from_specs(specs: Mapping[str, ResourceSpec]) -> List[str]:
+ """Ordered, de-duplicated list of output table names declared by *specs*."""
+ seen: Dict[str, None] = {}
+ for spec in specs.values():
+ seen.setdefault(spec.table, None)
+ return list(seen.keys())
+
+
+def columns_from_specs(specs: Mapping[str, ResourceSpec]) -> Dict[str, List[str]]:
+ """Map each output table name to its ordered column names."""
+ return {spec.table: list(spec.columns.keys()) for spec in specs.values()}
+
+
+def table_file_name(table_name: str, output_format: str = "parquet") -> str:
+ """Filename for a flattened table given the output format."""
+ ext = "parquet" if output_format == "parquet" else output_format
+ return f"{table_name}.{ext}"
+
+
+# ---------------------------------------------------------------------------
+# Datetime helpers (kept for external callers)
+# ---------------------------------------------------------------------------
+
+
+def parse_dt(s: Optional[str]) -> Optional[datetime]:
+ if not s:
+ return None
+ try:
+ dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
+ except ValueError:
+ dt = None
+ if dt is None and len(s) >= 10:
+ try:
+ dt = datetime.strptime(s[:10], "%Y-%m-%d")
+ except ValueError:
+ return None
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+def as_naive(dt: Optional[datetime]) -> Optional[datetime]:
+ if dt is None:
+ return None
+ return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt
+
+
+# ---------------------------------------------------------------------------
+# FHIR JSON helpers
+# ---------------------------------------------------------------------------
+
+
+def _coding_key(coding: Dict[str, Any]) -> str:
+ return f"{coding.get('system') or 'unknown'}|{coding.get('code') or 'unknown'}"
+
+
+def _first_coding(obj: Optional[Dict[str, Any]]) -> Optional[str]:
+ """CodeableConcept -> ``"system|code"`` for its first coding (or None)."""
+ if not isinstance(obj, dict):
+ return None
+ codings = obj.get("coding") or []
+ if not codings and "concept" in obj:
+ codings = (obj.get("concept") or {}).get("coding") or []
+ return _coding_key(codings[0]) if codings else None
+
+
+def _ref_id(ref: Optional[Any]) -> Optional[str]:
+ """``{"reference": "Patient/p1"}`` or ``"Patient/p1"`` -> ``"p1"``."""
+ if isinstance(ref, dict):
+ ref = ref.get("reference")
+ if not ref:
+ return None
+ return ref.rsplit("/", 1)[-1] if "/" in ref else ref
+
+
+def _normalize_deceased_boolean_for_storage(value: Any) -> Optional[str]:
+ """Map Patient.deceasedBoolean to stored "true"/"false"/None.
+
+ FHIR JSON uses real booleans; some exports use strings. Python's
+ bool("false") is True, so we must not coerce with bool().
+ """
+ if value is None:
+ return None
+ if value is True:
+ return "true"
+ if value is False:
+ return "false"
+ if isinstance(value, str):
+ key = value.strip().lower()
+ if key in ("true", "1", "yes", "y", "t"):
+ return "true"
+ if key in ("false", "0", "no", "n", "f", ""):
+ return "false"
+ return None
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
+ if value == 0:
+ return "false"
+ if value == 1:
+ return "true"
+ return None
+ return None
+
+
+def _medication_concept_key(value: Any) -> Optional[str]:
+ """MedicationRequest medication[x] -> a stable concept key.
+
+ Accepts either a ``medicationCodeableConcept`` (-> ``"system|code"``) or a
+ ``medicationReference`` (-> ``"MedicationRequest/reference|"``).
+ """
+ if not isinstance(value, dict):
+ return None
+ if "coding" in value or "concept" in value:
+ key = _first_coding(value)
+ if key:
+ return key
+ ref = value.get("reference")
+ if ref:
+ return f"MedicationRequest/reference|{_ref_id(ref) or ref}"
+ return None
+
+
+def _identity(value: Any) -> Optional[str]:
+ if value is None or isinstance(value, str):
+ return value
+ return str(value)
+
+
+# Transform registry: how a located leaf becomes a flat scalar string.
+TRANSFORMS = {
+ "identity": _identity,
+ "ref_id": _ref_id,
+ "coding_key": _first_coding,
+ "bool_norm": _normalize_deceased_boolean_for_storage,
+ "med_concept": _medication_concept_key,
+}
+
+
+def _unwrap_resource_dict(raw: Any) -> Optional[Dict[str, Any]]:
+ if not isinstance(raw, dict):
+ return None
+ resource = raw.get("resource") if "resource" in raw else raw
+ return resource if isinstance(resource, dict) else None
+
+
+def iter_resources_from_ndjson_obj(obj: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
+ """Yield resource dicts from one parsed NDJSON object (Bundle or bare resource)."""
+ if "entry" in obj:
+ for entry in obj.get("entry") or []:
+ resource = entry.get("resource")
+ if isinstance(resource, dict):
+ yield resource
+ return
+ resource = _unwrap_resource_dict(obj)
+ if resource is not None:
+ yield resource
+
+
+def iter_ndjson_objects(path: Path) -> Iterator[Dict[str, Any]]:
+ """Yield parsed JSON objects from a plain or gzip-compressed NDJSON file."""
+ opener = (
+ gzip.open(path, "rt", encoding="utf-8", errors="replace")
+ if path.suffix == ".gz"
+ else open(path, encoding="utf-8", errors="replace")
+ )
+ with opener as stream:
+ for line in stream:
+ line = line.strip()
+ if not line:
+ continue
+ parsed = orjson.loads(line)
+ if isinstance(parsed, dict):
+ yield parsed
+
+
+# ---------------------------------------------------------------------------
+# Generic extraction engine
+# ---------------------------------------------------------------------------
+
+
+def _get_path(obj: Any, path: str) -> Any:
+ """Walk a dotted path (e.g. ``"encounter.reference"``) safely; None if absent."""
+ cur = obj
+ for part in path.split("."):
+ if not isinstance(cur, dict):
+ return None
+ cur = cur.get(part)
+ return cur
+
+
+def _first_located(resource: Dict[str, Any], paths: Tuple[str, ...]) -> Any:
+ """First non-null value among the ordered ``paths`` (choice-type resolution)."""
+ for path in paths:
+ value = _get_path(resource, path)
+ if value is not None:
+ return value
+ return None
+
+
+def flatten_resource(
+ resource: Dict[str, Any],
+ specs: Mapping[str, ResourceSpec],
+) -> Optional[Tuple[str, Dict[str, Optional[str]]]]:
+ """Project one FHIR resource into ``(table_name, row)`` via *specs*.
+
+ Returns ``None`` if the resource type is not configured in *specs*, or if a
+ column marked ``required`` cannot be resolved (a dropped/corrupted resource).
+ """
+ spec = specs.get(resource.get("resourceType"))
+ if spec is None:
+ return None
+ row: Dict[str, Optional[str]] = {}
+ for name, col in spec.columns.items():
+ raw = _first_located(resource, col.locate)
+ if raw is None and col.required:
+ return None
+ row[name] = TRANSFORMS[col.transform](raw)
+ return spec.table, row
+
+
+# ---------------------------------------------------------------------------
+# Tabular writer (parquet / csv / tsv)
+# ---------------------------------------------------------------------------
+
+
+def _table_schema(columns: Sequence[str]) -> pa.Schema:
+ return pa.schema([(col, pa.string()) for col in columns])
+
+
+class _BufferedTableWriter:
+ """Buffered, streaming writer for one flat table in parquet/csv/tsv."""
+
+ def __init__(
+ self,
+ path: Path,
+ schema: pa.Schema,
+ output_format: str = "parquet",
+ batch_size: int = 50_000,
+ ) -> None:
+ self.path = path
+ self.schema = schema
+ self.output_format = output_format
+ self.batch_size = batch_size
+ self.rows: List[Dict[str, Any]] = []
+ self._pq_writer: Optional[pq.ParquetWriter] = None
+ self._fh = None
+ self._csv_header_written = False
+ self._delimiter = "\t" if output_format == "tsv" else ","
+ self.path.parent.mkdir(parents=True, exist_ok=True)
+
+ def add(self, row: Dict[str, Any]) -> None:
+ self.rows.append(row)
+ if len(self.rows) >= self.batch_size:
+ self.flush()
+
+ def flush(self) -> None:
+ if not self.rows:
+ return
+ table = pa.Table.from_pylist(self.rows, schema=self.schema)
+ if self.output_format == "parquet":
+ if self._pq_writer is None:
+ self._pq_writer = pq.ParquetWriter(str(self.path), self.schema)
+ self._pq_writer.write_table(table)
+ else:
+ if self._fh is None:
+ self._fh = open(self.path, "wb")
+ pa_csv.write_csv(
+ table,
+ self._fh,
+ write_options=pa_csv.WriteOptions(
+ include_header=not self._csv_header_written,
+ delimiter=self._delimiter,
+ ),
+ )
+ self._csv_header_written = True
+ self.rows.clear()
+
+ def close(self) -> None:
+ self.flush()
+ if self.output_format == "parquet":
+ if self._pq_writer is None:
+ pq.write_table(
+ pa.Table.from_pylist([], schema=self.schema), str(self.path)
+ )
+ else:
+ self._pq_writer.close()
+ return
+ if self._fh is None:
+ # Empty table: still write a header-only file for a stable schema.
+ self._fh = open(self.path, "wb")
+ pa_csv.write_csv(
+ pa.Table.from_pylist([], schema=self.schema),
+ self._fh,
+ write_options=pa_csv.WriteOptions(
+ include_header=True, delimiter=self._delimiter
+ ),
+ )
+ self._fh.close()
+
+
+# ---------------------------------------------------------------------------
+# Pipeline
+# ---------------------------------------------------------------------------
+
+
+def sorted_ndjson_files(root: Path, glob_pattern: GlobPatternArg) -> List[Path]:
+ """Return sorted unique file paths under root matching glob pattern(s).
+
+ Args:
+ root: Root directory to search under.
+ glob_pattern: Single glob string or sequence of glob strings.
+
+ Returns:
+ Sorted list of matching files. Empty if no matches.
+ """
+ patterns = [glob_pattern] if isinstance(glob_pattern, str) else list(glob_pattern)
+ files: set[Path] = set()
+ for pat in patterns:
+ files.update(p for p in root.glob(pat) if p.is_file())
+ return sorted(files, key=lambda p: str(p))
+
+
+def stream_fhir_ndjson_to_flat_tables(
+ root: Path,
+ glob_pattern: GlobPatternArg,
+ out_dir: Path,
+ specs: Mapping[str, ResourceSpec],
+ output_format: str = "parquet",
+) -> None:
+ """Stream NDJSON resources into normalized per-resource flat tables.
+
+ Resources are validated as they stream: anything whose type is not in
+ *specs*, or which is missing a ``required`` field, is dropped and counted; a
+ summary is logged at the end so corruption is visible rather than silent.
+
+ Args:
+ root: Root directory containing NDJSON/NDJSON.GZ files.
+ glob_pattern: Single glob string or sequence of glob strings.
+ out_dir: Output directory for per-resource-type tables.
+ specs: Registry mapping FHIR resourceType -> :class:`ResourceSpec`.
+ output_format: One of :data:`SUPPORTED_OUTPUT_FORMATS`.
+ """
+ if output_format not in SUPPORTED_OUTPUT_FORMATS:
+ raise ValueError(
+ f"Unsupported output_format {output_format!r}; "
+ f"expected one of {SUPPORTED_OUTPUT_FORMATS}."
+ )
+ out_dir.mkdir(parents=True, exist_ok=True)
+ tables = tables_from_specs(specs)
+ columns = columns_from_specs(specs)
+ writers = {
+ name: _BufferedTableWriter(
+ path=out_dir / table_file_name(name, output_format),
+ schema=_table_schema(columns[name]),
+ output_format=output_format,
+ )
+ for name in tables
+ }
+
+ ingested: Counter = Counter()
+ dropped: Counter = Counter()
+ skipped_unconfigured: Counter = Counter()
+ try:
+ for file_path in sorted_ndjson_files(root, glob_pattern):
+ for ndjson_obj in iter_ndjson_objects(file_path):
+ for resource in iter_resources_from_ndjson_obj(ndjson_obj):
+ resource_type = resource.get("resourceType")
+ result = flatten_resource(resource, specs)
+ if result is None:
+ if resource_type in specs:
+ dropped[resource_type] += 1
+ else:
+ skipped_unconfigured[resource_type] += 1
+ continue
+ table_name, row = result
+ writers[table_name].add(row)
+ ingested[table_name] += 1
+ finally:
+ for writer in writers.values():
+ writer.close()
+
+ logger.info(
+ "FHIR flatten complete (%s): %s",
+ output_format,
+ {name: ingested[name] for name in tables},
+ )
+ for resource_type, count in dropped.items():
+ logger.warning(
+ "FHIR flatten: dropped %d %s resource(s) missing a required field "
+ "(e.g. patient reference).",
+ count,
+ resource_type,
+ )
+ if skipped_unconfigured:
+ total = sum(skipped_unconfigured.values())
+ logger.info(
+ "FHIR flatten: skipped %d resource(s) of %d unconfigured type(s): %s",
+ total,
+ len(skipped_unconfigured),
+ dict(skipped_unconfigured),
+ )
+
+
+def _scan_flat_table(path: Path, output_format: str) -> pl.LazyFrame:
+ if output_format == "parquet":
+ return pl.scan_parquet(str(path))
+ sep = "\t" if output_format == "tsv" else ","
+ # infer_schema_length=0 keeps every column as Utf8 (flat tables are all strings).
+ return pl.scan_csv(str(path), separator=sep, infer_schema_length=0)
+
+
+def sorted_patient_ids_from_flat_tables(
+ table_dir: Path,
+ tables: Sequence[str],
+ output_format: str = "parquet",
+) -> List[str]:
+ """Return sorted unique patient IDs from a directory of flattened tables."""
+ patient_path = table_dir / table_file_name("patient", output_format)
+ if "patient" in tables and patient_path.exists():
+ return (
+ _scan_flat_table(patient_path, output_format)
+ .select("patient_id")
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+ frames = [
+ _scan_flat_table(
+ table_dir / table_file_name(t, output_format), output_format
+ ).select("patient_id")
+ for t in tables
+ if t != "patient"
+ ]
+ return (
+ pl.concat(frames)
+ .unique()
+ .sort("patient_id")
+ .collect(engine="streaming")["patient_id"]
+ .to_list()
+ )
+
+
+def filter_flat_tables_by_patient_ids(
+ source_dir: Path,
+ out_dir: Path,
+ keep_ids: Sequence[str],
+ tables: Sequence[str],
+ output_format: str = "parquet",
+) -> None:
+ """Filter all flattened tables to only include rows for the given patient IDs."""
+ out_dir.mkdir(parents=True, exist_ok=True)
+ keep_set = set(keep_ids)
+ for name in tables:
+ src = source_dir / table_file_name(name, output_format)
+ dst = out_dir / table_file_name(name, output_format)
+ lf = _scan_flat_table(src, output_format).filter(
+ pl.col("patient_id").is_in(keep_set)
+ )
+ if output_format == "parquet":
+ lf.sink_parquet(str(dst))
+ else:
+ sep = "\t" if output_format == "tsv" else ","
+ lf.sink_csv(str(dst), separator=sep)
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 4c168d3e3..cba1a04c2 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -39,6 +39,7 @@
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
from .ehrmamba import EHRMamba, MambaBlock
+from .ehrmamba_cehr import EHRMambaCEHR
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
from .text_embedding import TextEmbedding
diff --git a/pyhealth/models/cehr_embeddings.py b/pyhealth/models/cehr_embeddings.py
new file mode 100644
index 000000000..7974a699e
--- /dev/null
+++ b/pyhealth/models/cehr_embeddings.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2024 Vector Institute / Odyssey authors
+#
+# Derived from Odyssey (https://github.com/VectorInstitute/odyssey):
+# odyssey/models/embeddings.py — MambaEmbeddingsForCEHR, TimeEmbeddingLayer, VisitEmbedding
+# Modifications: removed HuggingFace MambaConfig dependency; explicit constructor args.
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+
+class TimeEmbeddingLayer(nn.Module):
+ """Embedding layer for time features (sinusoidal)."""
+
+ def __init__(self, embedding_size: int, is_time_delta: bool = False):
+ super().__init__()
+ self.embedding_size = embedding_size
+ self.is_time_delta = is_time_delta
+ self.w = nn.Parameter(torch.empty(1, self.embedding_size))
+ self.phi = nn.Parameter(torch.empty(1, self.embedding_size))
+ nn.init.xavier_uniform_(self.w)
+ nn.init.xavier_uniform_(self.phi)
+
+ def forward(self, time_stamps: torch.Tensor) -> torch.Tensor:
+ if self.is_time_delta:
+ time_stamps = torch.cat(
+ (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]),
+ dim=-1,
+ )
+ time_stamps = time_stamps.float()
+ next_input = time_stamps.unsqueeze(-1) * self.w + self.phi
+ return torch.sin(next_input)
+
+
+class VisitEmbedding(nn.Module):
+ """Embedding layer for visit segments."""
+
+ def __init__(self, visit_order_size: int, embedding_size: int):
+ super().__init__()
+ self.embedding = nn.Embedding(visit_order_size, embedding_size)
+
+ def forward(self, visit_segments: torch.Tensor) -> torch.Tensor:
+ return self.embedding(visit_segments)
+
+
+class MambaEmbeddingsForCEHR(nn.Module):
+ """CEHR-style combined embeddings for Mamba (concept + type + time + age + visit)."""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ hidden_size: int,
+ pad_token_id: int = 0,
+ type_vocab_size: int = 9,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_order_size: int = 3,
+ layer_norm_eps: float = 1e-12,
+ hidden_dropout_prob: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.pad_token_id = pad_token_id
+ self.type_vocab_size = type_vocab_size
+ self.max_num_visits = max_num_visits
+ self.word_embeddings = nn.Embedding(
+ vocab_size, hidden_size, padding_idx=pad_token_id
+ )
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
+ self.visit_order_embeddings = nn.Embedding(max_num_visits, hidden_size)
+ self.time_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=True
+ )
+ self.age_embeddings = TimeEmbeddingLayer(
+ embedding_size=time_embeddings_size, is_time_delta=False
+ )
+ self.visit_segment_embeddings = VisitEmbedding(
+ visit_order_size=visit_order_size, embedding_size=hidden_size
+ )
+ self.scale_back_concat_layer = nn.Linear(
+ hidden_size + 2 * time_embeddings_size, hidden_size
+ )
+ self.tanh = nn.Tanh()
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ token_type_ids_batch: torch.Tensor,
+ time_stamps: torch.Tensor,
+ ages: torch.Tensor,
+ visit_orders: torch.Tensor,
+ visit_segments: torch.Tensor,
+ ) -> torch.Tensor:
+ inputs_embeds = self.word_embeddings(input_ids)
+ time_stamps_embeds = self.time_embeddings(time_stamps)
+ ages_embeds = self.age_embeddings(ages)
+ visit_segments_embeds = self.visit_segment_embeddings(visit_segments)
+ visit_order_embeds = self.visit_order_embeddings(visit_orders)
+ token_type_embeds = self.token_type_embeddings(token_type_ids_batch)
+ concat_in = torch.cat(
+ (inputs_embeds, time_stamps_embeds, ages_embeds), dim=-1
+ )
+ h = self.tanh(self.scale_back_concat_layer(concat_in))
+ embeddings = h + token_type_embeds + visit_order_embeds + visit_segments_embeds
+ embeddings = self.dropout(embeddings)
+ return self.LayerNorm(embeddings)
diff --git a/pyhealth/models/ehrmamba_cehr.py b/pyhealth/models/ehrmamba_cehr.py
new file mode 100644
index 000000000..cd555629c
--- /dev/null
+++ b/pyhealth/models/ehrmamba_cehr.py
@@ -0,0 +1,117 @@
+"""EHRMamba with CEHR-style embeddings for single-stream FHIR token sequences."""
+
+from __future__ import annotations
+
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from pyhealth.datasets import SampleDataset
+
+from .base_model import BaseModel
+from .cehr_embeddings import MambaEmbeddingsForCEHR
+from .ehrmamba import MambaBlock
+from .utils import get_rightmost_masked_timestep
+
+
+class EHRMambaCEHR(BaseModel):
+ """Mamba backbone over CEHR embeddings (FHIR / MPF pipeline).
+
+ Args:
+ dataset: Fitted :class:`~pyhealth.datasets.SampleDataset` with MPF task schema.
+ vocab_size: Concept embedding vocabulary size (typically ``task.vocab.vocab_size``).
+ embedding_dim: Hidden size (``hidden_size`` in CEHR embeddings).
+ num_layers: Number of :class:`~pyhealth.models.ehrmamba.MambaBlock` layers.
+ pad_token_id: Padding id for masking (default 0).
+ state_size: SSM state size per channel.
+ conv_kernel: Causal conv kernel in each block.
+ dropout: Dropout before classifier.
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ vocab_size: int,
+ embedding_dim: int = 128,
+ num_layers: int = 2,
+ pad_token_id: int = 0,
+ state_size: int = 16,
+ conv_kernel: int = 4,
+ dropout: float = 0.1,
+ type_vocab_size: int = 16,
+ max_num_visits: int = 512,
+ time_embeddings_size: int = 32,
+ visit_segment_vocab: int = 3,
+ ):
+ super().__init__(dataset=dataset)
+ self.embedding_dim = embedding_dim
+ self.num_layers = num_layers
+ self.pad_token_id = pad_token_id
+ self.vocab_size = vocab_size
+
+ assert len(self.label_keys) == 1, "EHRMambaCEHR supports single label key only"
+ self.label_key = self.label_keys[0]
+ self.mode = self.dataset.output_schema[self.label_key]
+
+ self.embeddings = MambaEmbeddingsForCEHR(
+ vocab_size=vocab_size,
+ hidden_size=embedding_dim,
+ pad_token_id=pad_token_id,
+ type_vocab_size=type_vocab_size,
+ max_num_visits=max_num_visits,
+ time_embeddings_size=time_embeddings_size,
+ visit_order_size=visit_segment_vocab,
+ )
+ self.blocks = nn.ModuleList(
+ [
+ MambaBlock(
+ d_model=embedding_dim,
+ state_size=state_size,
+ conv_kernel=conv_kernel,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.dropout = nn.Dropout(dropout)
+ out_dim = self.get_output_size()
+ self.fc = nn.Linear(embedding_dim, out_dim)
+ self._forecasting_head: Optional[nn.Module] = None
+
+ def forward_forecasting(self, **kwargs: Any) -> Optional[torch.Tensor]:
+ """Optional next-token / forecasting head (extension point; not implemented)."""
+
+ return None
+
+ def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]:
+ concept_ids = kwargs["concept_ids"].to(self.device).long()
+ token_type_ids = kwargs["token_type_ids"].to(self.device).long()
+ time_stamps = kwargs["time_stamps"].to(self.device).float()
+ ages = kwargs["ages"].to(self.device).float()
+ visit_orders = kwargs["visit_orders"].to(self.device).long()
+ visit_segments = kwargs["visit_segments"].to(self.device).long()
+
+ x = self.embeddings(
+ input_ids=concept_ids,
+ token_type_ids_batch=token_type_ids,
+ time_stamps=time_stamps,
+ ages=ages,
+ visit_orders=visit_orders,
+ visit_segments=visit_segments,
+ )
+ mask = concept_ids != self.pad_token_id
+ for blk in self.blocks:
+ x = blk(x)
+ pooled = get_rightmost_masked_timestep(x, mask)
+ logits = self.fc(self.dropout(pooled))
+ y_true = kwargs[self.label_key].to(self.device).float()
+ if y_true.dim() == 1:
+ y_true = y_true.unsqueeze(-1)
+ loss = self.get_loss_function()(logits, y_true)
+ y_prob = self.prepare_y_prob(logits)
+ return {
+ "loss": loss,
+ "y_prob": y_prob,
+ "y_true": y_true,
+ "logit": logits,
+ }
diff --git a/pyhealth/models/utils.py b/pyhealth/models/utils.py
index 67edc010e..45cd6608d 100644
--- a/pyhealth/models/utils.py
+++ b/pyhealth/models/utils.py
@@ -44,3 +44,31 @@ def get_last_visit(hidden_states, mask):
last_hidden_states = torch.gather(hidden_states, 1, last_visit)
last_hidden_state = last_hidden_states[:, 0, :]
return last_hidden_state
+
+
+def get_rightmost_masked_timestep(hidden_states, mask):
+ """Gather hidden state at the last True position in ``mask`` per row.
+
+ Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a
+ contiguous prefix; it picks the maximum index where ``mask`` is True.
+ Use for MPF / CEHR layouts where padding can appear between boundary tokens.
+
+ Args:
+ hidden_states: ``[batch, seq_len, hidden_size]``.
+ mask: ``[batch, seq_len]`` bool.
+
+ Returns:
+ Tensor ``[batch, hidden_size]``.
+ """
+ if mask is None:
+ return hidden_states[:, -1, :]
+ batch, seq_len, hidden = hidden_states.shape
+ device = hidden_states.device
+ idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(
+ batch, -1
+ )
+ idx_m = torch.where(mask, idx, torch.full_like(idx, -1))
+ last_idx = idx_m.max(dim=1).values.clamp(min=0)
+ last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden)
+ gathered = torch.gather(hidden_states, 1, last_idx)
+ return gathered[:, 0, :]
diff --git a/pyhealth/nlp/metrics.py b/pyhealth/nlp/metrics.py
index 667a6665b..a61ec3a2d 100644
--- a/pyhealth/nlp/metrics.py
+++ b/pyhealth/nlp/metrics.py
@@ -353,13 +353,13 @@ class LevenshteinDistanceScoreMethod(ScoreMethod):
"""
@classmethod
def _get_external_modules(cls: Type) -> Tuple[str, ...]:
- return ('editdistance~=0.8.1',)
+ return ('rapidfuzz>=3.0.0',)
def _score(self, meth: str, context: ScoreContext) -> Iterable[FloatScore]:
- import editdistance
+ from rapidfuzz.distance import Levenshtein
for s1, s2 in context.pairs:
- val: int = editdistance.eval(s1, s2)
+ val: int = Levenshtein.distance(s1, s2)
if self.normalize:
text_len: int = max(len(s1), len(s2))
val = 1. - (val / text_len)
diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py
index b48072270..4568a5ece 100644
--- a/pyhealth/processors/__init__.py
+++ b/pyhealth/processors/__init__.py
@@ -50,6 +50,7 @@ def get_processor(name: str):
from .ignore_processor import IgnoreProcessor
from .temporal_timeseries_processor import TemporalTimeseriesProcessor
from .tuple_time_text_processor import TupleTimeTextProcessor
+from .cehr_processor import CehrProcessor, ConceptVocab
# Expose public API
from .base_processor import (
@@ -79,4 +80,6 @@ def get_processor(name: str):
"GraphProcessor",
"AudioProcessor",
"TupleTimeTextProcessor",
+ "CehrProcessor",
+ "ConceptVocab",
]
diff --git a/pyhealth/processors/cehr_processor.py b/pyhealth/processors/cehr_processor.py
new file mode 100644
index 000000000..9199f51d7
--- /dev/null
+++ b/pyhealth/processors/cehr_processor.py
@@ -0,0 +1,175 @@
+"""Concept vocabulary and CEHR feature processor for FHIR timelines.
+
+Public API
+----------
+ConceptVocab
+ Token-to-dense-id mapping with PAD/UNK reserved at 0 and 1. JSON-serialisable.
+ensure_special_tokens(vocab)
+ Add CEHR/MPF specials (````, ````, ````, ````) and
+ return their ids.
+CehrProcessor
+ Standard :class:`~pyhealth.processors.FeatureProcessor` that maps a sample's
+ list of concept-key strings (already boundary-padded by the task) to a 1-D
+ ``torch.long`` tensor of token ids. Vocab growth happens inside the standard
+ ``SampleBuilder.fit(samples)`` loop -- no warm-up or freeze flag needed.
+
+The per-patient timeline-extraction helpers (`collect_cehr_timeline_events`,
+`build_cehr_sequences`, `infer_mortality_label`, etc.) live with the task that
+owns that logic: :mod:`pyhealth.tasks.mpf_clinical_prediction`.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Dict, Iterable, List
+
+import orjson
+import torch
+
+from . import register_processor
+from .base_processor import FeatureProcessor
+
+DEFAULT_PAD = 0
+DEFAULT_UNK = 1
+PAD_TOKEN = ""
+UNK_TOKEN = ""
+
+__all__ = [
+ "DEFAULT_PAD",
+ "DEFAULT_UNK",
+ "PAD_TOKEN",
+ "UNK_TOKEN",
+ "ConceptVocab",
+ "ensure_special_tokens",
+ "CehrProcessor",
+]
+
+
+# ---------------------------------------------------------------------------
+# Vocabulary
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ConceptVocab:
+ """Maps concept keys to dense ids with PAD/UNK reserved at 0 and 1."""
+
+ token_to_id: Dict[str, int] = field(default_factory=dict)
+ pad_id: int = DEFAULT_PAD
+ unk_id: int = DEFAULT_UNK
+ _next_id: int = 2
+
+ def __post_init__(self) -> None:
+ if not self.token_to_id:
+ self.token_to_id = {PAD_TOKEN: self.pad_id, UNK_TOKEN: self.unk_id}
+ self._next_id = 2
+
+ def add_token(self, key: str) -> int:
+ if key in self.token_to_id:
+ return self.token_to_id[key]
+ tid = self._next_id
+ self.token_to_id[key] = tid
+ self._next_id += 1
+ return tid
+
+ def __getitem__(self, key: str) -> int:
+ return self.token_to_id.get(key, self.unk_id)
+
+ @property
+ def vocab_size(self) -> int:
+ return self._next_id
+
+ def to_json(self) -> Dict[str, Any]:
+ return {
+ "token_to_id": self.token_to_id,
+ "next_id": self._next_id,
+ "pad_id": self.pad_id,
+ "unk_id": self.unk_id,
+ }
+
+ @classmethod
+ def from_json(cls, data: Dict[str, Any]) -> "ConceptVocab":
+ pad_id = int(data.get("pad_id", DEFAULT_PAD))
+ unk_id = int(data.get("unk_id", DEFAULT_UNK))
+ vocab = cls(pad_id=pad_id, unk_id=unk_id)
+ loaded = dict(data.get("token_to_id") or {})
+ if loaded:
+ vocab.token_to_id = loaded
+ vocab._next_id = int(data.get("next_id", max(loaded.values()) + 1))
+ else:
+ vocab._next_id = int(data.get("next_id", 2))
+ return vocab
+
+ def save(self, path: str) -> None:
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+ Path(path).write_bytes(orjson.dumps(self.to_json(), option=orjson.OPT_SORT_KEYS))
+
+ @classmethod
+ def load(cls, path: str) -> "ConceptVocab":
+ return cls.from_json(orjson.loads(Path(path).read_bytes()))
+
+
+def ensure_special_tokens(vocab: ConceptVocab) -> Dict[str, int]:
+ """Add EHRMamba/CEHR special tokens and return their ids."""
+ return {name: vocab.add_token(name) for name in ("", "", "", "")}
+
+
+# ---------------------------------------------------------------------------
+# Processor
+# ---------------------------------------------------------------------------
+
+
+@register_processor("cehr")
+class CehrProcessor(FeatureProcessor):
+ """Map a sample's list of concept-key strings to a 1-D LongTensor of ids.
+
+ The task is expected to have already done all boundary-token insertion
+ (```` / ```` / ````) and left-padding with ````. This
+ processor's only state is a :class:`ConceptVocab`, grown during the
+ standard :meth:`~pyhealth.datasets.sample_dataset.SampleBuilder.fit`
+ pass over cached samples.
+ """
+
+ def __init__(self, max_len: int = 512) -> None:
+ self.vocab = ConceptVocab()
+ ensure_special_tokens(self.vocab)
+ self.max_len = max_len
+
+ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> "CehrProcessor":
+ for sample in samples:
+ keys = sample.get(field)
+ if not keys:
+ continue
+ for key in keys:
+ if isinstance(key, str):
+ self.vocab.add_token(key)
+ return self
+
+ def process(self, value: List[Any]) -> torch.Tensor:
+ ids = [
+ self.vocab[k] if isinstance(k, str) else int(k)
+ for k in value
+ ]
+ return torch.tensor(ids, dtype=torch.long)
+
+ def save(self, path: str) -> None:
+ self.vocab.save(path)
+
+ def load(self, path: str) -> None:
+ self.vocab = ConceptVocab.load(path)
+
+ def is_token(self) -> bool:
+ return True
+
+ def schema(self) -> tuple[str, ...]:
+ return ("value",)
+
+ def dim(self) -> tuple[int, ...]:
+ return (1,)
+
+ def spatial(self) -> tuple[bool, ...]:
+ return (True,)
+
+ def __repr__(self) -> str:
+ return f"CehrProcessor(max_len={self.max_len})"
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index a32618f9c..3729cf5f5 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -67,3 +67,11 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
+
+
+def __getattr__(name: str):
+ if name == "MPFClinicalPredictionTask":
+ from .mpf_clinical_prediction import MPFClinicalPredictionTask
+
+ return MPFClinicalPredictionTask
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/pyhealth/tasks/mpf_clinical_prediction.py b/pyhealth/tasks/mpf_clinical_prediction.py
new file mode 100644
index 000000000..8e387b512
--- /dev/null
+++ b/pyhealth/tasks/mpf_clinical_prediction.py
@@ -0,0 +1,310 @@
+"""Multitask Prompted Fine-tuning (MPF) clinical prediction on FHIR timelines.
+
+The task reads per-patient events via :meth:`pyhealth.data.Patient.get_events`
+and :class:`~pyhealth.data.Event` attribute access (the standard PyHealth
+idiom). It builds six aligned CEHR feature sequences, inserts MPF boundary
+specials, and left-pads to ``max_len``.
+
+Concept-key → integer-id mapping happens later, inside the standard pipeline:
+``SampleBuilder.fit`` walks the cached ``task_df.ld`` and fits a
+:class:`~pyhealth.processors.CehrProcessor` on the ``concept_ids`` field;
+that processor's vocab is then applied per sample by ``_proc_transform``.
+The other five sequences are plain numeric lists handled by the standard
+tensor processor.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+import polars as pl
+
+from pyhealth.data import Event, Patient
+from pyhealth.processors.cehr_processor import PAD_TOKEN
+
+from .base_task import BaseTask
+
+__all__ = [
+ "EVENT_TYPE_TO_TOKEN_TYPE",
+ "MPFClinicalPredictionTask",
+ "collect_cehr_timeline_events",
+ "infer_mortality_label",
+]
+
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+EVENT_TYPE_TO_TOKEN_TYPE: Dict[str, int] = {
+ "encounter": 1,
+ "condition": 2,
+ "medication_request": 3,
+ "observation": 4,
+ "procedure": 5,
+}
+
+_CLINICAL_EVENT_TYPES: Tuple[str, ...] = (
+ "condition",
+ "observation",
+ "medication_request",
+ "procedure",
+)
+
+
+# ---------------------------------------------------------------------------
+# Small pure helpers
+# ---------------------------------------------------------------------------
+
+
+def _deceased_boolean_column_means_dead(value: Any) -> bool:
+ """True only for an explicit ``"true"`` flag (not Python truthiness)."""
+ if value is None:
+ return False
+ return str(value).strip().lower() == "true"
+
+
+def _encounter_concept_key(event: Any) -> str:
+ enc_class = getattr(event, "encounter_class", None)
+ if enc_class:
+ return f"encounter|{enc_class}"
+ return "encounter|unknown"
+
+
+def _sequential_visit_idx_for_time(
+ event_time: Optional[datetime],
+ visit_encounters: List[Tuple[datetime, int]],
+) -> int:
+ """Bucket an unlinked event into the nearest preceding encounter's index."""
+ if not visit_encounters:
+ return 0
+ if event_time is None:
+ return visit_encounters[-1][1]
+ chosen = visit_encounters[0][1]
+ for encounter_start, visit_idx in visit_encounters:
+ if encounter_start <= event_time:
+ chosen = visit_idx
+ else:
+ break
+ return chosen
+
+
+def _birth_datetime_from_patient(patient: Patient) -> Optional[datetime]:
+ """Patient's birth date.
+
+ The ``patient`` table's yaml entry declares ``timestamp: birth_date``, so
+ the Event's ``timestamp`` field is the birth date itself.
+ """
+ events = patient.get_events(event_type="patient")
+ return events[0].timestamp if events else None
+
+
+# ---------------------------------------------------------------------------
+# Timeline extraction
+# ---------------------------------------------------------------------------
+
+
+def collect_cehr_timeline_events(
+ patient: Patient,
+) -> List[Tuple[datetime, str, str, int]]:
+ """Collect ``(time, concept_key, event_type, visit_idx)`` tuples for one patient.
+
+ Encounters define the visit boundaries. Clinical events that reference a
+ known encounter id are linked directly; events without a matching
+ encounter reference are bucketed into the chronologically nearest
+ preceding visit.
+ """
+ # Only well-formed encounters (real id + non-null timestamp) define visit
+ # indices. We have to inspect the raw polars frame here:
+ # ``Event.__init__`` silently coerces ``timestamp=None`` to
+ # ``datetime.now()`` (data.py:43-45), so by the time we get back an Event
+ # we can no longer tell which encounters were timestamp-less.
+ encounters_df = patient.get_events(event_type="encounter", return_df=True)
+ valid_encounters = [
+ Event.from_dict(row)
+ for row in encounters_df.filter(
+ pl.col("timestamp").is_not_null()
+ & pl.col("encounter/encounter_id").is_not_null()
+ ).iter_rows(named=True)
+ ]
+
+ encounter_visit_idx: Dict[str, int] = {}
+ encounter_start_by_id: Dict[str, datetime] = {}
+ visit_encounters: List[Tuple[datetime, int]] = []
+ for idx, enc in enumerate(valid_encounters):
+ enc_id = enc.encounter_id
+ encounter_visit_idx[enc_id] = idx
+ encounter_start_by_id[enc_id] = enc.timestamp
+ visit_encounters.append((enc.timestamp, idx))
+
+ events: List[Tuple[datetime, str, str, int]] = []
+ unlinked: List[Tuple[Optional[datetime], str, str]] = []
+
+ for enc in valid_encounters:
+ events.append(
+ (
+ enc.timestamp,
+ _encounter_concept_key(enc),
+ "encounter",
+ encounter_visit_idx[enc.encounter_id],
+ )
+ )
+
+ for et in _CLINICAL_EVENT_TYPES:
+ for ev in patient.get_events(event_type=et):
+ concept_key = getattr(ev, "concept_key", None) or f"{et}|unknown"
+ enc_id = getattr(ev, "encounter_id", None)
+ t = ev.timestamp
+ if enc_id and enc_id in encounter_visit_idx:
+ if t is None:
+ t = encounter_start_by_id.get(enc_id)
+ if t is None:
+ continue
+ events.append((t, concept_key, et, encounter_visit_idx[enc_id]))
+ else:
+ unlinked.append((t, concept_key, et))
+
+ for t, concept_key, et in unlinked:
+ idx = _sequential_visit_idx_for_time(t, visit_encounters)
+ if t is None:
+ if not visit_encounters:
+ continue
+ # Use the start of the chosen visit; fall back to the latest encounter.
+ t = next(
+ (start for start, v_idx in visit_encounters if v_idx == idx),
+ visit_encounters[-1][0],
+ )
+ events.append((t, concept_key, et, idx))
+
+ events.sort(key=lambda item: item[0])
+ return events
+
+
+# ---------------------------------------------------------------------------
+# Label
+# ---------------------------------------------------------------------------
+
+
+def infer_mortality_label(patient: Patient) -> int:
+ """Heuristic binary mortality label from flattened patient rows."""
+ for ev in patient.get_events(event_type="patient"):
+ if _deceased_boolean_column_means_dead(getattr(ev, "deceased_boolean", None)):
+ return 1
+ if getattr(ev, "deceased_datetime", None):
+ return 1
+ for ev in patient.get_events(event_type="condition"):
+ ck = (getattr(ev, "concept_key", None) or "").lower()
+ if any(token in ck for token in ("death", "deceased", "mortality")):
+ return 1
+ return 0
+
+
+# ---------------------------------------------------------------------------
+# Task
+# ---------------------------------------------------------------------------
+
+
+class MPFClinicalPredictionTask(BaseTask):
+ """Binary mortality prediction from FHIR CEHR sequences with optional MPF tokens.
+
+ The task does timeline extraction and emits **raw** per-event lists,
+ including concept keys as strings. Tokenization is the
+ :class:`~pyhealth.processors.CehrProcessor`'s job, fit during the
+ standard ``SampleBuilder.fit(dataset)`` pass.
+
+ Attributes:
+ max_len: Output sequence length (must be >= 2 for boundary tokens).
+ use_mpf: If True, prepend ```` to the sequence; else ````.
+ The closing ```` is always emitted.
+ """
+
+ task_name: str = "MPFClinicalPredictionFHIR"
+ output_schema: Dict[str, str] = {"label": "binary"}
+
+ def __init__(self, max_len: int = 512, use_mpf: bool = True) -> None:
+ if max_len < 2:
+ raise ValueError("max_len must be >= 2 for MPF boundary tokens")
+ self.max_len = max_len
+ self.use_mpf = use_mpf
+ self.boundary_start = "" if use_mpf else ""
+ self.boundary_end = ""
+ self.input_schema: Dict[str, Any] = {
+ "concept_ids": ("cehr", {"max_len": max_len}),
+ "token_type_ids": ("tensor", {"dtype": torch.long}),
+ "time_stamps": ("tensor", {"dtype": torch.float32}),
+ "ages": ("tensor", {"dtype": torch.float32}),
+ "visit_orders": ("tensor", {"dtype": torch.long}),
+ "visit_segments": ("tensor", {"dtype": torch.long}),
+ }
+
+ def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
+ """Build one labeled sample dict per patient."""
+ timeline = collect_cehr_timeline_events(patient)
+ birth = _birth_datetime_from_patient(patient)
+
+ clinical_cap = self.max_len - 2
+ tail = timeline[-clinical_cap:] if clinical_cap > 0 else []
+ base_time = tail[0][0] if tail else None
+
+ # Build the six aligned sequences in a single pass.
+ keys: List[str] = [self.boundary_start]
+ token_types: List[int] = [0]
+ time_stamps: List[float] = [0.0]
+ ages: List[float] = [0.0]
+ vis_o: List[int] = [0]
+ vis_s: List[int] = [0]
+
+ for event_time, concept_key, event_type, visit_idx in tail:
+ time_delta = (
+ float((event_time - base_time).total_seconds())
+ if base_time is not None and event_time is not None
+ else 0.0
+ )
+ age_years = (
+ (event_time - birth).days / 365.25
+ if birth is not None and event_time is not None
+ else 0.0
+ )
+ keys.append(concept_key)
+ token_types.append(EVENT_TYPE_TO_TOKEN_TYPE.get(event_type, 0))
+ time_stamps.append(time_delta)
+ ages.append(age_years)
+ vis_o.append(min(visit_idx, 511))
+ vis_s.append(visit_idx % 2)
+
+ keys.append(self.boundary_end)
+ token_types.append(0)
+ time_stamps.append(0.0)
+ ages.append(0.0)
+ vis_o.append(0)
+ vis_s.append(0)
+
+ ml = self.max_len
+ keys = _left_pad(keys, ml, PAD_TOKEN)
+ token_types = _left_pad(token_types, ml, 0)
+ time_stamps = _left_pad(time_stamps, ml, 0.0)
+ ages = _left_pad(ages, ml, 0.0)
+ vis_o = _left_pad(vis_o, ml, 0)
+ vis_s = _left_pad(vis_s, ml, 0)
+
+ return [
+ {
+ "patient_id": patient.patient_id,
+ "concept_ids": keys,
+ "token_type_ids": token_types,
+ "time_stamps": time_stamps,
+ "ages": ages,
+ "visit_orders": vis_o,
+ "visit_segments": vis_s,
+ "label": infer_mortality_label(patient),
+ }
+ ]
+
+
+def _left_pad(seq: List[Any], max_len: int, pad: Any) -> List[Any]:
+ if len(seq) >= max_len:
+ return seq[-max_len:]
+ return [pad] * (max_len - len(seq)) + seq
diff --git a/pyproject.toml b/pyproject.toml
index 98f88d47b..65e0e2757 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,7 @@ dependencies = [
"dask[complete]~=2025.11.0",
"litdata~=0.2.59",
"pyarrow~=22.0.0",
+ "orjson~=3.10",
"narwhals~=2.13.0",
"more-itertools~=10.8.0",
"einops>=0.8.0",
@@ -64,7 +65,7 @@ graph = [
"torch-geometric>=2.6.0",
]
nlp = [
- "editdistance~=0.8.1",
+ "rapidfuzz>=3.0.0",
"rouge_score~=0.1.2",
"nltk~=3.9.1",
]
diff --git a/tests/core/test_ehrmamba_cehr.py b/tests/core/test_ehrmamba_cehr.py
new file mode 100644
index 000000000..81c995f28
--- /dev/null
+++ b/tests/core/test_ehrmamba_cehr.py
@@ -0,0 +1,126 @@
+import unittest
+
+import torch
+
+from pyhealth.datasets import create_sample_dataset, get_dataloader
+from pyhealth.models import EHRMambaCEHR
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+
+def _tiny_samples(seq: int = 16) -> tuple:
+ """Build hand-crafted samples in the new task's emitted format.
+
+ ``concept_ids`` carries raw string tokens (``""`` / ``""`` / a
+ filler concept key); the ``CehrProcessor`` registered via the task's
+ ``input_schema`` does the string → integer-id mapping during
+ ``SampleBuilder.fit``.
+ """
+ task = MPFClinicalPredictionTask(max_len=seq, use_mpf=True)
+ samples = []
+ for lab in (0, 1):
+ samples.append(
+ {
+ "patient_id": f"p{lab}",
+ "visit_id": f"v{lab}",
+ "concept_ids": [""] + ["test|filler"] * (seq - 2) + [""],
+ "token_type_ids": [0] * seq,
+ "time_stamps": [0.0] * seq,
+ "ages": [50.0] * seq,
+ "visit_orders": [0] * seq,
+ "visit_segments": [0] * seq,
+ "label": lab,
+ }
+ )
+ return samples, task
+
+
+class TestEHRMambaCEHR(unittest.TestCase):
+ def test_readout_pools_rightmost_non_pad(self) -> None:
+ """MPF padding between tokens must not make pooling pick a pad position."""
+
+ from pyhealth.models.utils import (
+ get_last_visit,
+ get_rightmost_masked_timestep,
+ )
+
+ h = torch.tensor([[[1.0, 0.0], [2.0, 0.0], [0.0, 0.0], [99.0, 0.0]]])
+ m = torch.tensor([[True, True, False, True]])
+ out = get_rightmost_masked_timestep(h, m)
+ self.assertTrue(torch.allclose(out[0], torch.tensor([99.0, 0.0])))
+ wrong = get_last_visit(h, m)
+ self.assertFalse(torch.allclose(out[0], wrong[0]))
+
+ def test_end_to_end_fhir_pipeline(self) -> None:
+ import tempfile
+ from pathlib import Path
+
+ from pyhealth.datasets import MIMIC4FHIR, create_sample_dataset
+ from pyhealth.datasets import get_dataloader
+
+ from tests.core.test_fhir_ndjson_fixtures import run_task, write_two_class_ndjson
+
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ with tempfile.TemporaryDirectory() as tmp:
+ write_two_class_ndjson(Path(tmp))
+ ds = MIMIC4FHIR(
+ root=tmp, glob_pattern="*.ndjson", cache_dir=tmp
+ )
+ samples = run_task(ds, task)
+ sample_ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ dataset_name="fhir_test",
+ )
+ vocab_size = sample_ds.input_processors["concept_ids"].vocab.vocab_size
+ model = EHRMambaCEHR(
+ dataset=sample_ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ )
+ batch = next(
+ iter(get_dataloader(sample_ds, batch_size=2, shuffle=False))
+ )
+ out = model(**batch)
+ self.assertIn("loss", out)
+ out["loss"].backward()
+
+ def test_forward_backward(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = ds.input_processors["concept_ids"].vocab.vocab_size
+ model = EHRMambaCEHR(
+ dataset=ds,
+ vocab_size=vocab_size,
+ embedding_dim=64,
+ num_layers=1,
+ state_size=8,
+ )
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertEqual(out["logit"].shape[0], 2)
+ out["loss"].backward()
+
+ def test_eval_mode(self) -> None:
+ samples, task = _tiny_samples()
+ ds = create_sample_dataset(
+ samples=samples,
+ input_schema=task.input_schema,
+ output_schema=task.output_schema,
+ )
+ vocab_size = ds.input_processors["concept_ids"].vocab.vocab_size
+ model = EHRMambaCEHR(dataset=ds, vocab_size=vocab_size, embedding_dim=32, num_layers=1)
+ model.eval()
+ with torch.no_grad():
+ batch = next(iter(get_dataloader(ds, batch_size=2, shuffle=False)))
+ out = model(**batch)
+ self.assertIn("y_prob", out)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_fhir_dataset.py b/tests/core/test_fhir_dataset.py
new file mode 100644
index 000000000..235c9eeeb
--- /dev/null
+++ b/tests/core/test_fhir_dataset.py
@@ -0,0 +1,764 @@
+import gzip
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import orjson
+import polars as pl
+
+from pyhealth.data import Patient
+from pyhealth.datasets import MIMIC4FHIR
+from pyhealth.datasets.fhir.utils import (
+ flatten_resource,
+ load_resource_specs_from_yaml,
+)
+from pyhealth.processors.cehr_processor import ConceptVocab
+from pyhealth.tasks.mpf_clinical_prediction import (
+ MPFClinicalPredictionTask,
+ collect_cehr_timeline_events,
+ infer_mortality_label,
+)
+
+
+def _mimic4_specs():
+ """Load the bundled MIMIC4 ResourceSpec registry from its YAML."""
+ import yaml as _yaml
+ with open(MIMIC4FHIR.DEFAULT_CONFIG_PATH, encoding="utf-8") as _f:
+ return load_resource_specs_from_yaml(_yaml.safe_load(_f))
+
+
+_MIMIC4_SPECS = _mimic4_specs()
+
+
+def _flatten_resource_to_table_row(resource):
+ """Flatten one resource via the MIMIC4 spec registry (test convenience)."""
+ return flatten_resource(resource, _MIMIC4_SPECS)
+
+
+def _clinical_slice(sample: Dict[str, object]) -> Tuple[List[str], List[int], List[int]]:
+ """Drop ```` and the leading/trailing boundary tokens from a sample.
+
+ Returns the per-event ``(concept_keys, visit_orders, visit_segments)``
+ lists for the patient's clinical events only.
+ """
+ keys = list(sample["concept_ids"]) # type: ignore[arg-type]
+ v_o = list(sample["visit_orders"]) # type: ignore[arg-type]
+ v_s = list(sample["visit_segments"]) # type: ignore[arg-type]
+ non_pad = [
+ (k, o, s) for k, o, s in zip(keys, v_o, v_s) if k != ""
+ ]
+ # Strip leading boundary (/) and trailing .
+ middle = non_pad[1:-1] if len(non_pad) >= 2 else []
+ return (
+ [k for k, _, _ in middle],
+ [o for _, o, _ in middle],
+ [s for _, _, s in middle],
+ )
+
+from tests.core.test_fhir_ndjson_fixtures import (
+ ndjson_two_class_text,
+ write_one_patient_ndjson,
+ write_two_class_ndjson,
+)
+
+
+def _third_patient_loinc_resources() -> List[Dict[str, object]]:
+ return [
+ {
+ "resourceType": "Patient",
+ "id": "p-synth-3",
+ "birthDate": "1960-01-01",
+ },
+ {
+ "resourceType": "Encounter",
+ "id": "e3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "period": {"start": "2020-08-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o3",
+ "subject": {"reference": "Patient/p-synth-3"},
+ "encounter": {"reference": "Encounter/e3"},
+ "effectiveDateTime": "2020-08-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "999-9"}]},
+ },
+ ]
+
+
+def write_two_class_plus_third_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ lines = ndjson_two_class_text().strip().split("\n")
+ lines.extend(orjson.dumps(r).decode("utf-8") for r in _third_patient_loinc_resources())
+ path = directory / name
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return path
+
+
+def _patient_from_rows(patient_id: str, rows: List[Dict[str, object]]) -> Patient:
+ """Build a Patient whose ``timestamp`` column is a real datetime, matching
+ the shape ``FHIRDataset.load_table`` produces in production.
+ """
+ df = pl.DataFrame(rows).with_columns(
+ pl.col("timestamp").str.to_datetime(strict=False)
+ )
+ return Patient(patient_id=patient_id, data_source=df)
+
+
+class TestDeceasedBooleanFlattening(unittest.TestCase):
+ def test_string_false_not_coerced_by_python_bool(self) -> None:
+ """Non-conformant ``\"false\"`` string must not become stored ``\"true\"``."""
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-false",
+ "deceasedBoolean": "false",
+ }
+ )
+ self.assertIsNotNone(row)
+ _table, payload = row
+ self.assertEqual(payload.get("deceased_boolean"), "false")
+
+ def test_string_true_parsed(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-str-true",
+ "deceasedBoolean": "true",
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), "true")
+
+ def test_json_booleans_unchanged(self) -> None:
+ for raw, expected in ((True, "true"), (False, "false")):
+ with self.subTest(raw=raw):
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-bool",
+ "deceasedBoolean": raw,
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1].get("deceased_boolean"), expected)
+
+ def test_unknown_deceased_type_stored_as_none(self) -> None:
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Patient",
+ "id": "p-garbage",
+ "deceasedBoolean": {"unexpected": "object"},
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertIsNone(row[1].get("deceased_boolean"))
+
+ def test_infer_mortality_respects_string_false_row(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "event_type": "patient",
+ "timestamp": "2020-01-01T00:00:00",
+ "patient/deceased_boolean": "false",
+ },
+ ],
+ )
+ self.assertEqual(infer_mortality_label(patient), 0)
+
+
+class TestFHIRDataset(unittest.TestCase):
+ def test_concept_vocab_from_json_empty_token_to_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}})
+ self.assertIn("", v.token_to_id)
+ self.assertIn("", v.token_to_id)
+ self.assertEqual(v._next_id, 2)
+
+ def test_concept_vocab_from_json_empty_respects_next_id(self) -> None:
+ v = ConceptVocab.from_json({"token_to_id": {}, "next_id": 50})
+ self.assertEqual(v._next_id, 50)
+
+ def test_sorted_ndjson_files_accepts_sequence_and_dedupes(self) -> None:
+ from pyhealth.datasets.fhir.utils import sorted_ndjson_files
+
+ with tempfile.TemporaryDirectory() as tmp:
+ root = Path(tmp)
+ (root / "MimicPatient.ndjson.gz").write_text("x", encoding="utf-8")
+ (root / "MimicMedication.ndjson.gz").write_text("y", encoding="utf-8")
+ (root / "notes.txt").write_text("z", encoding="utf-8")
+ wide = sorted_ndjson_files(root, "**/*.ndjson.gz")
+ narrow = sorted_ndjson_files(
+ root,
+ ["MimicPatient*.ndjson.gz", "**/MimicPatient*.ndjson.gz"],
+ )
+ self.assertEqual(len(wide), 2)
+ self.assertEqual(len(narrow), 1)
+ self.assertEqual(narrow[0].name, "MimicPatient.ndjson.gz")
+
+ def test_dataset_accepts_glob_patterns_kwarg(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ write_one_patient_ndjson(Path(tmp))
+ ds = MIMIC4FHIR(
+ root=tmp, glob_patterns=["*.ndjson"], cache_dir=tmp
+ )
+ self.assertEqual(ds.glob_patterns, ["*.ndjson"])
+
+ def test_dataset_rejects_both_glob_kwargs(self) -> None:
+ with tempfile.TemporaryDirectory() as tmp:
+ with self.assertRaises(ValueError):
+ MIMIC4FHIR(
+ root=tmp,
+ glob_pattern="*.ndjson",
+ glob_patterns=["*.ndjson"],
+ cache_dir=tmp,
+ )
+
+ def test_disk_ingest_gz_and_max_patients(self) -> None:
+ """gzip ingest path + ``max_patients`` cap, covered in one build.
+
+ The heavier build/schema/set_task/pre_filter assertions now live in
+ ``TestFHIRSharedWorkflow`` (one shared build), so this is the only
+ ingest-variant build left in this class.
+ """
+ with tempfile.TemporaryDirectory() as tmp:
+ gz_path = Path(tmp) / "fixture.ndjson.gz"
+ with gzip.open(gz_path, "wt", encoding="utf-8") as gz:
+ gz.write(ndjson_two_class_text())
+ ds = MIMIC4FHIR(root=tmp, glob_pattern="*.ndjson.gz", max_patients=5)
+ self.assertEqual(len(ds.unique_patient_ids), 2)
+
+ def test_encounter_reference_requires_exact_id(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-02T10:00:00",
+ "encounter/encounter_id": "e10",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-02T11:00:00",
+ "condition/encounter_id": "e10",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I99",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ self.assertEqual(
+ sample["concept_ids"].count("http://hl7.org/fhir/sid/icd-10-cm|I99"),
+ 1,
+ )
+
+ def test_unlinked_condition_emitted_once_with_two_encounters(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "ea",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "eb",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ self.assertEqual(
+ sample["concept_ids"].count("http://hl7.org/fhir/sid/icd-10-cm|Z00"),
+ 1,
+ )
+
+ def test_max_len_two_keeps_only_boundary_tokens(self) -> None:
+ """``max_len=2`` leaves room for only the two boundary tokens; the
+ clinical timeline is truncated away.
+ """
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=2, use_mpf=True)(patient)[0]
+ self.assertEqual(sample["concept_ids"], ["", ""])
+ self.assertEqual(sample["visit_segments"], [0, 0])
+
+ def test_visit_segments_alternate_by_visit_index(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e0",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e0",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-07-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-07-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I20",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ _, _, visit_segments = _clinical_slice(sample)
+ self.assertEqual(visit_segments, [0, 0, 1, 1])
+
+ def test_unlinked_visit_idx_matches_sequential_counter(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": None,
+ "encounter/encounter_id": "e_bad",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-03-01T10:00:00",
+ "encounter/encounter_id": "e_ok",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-05T11:00:00",
+ "condition/encounter_id": "e_ok",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|I10",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-03-15T12:00:00",
+ "condition/concept_key": "http://hl7.org/fhir/sid/icd-10-cm|Z00",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ keys = sample["concept_ids"]
+ i_link = keys.index("http://hl7.org/fhir/sid/icd-10-cm|I10")
+ i_free = keys.index("http://hl7.org/fhir/sid/icd-10-cm|Z00")
+ self.assertEqual(sample["visit_orders"][i_link], sample["visit_orders"][i_free])
+ self.assertEqual(sample["visit_segments"][i_link], sample["visit_segments"][i_free])
+
+ def test_medication_request_uses_medication_codeable_concept(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|111",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T12:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "http://www.nlm.nih.gov/research/umls/rxnorm|222",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ keys = sample["concept_ids"]
+ ka = "http://www.nlm.nih.gov/research/umls/rxnorm|111"
+ kb = "http://www.nlm.nih.gov/research/umls/rxnorm|222"
+ self.assertEqual(keys.count(ka), 1)
+ self.assertEqual(keys.count(kb), 1)
+
+ def test_medication_request_medication_reference_token(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "IMP",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "medication_request",
+ "timestamp": "2020-06-01T11:00:00",
+ "medication_request/encounter_id": "e1",
+ "medication_request/concept_key": "MedicationRequest/reference|med-abc",
+ },
+ ],
+ )
+ sample = MPFClinicalPredictionTask(max_len=64, use_mpf=True)(patient)[0]
+ key = "MedicationRequest/reference|med-abc"
+ self.assertIn(key, sample["concept_ids"])
+ self.assertEqual(sample["concept_ids"].count(key), 1)
+
+ def test_collect_cehr_timeline_events_orders_by_timestamp(self) -> None:
+ patient = _patient_from_rows(
+ "p1",
+ [
+ {
+ "patient_id": "p1",
+ "event_type": "patient",
+ "timestamp": None,
+ "patient/birth_date": "1950-01-01",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "encounter",
+ "timestamp": "2020-06-01T10:00:00",
+ "encounter/encounter_id": "e1",
+ "encounter/encounter_class": "AMB",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "condition",
+ "timestamp": "2020-06-01T11:00:00",
+ "condition/encounter_id": "e1",
+ "condition/concept_key": "a|1",
+ },
+ {
+ "patient_id": "p1",
+ "event_type": "observation",
+ "timestamp": "2020-06-01T12:00:00",
+ "observation/encounter_id": "e1",
+ "observation/concept_key": "b|2",
+ },
+ ],
+ )
+ events = collect_cehr_timeline_events(patient)
+ self.assertEqual([event[1] for event in events], ["encounter|AMB", "a|1", "b|2"])
+
+ def test_observation_effective_period_start_yields_event_time(self) -> None:
+ """Choice-type fix: an Observation carrying only ``effectivePeriod.start``
+ (no ``effectiveDateTime``) must still resolve a non-null event_time.
+ The pre-refactor extractor silently dropped this variant.
+ """
+ row = _flatten_resource_to_table_row(
+ {
+ "resourceType": "Observation",
+ "id": "o-period",
+ "subject": {"reference": "Patient/p"},
+ "effectivePeriod": {"start": "2022-02-02T00:00:00Z"},
+ "code": {"coding": [{"system": "http://loinc.org", "code": "1-1"}]},
+ }
+ )
+ self.assertIsNotNone(row)
+ self.assertEqual(row[1]["event_time"], "2022-02-02T00:00:00Z")
+
+ def test_new_resource_type_via_registry_flows_through(self) -> None:
+ """A resource type absent from MIMIC4's specs flows end-to-end purely by
+ adding a YAML entry — no engine change.
+
+ Also exercises the directly-usable generic ``FHIRDataset`` (whole
+ ingest contract authored in a single YAML, no subclass).
+ """
+ from pyhealth.datasets import FHIRDataset
+
+ resources = [
+ {"resourceType": "Patient", "id": "imm-1", "birthDate": "1970-01-01"},
+ {
+ "resourceType": "Immunization",
+ "id": "i1",
+ "patient": {"reference": "Patient/imm-1"},
+ "occurrenceDateTime": "2021-01-01T00:00:00Z",
+ "vaccineCode": {
+ "coding": [{"system": "http://hl7.org/fhir/sid/cvx", "code": "208"}]
+ },
+ },
+ ]
+ config_yaml = (
+ "version: test\n"
+ "resource_specs:\n"
+ " Patient:\n"
+ " table: patient\n"
+ " columns:\n"
+ " patient_id: { locate: [id], required: true }\n"
+ " birth_date: { locate: [birthDate] }\n"
+ " Immunization:\n"
+ " table: immunization\n"
+ " columns:\n"
+ " patient_id: { locate: [patient.reference], transform: ref_id, required: true }\n"
+ " resource_id: { locate: [id] }\n"
+ " encounter_id: { locate: [encounter.reference], transform: ref_id }\n"
+ " event_time: { locate: [occurrenceDateTime, recorded] }\n"
+ " concept_key: { locate: [vaccineCode], transform: coding_key }\n"
+ "tables:\n"
+ " patient:\n"
+ " file_path: patient.parquet\n"
+ " patient_id: patient_id\n"
+ " timestamp: birth_date\n"
+ " attributes: [birth_date]\n"
+ " immunization:\n"
+ " file_path: immunization.parquet\n"
+ " patient_id: patient_id\n"
+ " timestamp: event_time\n"
+ " attributes: [resource_id, encounter_id, event_time, concept_key]\n"
+ )
+ with tempfile.TemporaryDirectory() as tmp:
+ tmpp = Path(tmp)
+ (tmpp / "fx.ndjson").write_text(
+ "\n".join(orjson.dumps(r).decode("utf-8") for r in resources) + "\n",
+ encoding="utf-8",
+ )
+ cfg = tmpp / "immun.yaml"
+ cfg.write_text(config_yaml, encoding="utf-8")
+ ds = FHIRDataset(
+ root=str(tmpp),
+ config_path=str(cfg),
+ glob_pattern="*.ndjson",
+ cache_dir=str(tmpp),
+ )
+ df = ds.global_event_df.collect(engine="streaming")
+ self.assertIn("immunization/concept_key", df.columns)
+ keys = (
+ df.filter(pl.col("event_type") == "immunization")[
+ "immunization/concept_key"
+ ]
+ .to_list()
+ )
+ self.assertIn("http://hl7.org/fhir/sid/cvx|208", keys)
+
+ def test_fhir_dataset_requires_specs(self) -> None:
+ """Bare ``FHIRDataset`` (no specs, no subclass) errors clearly."""
+ from pyhealth.datasets import FHIRDataset
+
+ with tempfile.TemporaryDirectory() as tmp:
+ with self.assertRaises(ValueError):
+ FHIRDataset(root=tmp, glob_pattern="*.ndjson", cache_dir=tmp)
+
+
+class TestFHIRSharedWorkflow(unittest.TestCase):
+ """Build the dataset and run ``set_task`` ONCE, then assert over the shared
+ artifacts. Mirrors a realistic "ingest once, do many things" workflow and
+ keeps the suite fast: a single Dask build (plus one canonical ``set_task``)
+ shared by every assertion, instead of rebuilding per test.
+ """
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._tmp = tempfile.mkdtemp()
+ write_two_class_plus_third_ndjson(Path(cls._tmp))
+ cls.ds = MIMIC4FHIR(
+ root=cls._tmp, glob_pattern="*.ndjson", cache_dir=cls._tmp, num_workers=1
+ )
+ # The single Dask build for the whole class.
+ cls.global_df = cls.ds.global_event_df.collect(engine="streaming")
+ # The canonical set_task (reuses the build above; no rebuild).
+ cls.sample_ds = cls.ds.set_task(
+ MPFClinicalPredictionTask(max_len=48, use_mpf=True), num_workers=1
+ )
+ cls.samples = sorted(
+ [cls.sample_ds[i] for i in range(len(cls.sample_ds))],
+ key=lambda s: s["patient_id"],
+ )
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ shutil.rmtree(cls._tmp, ignore_errors=True)
+
+ def test_build_produces_expected_tables_and_schema(self) -> None:
+ """Flat parquet tables exist, and the global event frame has the
+ expected long-format + namespaced columns with a patient's events.
+ """
+ prepared = self.ds.prepared_tables_dir
+ for name in ("patient", "encounter", "condition", "observation"):
+ self.assertTrue((prepared / f"{name}.parquet").is_file())
+ for col in (
+ "patient_id",
+ "timestamp",
+ "event_type",
+ "condition/concept_key",
+ "observation/concept_key",
+ "patient/deceased_boolean",
+ ):
+ self.assertIn(col, self.global_df.columns)
+ sub = self.global_df.filter(pl.col("patient_id") == "p-synth-1")
+ self.assertGreaterEqual(len(sub), 2)
+
+ def test_set_task_builds_vocab(self) -> None:
+ vocab = self.sample_ds.input_processors["concept_ids"].vocab
+ self.assertGreater(vocab.vocab_size, 6)
+
+ def test_set_task_produces_correct_samples(self) -> None:
+ self.assertEqual(len(self.samples), 3)
+ self.assertEqual(
+ {s["patient_id"] for s in self.samples},
+ {"p-synth-1", "p-synth-2", "p-synth-3"},
+ )
+ for s in self.samples:
+ self.assertIn("concept_ids", s)
+ self.assertIn("label", s)
+ self.assertEqual({int(s["label"]) for s in self.samples}, {0, 1})
+
+ def test_cehr_sequence_shapes(self) -> None:
+ patient = self.ds.get_patient("p-synth-1")
+ sample = MPFClinicalPredictionTask(max_len=32, use_mpf=True)(patient)[0]
+ n = len(sample["concept_ids"])
+ self.assertEqual(n, 32)
+ for key in (
+ "token_type_ids", "time_stamps", "ages", "visit_orders", "visit_segments",
+ ):
+ self.assertEqual(len(sample[key]), n)
+ non_special = {
+ k for k in sample["concept_ids"]
+ if k not in ("", "", "", "")
+ }
+ self.assertGreater(len(non_special), 0)
+
+ def test_mpf_pre_filter_single_patient_limits_effective_workers(self) -> None:
+ """Pre-filter yielding one patient caps effective_workers to 1 (the
+ formula is verified directly; a 1-patient ``set_task`` would raise on a
+ single label class).
+ """
+ class OnePatientMPFTask(MPFClinicalPredictionTask):
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(pl.col("patient_id") == "p-synth-1")
+
+ warmup_pids = (
+ OnePatientMPFTask(max_len=48, use_mpf=True)
+ .pre_filter(self.ds.global_event_df)
+ .select("patient_id")
+ .unique()
+ .collect(engine="streaming")
+ .to_series()
+ .sort()
+ .to_list()
+ )
+ self.assertEqual(warmup_pids, ["p-synth-1"])
+ effective_workers = min(2, len(warmup_pids)) if warmup_pids else 1
+ self.assertEqual(effective_workers, 1)
+
+ def test_mpf_pre_filter_excludes_dropped_patients_from_vocab(self) -> None:
+ """A task ``pre_filter`` that drops a patient also drops their concept
+ keys from the fitted vocab. Reuses the shared build; a distinct
+ ``task_name`` keeps this run's sample cache separate from the canonical
+ ``set_task`` in setUpClass (identical params would otherwise collide on
+ the shared ``cache_dir``).
+ """
+ class TwoPatientMPFTask(MPFClinicalPredictionTask):
+ task_name = "MPFClinicalPredictionFHIR_prefilter"
+
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(
+ pl.col("patient_id").is_in(["p-synth-1", "p-synth-2"])
+ )
+
+ sample_ds = self.ds.set_task(
+ TwoPatientMPFTask(max_len=48, use_mpf=True), num_workers=1
+ )
+ vocab = sample_ds.input_processors["concept_ids"].vocab
+ self.assertNotIn("http://loinc.org|999-9", vocab.token_to_id)
+ self.assertIn("http://loinc.org|789-0", vocab.token_to_id)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_fhir_ndjson_fixtures.py b/tests/core/test_fhir_ndjson_fixtures.py
new file mode 100644
index 000000000..7311d4802
--- /dev/null
+++ b/tests/core/test_fhir_ndjson_fixtures.py
@@ -0,0 +1,110 @@
+"""NDJSON file bodies for :mod:`tests.core.test_fhir_dataset` (disk-only ingest)."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Dict, List
+
+import orjson
+
+
+# ---------------------------------------------------------------------------
+# Synthetic in-memory FHIR resources
+# ---------------------------------------------------------------------------
+
+
+def _one_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ {"resourceType": "Patient", "id": "p-synth-1", "birthDate": "1950-01-01", "gender": "female"},
+ {
+ "resourceType": "Encounter",
+ "id": "e1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "period": {"start": "2020-06-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Condition",
+ "id": "c1",
+ "subject": {"reference": "Patient/p-synth-1"},
+ "encounter": {"reference": "Encounter/e1"},
+ "code": {"coding": [{"system": "http://hl7.org/fhir/sid/icd-10-cm", "code": "I10"}]},
+ "onsetDateTime": "2020-06-01T11:00:00Z",
+ },
+ ]
+
+
+def _two_patient_resources() -> List[Dict[str, Any]]:
+ return [
+ *_one_patient_resources(),
+ {"resourceType": "Patient", "id": "p-synth-2", "birthDate": "1940-05-05", "deceasedBoolean": True},
+ {
+ "resourceType": "Encounter",
+ "id": "e-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "period": {"start": "2020-07-01T10:00:00Z"},
+ "class": {"code": "IMP"},
+ },
+ {
+ "resourceType": "Observation",
+ "id": "o-dead",
+ "subject": {"reference": "Patient/p-synth-2"},
+ "encounter": {"reference": "Encounter/e-dead"},
+ "effectiveDateTime": "2020-07-01T12:00:00Z",
+ "code": {"coding": [{"system": "http://loinc.org", "code": "789-0"}]},
+ },
+ ]
+
+
+# ---------------------------------------------------------------------------
+# Text serialisers
+# ---------------------------------------------------------------------------
+
+
+def ndjson_one_patient_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in _one_patient_resources()) + "\n"
+
+
+def ndjson_two_class_text() -> str:
+ return "\n".join(orjson.dumps(r).decode("utf-8") for r in _two_patient_resources()) + "\n"
+
+
+# ---------------------------------------------------------------------------
+# Disk writers
+# ---------------------------------------------------------------------------
+
+
+def write_two_class_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_two_class_text(), encoding="utf-8")
+ return path
+
+
+def write_one_patient_ndjson(directory: Path, *, name: str = "fixture.ndjson") -> Path:
+ path = directory / name
+ path.write_text(ndjson_one_patient_text(), encoding="utf-8")
+ return path
+
+
+# ---------------------------------------------------------------------------
+# Shared test helper
+# ---------------------------------------------------------------------------
+
+
+def run_task(ds: Any, task: Any) -> List[Dict[str, Any]]:
+ """Run *task* over every patient in *ds* without the LitData caching pipeline.
+
+ This helper mirrors the direct-iteration path that the old
+ ``FHIRDataset.gather_samples`` provided. It is intentionally kept
+ here (the shared fixture module) so all FHIR test files can import it
+ rather than each maintaining their own copy.
+
+ Args:
+ ds: A :class:`~pyhealth.datasets.FHIRDataset` instance whose
+ ``global_event_df`` has already been built.
+ task: A :class:`~pyhealth.tasks.MPFClinicalPredictionTask` instance.
+
+ Returns:
+ Flat list of sample dicts, one per patient.
+ """
+ return [s for patient in ds.iter_patients() for s in task(patient)]
diff --git a/tests/core/test_mpf_task.py b/tests/core/test_mpf_task.py
new file mode 100644
index 000000000..a1b261719
--- /dev/null
+++ b/tests/core/test_mpf_task.py
@@ -0,0 +1,99 @@
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+from pyhealth.datasets import MIMIC4FHIR
+from pyhealth.processors.cehr_processor import PAD_TOKEN
+from pyhealth.tasks.mpf_clinical_prediction import MPFClinicalPredictionTask
+
+from tests.core.test_fhir_ndjson_fixtures import (
+ run_task,
+ write_two_class_ndjson,
+)
+
+
+class TestMPFClinicalPredictionTask(unittest.TestCase):
+ """Verifies the task emits boundary-marker strings at the expected
+ positions in its raw output. Vocab → integer-id mapping is the
+ ``CehrProcessor``'s job and is exercised separately via the standard
+ ``SampleBuilder.fit`` pipeline.
+ """
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._tmp = tempfile.mkdtemp()
+ write_two_class_ndjson(Path(cls._tmp))
+ # One shared build for the whole class; tests reuse it via run_task
+ # (run_task just applies the task to cached patients — no rebuild).
+ cls.ds = MIMIC4FHIR(
+ root=cls._tmp, glob_pattern="*.ndjson", cache_dir=cls._tmp
+ )
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ shutil.rmtree(cls._tmp, ignore_errors=True)
+
+ def test_max_len_validation(self) -> None:
+ with self.assertRaises(ValueError):
+ MPFClinicalPredictionTask(max_len=1, use_mpf=True)
+
+ def test_mpf_sets_boundary_tokens(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=True)
+ samples = run_task(self.ds, task)
+ self.assertGreater(len(samples), 0)
+ keys = samples[0]["concept_ids"]
+ first = next(i for i, x in enumerate(keys) if x != PAD_TOKEN)
+ last_nz = next(
+ i for i in range(len(keys) - 1, -1, -1) if keys[i] != PAD_TOKEN
+ )
+ self.assertEqual(keys[first], "")
+ self.assertEqual(keys[last_nz], "")
+ self.assertEqual(keys[-1], "")
+
+ def test_no_mpf_uses_cls_reg(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=32, use_mpf=False)
+ samples = run_task(self.ds, task)
+ keys = samples[0]["concept_ids"]
+ first = next(i for i, x in enumerate(keys) if x != PAD_TOKEN)
+ last_nz = next(
+ i for i in range(len(keys) - 1, -1, -1) if keys[i] != PAD_TOKEN
+ )
+ self.assertEqual(keys[first], "")
+ self.assertEqual(keys[last_nz], "")
+ self.assertEqual(keys[-1], "")
+
+ def test_schema_keys(self) -> None:
+ task = MPFClinicalPredictionTask(max_len=16, use_mpf=True)
+ samples = run_task(self.ds, task)
+ for k in task.input_schema:
+ self.assertIn(k, samples[0])
+ self.assertIn("label", samples[0])
+
+ def test_max_len_two_keeps_boundary_tokens(self) -> None:
+ """At ``max_len=2`` the sequence is exactly ``[, ]``."""
+
+ task = MPFClinicalPredictionTask(max_len=2, use_mpf=True)
+ samples = run_task(self.ds, task)
+ for s in samples:
+ keys = s["concept_ids"]
+ self.assertEqual(len(keys), 2)
+ self.assertEqual(keys[0], "")
+ self.assertEqual(keys[1], "")
+
+ def test_fixed_length_alignment(self) -> None:
+ """All six per-event lists must be the same length (max_len)."""
+
+ task = MPFClinicalPredictionTask(max_len=24, use_mpf=True)
+ samples = run_task(self.ds, task)
+ for s in samples:
+ self.assertEqual(len(s["concept_ids"]), 24)
+ self.assertEqual(len(s["token_type_ids"]), 24)
+ self.assertEqual(len(s["time_stamps"]), 24)
+ self.assertEqual(len(s["ages"]), 24)
+ self.assertEqual(len(s["visit_orders"]), 24)
+ self.assertEqual(len(s["visit_segments"]), 24)
+
+
+if __name__ == "__main__":
+ unittest.main()