From 0d627502a6ec122d2b7ac296c744d3e65e1ad075 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:05:02 +0000 Subject: [PATCH 1/4] refactor hf_sim to use xarray/dask for distributed station processing - Add HostType(str, Enum) with 'local' and 'slurm' values - Add dask_cluster context manager for LocalCluster/SLURMCluster - Add load_hf_dataset: loads stations into chunked xarray Dataset - Add process_hf_dataset: map_blocks-compatible station processor - Refactor run_hf CLI to use --host option and Dask workflow - Replace ThreadPoolExecutor with xr.map_blocks + dask.distributed - Add dask[distributed] and dask-jobqueue dependencies - Add tests for HostType, load_hf_dataset, process_hf_dataset Agent-Logs-Url: https://github.com/ucgmsim/workflow/sessions/66b1316d-545e-44d8-96d2-317087688eed Co-authored-by: lispandfound <12835929+lispandfound@users.noreply.github.com> --- pyproject.toml | 3 + tests/test_hf.py | 129 ++++++++++++++++ workflow/scripts/hf_sim.py | 291 ++++++++++++++++++++++++++++++------- 3 files changed, 373 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fdd6492a..1d6d48ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ dependencies = [ "pyyaml", "xarray[io]", + # Distributed Computing + "dask[distributed]", + "dask-jobqueue", # Numerics "numpy", "scipy", diff --git a/tests/test_hf.py b/tests/test_hf.py index 7d72c8b6..ac7a3fb9 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -2,6 +2,7 @@ from types import SimpleNamespace import numpy as np +import pandas as pd import pytest from hypothesis import given from hypothesis import strategies as st @@ -12,6 +13,7 @@ RuptureVelocity, ) from workflow.scripts import hf_sim +from workflow.scripts.hf_sim import HostType def test_build_hf_input_serialisation() -> None: @@ -169,3 +171,130 @@ def test_create_hf_dataset_structure() -> None: assert ds.attrs["dt"] == dt assert ds.attrs["nt"] == n_time assert ds.attrs["units"] == "cm/s^2" + + +def test_host_type_values() -> None: + assert HostType.local == "local" + assert HostType.slurm == "slurm" + assert HostType("local") is HostType.local + assert HostType("slurm") is HostType.slurm + + +def test_load_hf_dataset(tmp_path: Path) -> None: + station_file = tmp_path / "stations.ll" + station_file.write_text("172.6 -43.5 STAT_A\n172.7 -43.6 STAT_B\n172.8 -43.7 STAT_C\n") + + seeds = SimpleNamespace(hf_seed=42) + resolution = Resolution(resolution=0.1) + hf_config = SimpleNamespace(t_sec=0.0) + velocity_model = SimpleNamespace( + model=pd.DataFrame({"Vs": [0.5]}), + ) + domain_parameters = SimpleNamespace(duration=100.0) + + ds = hf_sim.load_hf_dataset( + station_file, + seeds, # type: ignore[arg-type] + resolution, + hf_config, # type: ignore[arg-type] + velocity_model, # type: ignore[arg-type] + domain_parameters, # type: ignore[arg-type] + ) + + assert "station" in ds.dims + assert ds.sizes["station"] == 3 + np.testing.assert_array_equal( + ds.station.values, ["STAT_A", "STAT_B", "STAT_C"] + ) + + assert "latitude" in ds.data_vars + assert "longitude" in ds.data_vars + assert "seed" in ds.data_vars + assert "vref" in ds.data_vars + + np.testing.assert_allclose(ds["latitude"].values, [-43.5, -43.6, -43.7]) + np.testing.assert_allclose(ds["longitude"].values, [172.6, 172.7, 172.8]) + + assert ds.attrs["nt"] > 0 + assert ds.attrs["dt"] == pytest.approx(0.005) + assert ds.attrs["start_sec"] == 0.0 + + # vref should be Vs * 1000 + np.testing.assert_allclose(ds["vref"].values, [500.0, 500.0, 500.0]) + + +def test_load_hf_dataset_chunking(tmp_path: Path) -> None: + # Create a station file with 1500 stations to test chunking logic + station_file = tmp_path / "stations.ll" + lines = [f"172.{i % 10} -43.{i % 10} STAT_{i:04d}" for i in range(1500)] + station_file.write_text("\n".join(lines) + "\n") + + seeds = SimpleNamespace(hf_seed=42) + resolution = Resolution(resolution=0.1) + hf_config = SimpleNamespace(t_sec=0.0) + velocity_model = SimpleNamespace( + model=pd.DataFrame({"Vs": [0.5]}), + ) + domain_parameters = SimpleNamespace(duration=10.0) + + ds = hf_sim.load_hf_dataset( + station_file, + seeds, # type: ignore[arg-type] + resolution, + hf_config, # type: ignore[arg-type] + velocity_model, # type: ignore[arg-type] + domain_parameters, # type: ignore[arg-type] + ) + + # chunk_size = max(1, 1500 // 500) = 3 + assert ds.chunks is not None + station_chunks = ds.chunks["station"] + assert all(c <= 3 for c in station_chunks) + + +def test_process_hf_dataset_structure() -> None: + import xarray as xr + + nt = 100 + dt = 0.02 + station_names = np.array(["STAT_A", "STAT_B"]) + + input_ds = xr.Dataset( + { + "latitude": ("station", np.array([-43.5, -43.6])), + "longitude": ("station", np.array([172.6, 172.7])), + "seed": ("station", np.array([123, 456])), + "vref": ("station", np.array([500.0, 500.0])), + }, + coords={"station": ("station", station_names)}, + attrs={"nt": nt, "dt": dt, "start_sec": 0.0}, + ) + + # We can't actually run the binary, but we can test that the function + # signature is correct and the output structure is as expected by + # mocking hf_simulate_station + import unittest.mock as mock + + mock_waveform = np.random.rand(nt, 3).astype(np.float32) + + with mock.patch.object( + hf_sim, + "hf_simulate_station", + side_effect=[ + ("STAT_A", 10.5, mock_waveform), + ("STAT_B", 20.1, mock_waveform), + ], + ): + result = hf_sim.process_hf_dataset( + input_ds, + hf_sim_path="/fake/path", + hf_input_template="template", + ) + + assert "waveform" in result.data_vars + assert "epicentre_distance" in result.data_vars + assert result["waveform"].dims == ("component", "station", "time") + assert result.sizes == {"component": 3, "station": 2, "time": nt} + assert result["epicentre_distance"].dims == ("station",) + np.testing.assert_array_equal(result.station.values, station_names) + np.testing.assert_array_equal(result.component.values, ["x", "y", "z"]) diff --git a/workflow/scripts/hf_sim.py b/workflow/scripts/hf_sim.py index 99048001..156df024 100644 --- a/workflow/scripts/hf_sim.py +++ b/workflow/scripts/hf_sim.py @@ -23,6 +23,10 @@ > [!NOTE] > The high-frequency code is very brittle. It is recommended you have both versions 6.0.3 and 5.4.5 built to run with. Sometimes it is necessary to switch between versions if one does not work. +> [!NOTE] +> Dask worker memory limits should account for the external binary's footprint, +> as `hb_high_binmod` runs as a subprocess and its memory usage is not tracked by Dask. + Usage ----- `hf-sim [OPTIONS] REALISATION_FFP STOCH_FFP STATION_FILE OUT_FILE` @@ -32,11 +36,11 @@ See the output of `hf-sim --help`. """ -import concurrent.futures import subprocess import tempfile -from collections.abc import Iterable -from concurrent.futures.thread import ThreadPoolExecutor +from collections.abc import Generator, Iterable +from contextlib import contextmanager +from enum import Enum from pathlib import Path from typing import Annotated @@ -45,6 +49,7 @@ import pandas as pd import typer import xarray as xr +from dask.distributed import Client, LocalCluster from qcore import cli from workflow import log_utils, realisations, utils @@ -61,6 +66,44 @@ app = typer.Typer() +class HostType(str, Enum): + """Cluster host type for Dask scheduling.""" + + local = "local" + slurm = "slurm" + + +@contextmanager +def dask_cluster(host: HostType) -> Generator[Client, None, None]: + """Create and manage a Dask distributed client for the given host type. + + Parameters + ---------- + host : HostType + The host type, either ``local`` for a `LocalCluster` or + ``slurm`` for a `SLURMCluster`. + + Yields + ------ + Client + A connected Dask distributed client. + """ + if host == HostType.local: + cluster = LocalCluster() + else: + from dask_jobqueue import SLURMCluster + + cluster = SLURMCluster( + cores=2, + memory="4GB", + walltime="01:00:00", + account="default", + ) + cluster.adapt(minimum=1, maximum=50) + with Client(cluster) as client: + yield client + + def rupture_velocity_hf_transition_bands( rupture_velocity: RuptureVelocity, ) -> tuple[float, float, float, float]: @@ -272,6 +315,146 @@ def station_seeds(seed: int, stations: Iterable[str]) -> npt.NDArray[np.int32]: return np.int32(seed) ^ station_hashes +def load_hf_dataset( + station_file: Path, + seeds: Seeds, + resolution: Resolution, + hf_config: HFConfig, + velocity_model: HFVelocityModel1D, + domain_parameters: DomainParameters, +) -> xr.Dataset: + """Load station data into a chunked xarray Dataset for distributed processing. + + Parameters + ---------- + station_file : Path + Path to station CSV file (columns: longitude, latitude, name). + seeds : Seeds + Seed configuration for the simulation. + resolution : Resolution + HF simulation resolution. + hf_config : HFConfig + The high-frequency config. + velocity_model : HFVelocityModel1D + The 1D velocity model. + domain_parameters : DomainParameters + The simulation domain parameters. + + Returns + ------- + xr.Dataset + A dataset indexed by ``station`` with variables for latitude, + longitude, seed, and vref, chunked for distributed processing. + """ + stations = pd.read_csv( + station_file, + delimiter=r"\s+", + header=None, + names=["longitude", "latitude", "name"], + ).set_index("name") + + seeds_array = station_seeds(seeds.hf_seed, stations.index) + vs = velocity_model.model["Vs"].iloc[0] * 1000 + vref = np.full(len(stations), vs, dtype=np.float64) + + nt = int( + np.float32(domain_parameters.duration) / np.float32(resolution.dt) + ) + + total_stations = len(stations) + chunk_size = max(1, total_stations // 500) + + ds = xr.Dataset( + { + "latitude": ("station", stations["latitude"].values), + "longitude": ("station", stations["longitude"].values), + "seed": ("station", seeds_array), + "vref": ("station", vref), + }, + coords={ + "station": ("station", stations.index.values.astype(str)), + }, + attrs={ + "nt": nt, + "dt": resolution.dt, + "start_sec": hf_config.t_sec, + }, + ) + return ds.chunk({"station": chunk_size}) + + +def process_hf_dataset( + ds: xr.Dataset, + *, + hf_sim_path: str, + hf_input_template: str, +) -> xr.Dataset: + """Process a chunk of the HF dataset by running station simulations. + + Designed to be used with :func:`xarray.map_blocks`. Iterates over the + stations in the chunk and executes ``hf_simulate_station`` for each one. + + Parameters + ---------- + ds : xr.Dataset + A chunk of the input dataset with ``latitude``, ``longitude``, + and ``seed`` variables indexed by ``station``. + hf_sim_path : str + Path to the HF simulation binary (passed as string for + serialization). + hf_input_template : str + The stdin input template for the HF simulation binary. + + Returns + ------- + xr.Dataset + A dataset containing ``waveform`` (dims: component, station, time) + and ``epicentre_distance`` (dims: station) for the stations in the + chunk. + """ + station_names = ds.station.values + n_stations = len(station_names) + nt = ds.attrs["nt"] + + waveform = np.empty((3, n_stations, nt), dtype=np.float32) + epicentre_distances = np.empty(n_stations, dtype=np.float64) + + for i, station_name in enumerate(station_names): + lat = float(ds["latitude"].values[i]) + lon = float(ds["longitude"].values[i]) + seed = int(ds["seed"].values[i]) + + _, epicentre, station_waveform = hf_simulate_station( + Path(hf_sim_path), + hf_input_template, + lat, + lon, + str(station_name), + seed, + ) + epicentre_distances[i] = epicentre + for component in range(3): + waveform[component, i] = station_waveform[:, component] + + dt = ds.attrs["dt"] + time_coords = np.arange(nt) * dt + + return xr.Dataset( + { + "waveform": ( + ["component", "station", "time"], + waveform, + ), + "epicentre_distance": (["station"], epicentre_distances), + }, + coords={ + "station": ("station", station_names), + "component": ("component", ["x", "y", "z"]), + "time": ("time", time_coords), + }, + ) + + def create_hf_dataset( # array-like used here to reduce the number of times we have to # change the types if the downstream function inputs change. @@ -369,15 +552,15 @@ def run_hf( Path, typer.Option(exists=True, writable=True, file_okay=False), ] = Path("/out"), + host: Annotated[HostType, typer.Option()] = HostType.local, ) -> None: """Run the HF (High-Frequency) simulation and generate the HF output file. This function performs the following steps: 1. Reads configuration and domain parameters from the realisation file. - 2. Filters stations based on their location relative to the domain. - 3. Uses multiprocessing to simulate each station and calculate epicentre distances. - 4. Reads the velocity model and calculates the `vs` value. - 5. Writes the HF output file, including header and station-specific data. + 2. Loads station data into a chunked xarray Dataset. + 3. Uses Dask distributed to simulate each station chunk in parallel. + 4. Writes the HF output file in NetCDF format. Parameters ---------- @@ -393,6 +576,9 @@ def run_hf( Path to the HF simulation binary. work_directory : Path, optional Directory for intermediate files. Must be writable. + host : HostType, optional + Dask cluster host type. Use ``local`` for a local cluster or + ``slurm`` for a SLURM cluster. Defaults to ``local``. Returns ------- @@ -416,19 +602,8 @@ def run_hf( realisation_ffp, metadata.defaults_version ) - stations = pd.read_csv( - station_file, - delimiter=r"\s+", - header=None, - names=["longitude", "latitude", "name"], - ).set_index("name") - stations["seed"] = station_seeds(seeds.hf_seed, stations.index) velocity_model_path = work_directory / "velocity_model" velocity_model.write_velocity_model(velocity_model_path) - nt = int( - np.float32(domain_parameters.duration) / np.float32(resolution.dt) - ) # Match Fortran's single-precision for consistent nt calculation - waveform = np.empty((3, len(stations), nt), dtype=np.float32) hf_input_template = build_hf_input( stoch_ffp, @@ -438,42 +613,58 @@ def run_hf( rupture_velocity, domain_parameters, ) - stations["epicentre_distance"] = np.nan - - with ThreadPoolExecutor(max_workers=utils.get_available_cores()) as executor: - station_index = {station: i for i, station in enumerate(stations.index)} - futures = [ - executor.submit( - hf_simulate_station, - hf_sim_path, - hf_input_template, - station["latitude"], - station["longitude"], - str(name), - int(station["seed"]), - ) - for name, station in stations.iterrows() - ] - for future in concurrent.futures.as_completed(futures): - station, epicentre, station_waveform = future.result() - stations.loc[station, "epicentre_distance"] = epicentre - i = station_index[station] - for component in range(3): - waveform[component, i] = station_waveform[:, component] + input_ds = load_hf_dataset( + station_file, + seeds, + resolution, + hf_config, + velocity_model, + domain_parameters, + ) + + nt = input_ds.attrs["nt"] + dt = input_ds.attrs["dt"] + station_names = input_ds.station.values + n_stations = len(station_names) + time_coords = np.arange(nt) * dt - vs = velocity_model.model["Vs"].iloc[0] * 1000 - stations["vs"] = vs + template = xr.Dataset( + { + "waveform": ( + ["component", "station", "time"], + np.empty((3, n_stations, nt), dtype=np.float32), + ), + "epicentre_distance": (["station"], np.empty(n_stations, dtype=np.float64)), + }, + coords={ + "station": ("station", station_names), + "component": ("component", ["x", "y", "z"]), + "time": ("time", time_coords), + }, + ) + + with dask_cluster(host) as client: + result_ds = xr.map_blocks( + process_hf_dataset, + input_ds, + template=template, + kwargs={ + "hf_sim_path": str(hf_sim_path), + "hf_input_template": hf_input_template, + }, + ) + result_ds = client.compute(result_ds).result() ds = create_hf_dataset( - waveform=waveform, - latitude=stations["latitude"], - longitude=stations["longitude"], - names=stations.index, - epicentre_distance=stations["epicentre_distance"], - seed=stations["seed"], - vref=stations["vs"], - dt=resolution.dt, + waveform=result_ds["waveform"].values, + latitude=input_ds["latitude"].values, + longitude=input_ds["longitude"].values, + names=station_names, + epicentre_distance=result_ds["epicentre_distance"].values, + seed=input_ds["seed"].values, + vref=input_ds["vref"].values, + dt=dt, start_sec=hf_config.t_sec, ) ds.to_netcdf(out_file, engine="h5netcdf") From 8bc6d62d1141bb0b71c514b8ca7df6fed3c8ed87 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:06:34 +0000 Subject: [PATCH 2/4] address review: extract _TARGET_TASK_COUNT constant, improve chunking test Agent-Logs-Url: https://github.com/ucgmsim/workflow/sessions/66b1316d-545e-44d8-96d2-317087688eed Co-authored-by: lispandfound <12835929+lispandfound@users.noreply.github.com> --- tests/test_hf.py | 1 + workflow/scripts/hf_sim.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_hf.py b/tests/test_hf.py index ac7a3fb9..fb26b5b7 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -250,6 +250,7 @@ def test_load_hf_dataset_chunking(tmp_path: Path) -> None: assert ds.chunks is not None station_chunks = ds.chunks["station"] assert all(c <= 3 for c in station_chunks) + assert sum(station_chunks) == 1500 def test_process_hf_dataset_structure() -> None: diff --git a/workflow/scripts/hf_sim.py b/workflow/scripts/hf_sim.py index 156df024..0a69abe1 100644 --- a/workflow/scripts/hf_sim.py +++ b/workflow/scripts/hf_sim.py @@ -65,6 +65,15 @@ app = typer.Typer() +_TARGET_TASK_COUNT = 500 +"""Target number of Dask tasks for station chunking. + +Chosen to balance scheduler overhead against parallelism: too few tasks +under-utilise workers, while too many (e.g. one per station at 100 k+) +flood the scheduler with graph overhead. A value of 500–1 000 keeps the +task graph manageable while still saturating a large cluster. +""" + class HostType(str, Enum): """Cluster host type for Dask scheduling.""" @@ -362,7 +371,7 @@ def load_hf_dataset( ) total_stations = len(stations) - chunk_size = max(1, total_stations // 500) + chunk_size = max(1, total_stations // _TARGET_TASK_COUNT) ds = xr.Dataset( { From 995399215ae174af0e7bcd84b65a54ce01c74cd8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:57:03 +0000 Subject: [PATCH 3/4] remove SLURM support, make pipeline fully lazy for larger-than-memory output - Remove HostType enum, dask_cluster context manager, --host CLI option - Remove dask-jobqueue dependency - Use dask.array.empty for template to avoid allocating memory - Remove all eager .compute()/.result()/.values on large data - Let Dask graph flow lazily to to_netcdf for chunk-by-chunk writing - Remove test_host_type_values test Agent-Logs-Url: https://github.com/ucgmsim/workflow/sessions/dee90e3b-9cd0-4ac3-afcc-58a8a3d10c06 Co-authored-by: lispandfound <12835929+lispandfound@users.noreply.github.com> --- pyproject.toml | 1 - tests/test_hf.py | 8 --- workflow/scripts/hf_sim.py | 103 ++++++++++++++----------------------- 3 files changed, 39 insertions(+), 73 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1d6d48ed..94e632db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ dependencies = [ # Distributed Computing "dask[distributed]", - "dask-jobqueue", # Numerics "numpy", "scipy", diff --git a/tests/test_hf.py b/tests/test_hf.py index fb26b5b7..b01b2d62 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -13,7 +13,6 @@ RuptureVelocity, ) from workflow.scripts import hf_sim -from workflow.scripts.hf_sim import HostType def test_build_hf_input_serialisation() -> None: @@ -173,13 +172,6 @@ def test_create_hf_dataset_structure() -> None: assert ds.attrs["units"] == "cm/s^2" -def test_host_type_values() -> None: - assert HostType.local == "local" - assert HostType.slurm == "slurm" - assert HostType("local") is HostType.local - assert HostType("slurm") is HostType.slurm - - def test_load_hf_dataset(tmp_path: Path) -> None: station_file = tmp_path / "stations.ll" station_file.write_text("172.6 -43.5 STAT_A\n172.7 -43.6 STAT_B\n172.8 -43.7 STAT_C\n") diff --git a/workflow/scripts/hf_sim.py b/workflow/scripts/hf_sim.py index 0a69abe1..8f672d65 100644 --- a/workflow/scripts/hf_sim.py +++ b/workflow/scripts/hf_sim.py @@ -38,12 +38,11 @@ import subprocess import tempfile -from collections.abc import Generator, Iterable -from contextlib import contextmanager -from enum import Enum +from collections.abc import Iterable from pathlib import Path from typing import Annotated +import dask.array as da import numpy as np import numpy.typing as npt import pandas as pd @@ -75,44 +74,6 @@ """ -class HostType(str, Enum): - """Cluster host type for Dask scheduling.""" - - local = "local" - slurm = "slurm" - - -@contextmanager -def dask_cluster(host: HostType) -> Generator[Client, None, None]: - """Create and manage a Dask distributed client for the given host type. - - Parameters - ---------- - host : HostType - The host type, either ``local`` for a `LocalCluster` or - ``slurm`` for a `SLURMCluster`. - - Yields - ------ - Client - A connected Dask distributed client. - """ - if host == HostType.local: - cluster = LocalCluster() - else: - from dask_jobqueue import SLURMCluster - - cluster = SLURMCluster( - cores=2, - memory="4GB", - walltime="01:00:00", - account="default", - ) - cluster.adapt(minimum=1, maximum=50) - with Client(cluster) as client: - yield client - - def rupture_velocity_hf_transition_bands( rupture_velocity: RuptureVelocity, ) -> tuple[float, float, float, float]: @@ -561,15 +522,15 @@ def run_hf( Path, typer.Option(exists=True, writable=True, file_okay=False), ] = Path("/out"), - host: Annotated[HostType, typer.Option()] = HostType.local, ) -> None: """Run the HF (High-Frequency) simulation and generate the HF output file. This function performs the following steps: 1. Reads configuration and domain parameters from the realisation file. 2. Loads station data into a chunked xarray Dataset. - 3. Uses Dask distributed to simulate each station chunk in parallel. - 4. Writes the HF output file in NetCDF format. + 3. Uses Dask to lazily simulate each station chunk in parallel. + 4. Writes the HF output file in NetCDF format chunk-by-chunk to + support larger-than-memory datasets. Parameters ---------- @@ -585,9 +546,6 @@ def run_hf( Path to the HF simulation binary. work_directory : Path, optional Directory for intermediate files. Must be writable. - host : HostType, optional - Dask cluster host type. Use ``local`` for a local cluster or - ``slurm`` for a SLURM cluster. Defaults to ``local``. Returns ------- @@ -634,17 +592,25 @@ def run_hf( nt = input_ds.attrs["nt"] dt = input_ds.attrs["dt"] + # Station coordinate labels are always in-memory in xarray (not + # dask-backed), so accessing .values here is safe and necessary to + # construct the template. station_names = input_ds.station.values n_stations = len(station_names) time_coords = np.arange(nt) * dt + # Use dask.array.empty so the template does not allocate memory for + # the full waveform array (which can be 100 GB+ for large runs). template = xr.Dataset( { "waveform": ( ["component", "station", "time"], - np.empty((3, n_stations, nt), dtype=np.float32), + da.empty((3, n_stations, nt), dtype=np.float32), + ), + "epicentre_distance": ( + ["station"], + da.empty(n_stations, dtype=np.float64), ), - "epicentre_distance": (["station"], np.empty(n_stations, dtype=np.float64)), }, coords={ "station": ("station", station_names), @@ -653,7 +619,7 @@ def run_hf( }, ) - with dask_cluster(host) as client: + with LocalCluster() as cluster, Client(cluster): result_ds = xr.map_blocks( process_hf_dataset, input_ds, @@ -663,18 +629,27 @@ def run_hf( "hf_input_template": hf_input_template, }, ) - result_ds = client.compute(result_ds).result() - - ds = create_hf_dataset( - waveform=result_ds["waveform"].values, - latitude=input_ds["latitude"].values, - longitude=input_ds["longitude"].values, - names=station_names, - epicentre_distance=result_ds["epicentre_distance"].values, - seed=input_ds["seed"].values, - vref=input_ds["vref"].values, - dt=dt, - start_sec=hf_config.t_sec, - ) - ds.to_netcdf(out_file, engine="h5netcdf") + + # Attach station metadata from the input dataset. These are + # small 1-D arrays (one value per station) so they do not + # contribute to memory pressure. + result_ds["seed"] = input_ds["seed"] + result_ds["vref"] = input_ds["vref"] + result_ds = result_ds.assign_coords( + lat=input_ds["latitude"], + lon=input_ds["longitude"], + ) + result_ds.attrs.update( + { + "start_sec": hf_config.t_sec, + "nt": nt, + "dt": dt, + "units": "cm/s^2", + } + ) + + # Write lazily — Dask streams chunks to disk one at a time so + # the full waveform array never needs to reside in memory. + result_ds.to_netcdf(out_file, engine="h5netcdf") + realisations.append_log_entry(realisation_ffp) From 9b6983991de2c427324822131e4d1993841a62ca Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 10:19:41 +0000 Subject: [PATCH 4/4] refactor test_hf to use xarray testing facilities Replace np.testing.assert_allclose/.assert_array_equal with xr.testing.assert_allclose/assert_equal for dataset and coordinate comparisons in test_load_hf_dataset and test_process_hf_dataset_structure. Agent-Logs-Url: https://github.com/ucgmsim/workflow/sessions/eb2dbecf-d37e-45e1-bc15-020c0d78111f Co-authored-by: lispandfound <12835929+lispandfound@users.noreply.github.com> --- tests/test_hf.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/test_hf.py b/tests/test_hf.py index b01b2d62..561ac1d2 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from hypothesis import given from hypothesis import strategies as st @@ -193,27 +194,24 @@ def test_load_hf_dataset(tmp_path: Path) -> None: domain_parameters, # type: ignore[arg-type] ) - assert "station" in ds.dims - assert ds.sizes["station"] == 3 - np.testing.assert_array_equal( - ds.station.values, ["STAT_A", "STAT_B", "STAT_C"] + station_names = ["STAT_A", "STAT_B", "STAT_C"] + expected = xr.Dataset( + { + "latitude": ("station", [-43.5, -43.6, -43.7]), + "longitude": ("station", [172.6, 172.7, 172.8]), + "vref": ("station", [500.0, 500.0, 500.0]), + }, + coords={"station": ("station", station_names)}, + ) + xr.testing.assert_allclose( + ds[["latitude", "longitude", "vref"]], expected ) - assert "latitude" in ds.data_vars - assert "longitude" in ds.data_vars assert "seed" in ds.data_vars - assert "vref" in ds.data_vars - - np.testing.assert_allclose(ds["latitude"].values, [-43.5, -43.6, -43.7]) - np.testing.assert_allclose(ds["longitude"].values, [172.6, 172.7, 172.8]) - assert ds.attrs["nt"] > 0 assert ds.attrs["dt"] == pytest.approx(0.005) assert ds.attrs["start_sec"] == 0.0 - # vref should be Vs * 1000 - np.testing.assert_allclose(ds["vref"].values, [500.0, 500.0, 500.0]) - def test_load_hf_dataset_chunking(tmp_path: Path) -> None: # Create a station file with 1500 stations to test chunking logic @@ -246,8 +244,6 @@ def test_load_hf_dataset_chunking(tmp_path: Path) -> None: def test_process_hf_dataset_structure() -> None: - import xarray as xr - nt = 100 dt = 0.02 station_names = np.array(["STAT_A", "STAT_B"]) @@ -289,5 +285,8 @@ def test_process_hf_dataset_structure() -> None: assert result["waveform"].dims == ("component", "station", "time") assert result.sizes == {"component": 3, "station": 2, "time": nt} assert result["epicentre_distance"].dims == ("station",) - np.testing.assert_array_equal(result.station.values, station_names) - np.testing.assert_array_equal(result.component.values, ["x", "y", "z"]) + xr.testing.assert_equal(result.station, input_ds.station) + expected_components = xr.DataArray( + ["x", "y", "z"], coords={"component": ["x", "y", "z"]}, dims="component" + ) + xr.testing.assert_equal(result.component, expected_components)