Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/conf_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Resume is rejected when any of the following checks fail:
- the infiltration model changes between the archived run and the resumed run;
- a surface-flow parameter changes outside the set explicitly allowed on
resume;
- ``start_time`` differs from the value stored in the hotstart;
- ``end_time`` is changed to a value that is not strictly after the archived
checkpoint time;

Expand Down Expand Up @@ -139,9 +140,10 @@ The table below reflects the current implementation.
- A run with drainage can only resume from a hotstart that also contains
drainage state, and vice versa.
* - ``[time] start_time``
- Keep unchanged
- The archived simulation clock and scheduler are restored from the
hotstart. They are not remapped to a new start time.
- Must match
- Changing it raises a hotstart compatibility error. The archived
simulation clock and scheduler are restored from the hotstart and are
not remapped to a new start time.
* - Input map names, ``[drainage] swmm_inp``, and other forcing paths
- Not cross-checked
- Itzi validates the resumed domain and mask, but it does not verify that
Expand Down
8 changes: 8 additions & 0 deletions src/itzi/simulation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def _validate_resume_config_congruence(
"""Validate which runtime settings may change across a hotstart resume."""
hotstart_sim_time = hotstart_state.sim_time

if self.sim_config.start_time != hotstart_config.start_time:
raise HotstartError(
"Hotstart start_time mismatch: "
f"current={self.sim_config.start_time}, "
f"hotstart={hotstart_config.start_time}. "
"Resume must keep the archived start_time unchanged."
)

# Keep this defensive check here even though SimulationConfig also validates
# user input: model_copy(update=...) can bypass Pydantic validation in tests
# and internal resume flows.
Expand Down
14 changes: 14 additions & 0 deletions tests/core/test_hotstart_state_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,20 @@ def test_build_allows_record_step_change(

assert simulation.report.dt == resumed_record_step

def test_build_rejects_start_time_mismatch(
self,
domain_5by5,
sim_config: SimulationConfig,
valid_hotstart_bytes: io.BytesIO,
) -> None:
"""build() should reject resumed start times that differ from the hotstart."""
resumed_config = sim_config.model_copy(
update={"start_time": sim_config.start_time + timedelta(seconds=5)}
)

with pytest.raises(HotstartError, match="start_time"):
self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes)

def test_build_rejects_end_time_not_after_hotstart_time(
self,
domain_5by5,
Expand Down
254 changes: 254 additions & 0 deletions tests/core/test_hotstart_timed_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
"""Hotstart integration tests for provider-backed timed xarray inputs."""

from __future__ import annotations

from datetime import datetime, timedelta
from typing import TYPE_CHECKING

import numpy as np
import pytest

pytest.importorskip("xarray")

import xarray as xr

from itzi.const import InfiltrationModelType, TemporalType
from itzi.data_containers import SimulationConfig, SurfaceFlowParameters
from itzi.providers.memory_output import MemoryRasterOutputProvider, MemoryVectorOutputProvider
from itzi.providers.xarray_input import XarrayRasterInputProvider
from itzi.simulation_builder import SimulationBuilder

if TYPE_CHECKING:
from itzi.simulation import Simulation


# Mark all tests in this module as cloud tests
pytestmark = pytest.mark.cloud

TIME_SLICE_SECONDS = (0, 10, 20, 30)
RAIN_MM_PER_HOUR = {
0: 0.0,
10: 180.0,
20: 360.0,
30: 540.0,
}


def _run_to_end(simulation: "Simulation", *, skip_initialize: bool = False) -> None:
if not skip_initialize:
simulation.initialize()
while simulation.sim_time < simulation.end_time:
simulation.update()
simulation.finalize()


def _assert_final_state_matches(resumed: "Simulation", reference: "Simulation") -> None:
for key in ["water_depth", "qe", "qs"]:
np.testing.assert_allclose(
resumed.raster_domain.get_array(key),
reference.raster_domain.get_array(key),
rtol=1e-5,
atol=1e-7,
err_msg=f"Final {key} mismatch",
)


def _make_xarray_input_dataset(
domain_5by5,
start_time: datetime,
*,
use_relative_time: bool,
) -> tuple[xr.Dataset, dict[int, np.ndarray]]:
coordinates = domain_5by5.domain_data.get_coordinates()
shape = domain_5by5.domain_data.shape

if use_relative_time:
time_coords = np.array(
[np.timedelta64(seconds, "s") for seconds in TIME_SLICE_SECONDS],
dtype="timedelta64[s]",
)
else:
time_coords = np.array(
[
np.datetime64(start_time + timedelta(seconds=seconds), "s")
for seconds in TIME_SLICE_SECONDS
],
dtype="datetime64[s]",
)

expected_rain_arrays: dict[int, np.ndarray] = {}
rain_stack: list[np.ndarray] = []
for seconds in TIME_SLICE_SECONDS:
rain_mm_per_hour = RAIN_MM_PER_HOUR[seconds]
rain_stack.append(np.full(shape, rain_mm_per_hour, dtype=np.float32))
expected_rain_arrays[seconds] = np.full(
shape,
rain_mm_per_hour / (1000 * 3600),
dtype=np.float32,
)

dataset = xr.Dataset(
{
"dem": (["y", "x"], domain_5by5.arr_dem_flat.copy()),
"friction": (["y", "x"], domain_5by5.arr_n.copy()),
"water_depth": (["y", "x"], domain_5by5.arr_start_h.copy()),
"rain": (["time", "y", "x"], np.stack(rain_stack, axis=0)),
},
coords={
"time": time_coords,
"x": coordinates["x"],
"y": coordinates["y"],
},
attrs={"crs_wkt": domain_5by5.domain_data.crs_wkt},
)
return dataset, expected_rain_arrays


def _make_simulation_config(
start_time: datetime,
end_time: datetime,
*,
temporal_type: TemporalType,
) -> SimulationConfig:
return SimulationConfig(
start_time=start_time,
end_time=end_time,
record_step=timedelta(seconds=10),
temporal_type=temporal_type,
input_map_names={
"dem": "dem",
"friction": "friction",
"water_depth": "water_depth",
"rain": "rain",
},
output_map_names={"water_depth": "out_hotstart_timed_inputs_water_depth"},
surface_flow_parameters=SurfaceFlowParameters(hmin=0.0001, dtmax=0.3, cfl=0.2),
infiltration_model=InfiltrationModelType.NULL,
)


def _build_provider_simulation(
sim_config: SimulationConfig,
domain_5by5,
dataset: xr.Dataset,
*,
hotstart_bytes: bytes | None = None,
) -> "Simulation":
input_provider = XarrayRasterInputProvider(
{
"dataset": dataset,
"input_map_names": sim_config.input_map_names,
"simulation_start_time": sim_config.start_time,
"simulation_end_time": sim_config.end_time,
}
)

builder = (
SimulationBuilder(sim_config, domain_5by5.arr_mask, np.float32)
.with_input_provider(input_provider)
.with_raster_output_provider(
MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names})
)
.with_vector_output_provider(MemoryVectorOutputProvider({}))
)
if hotstart_bytes is not None:
builder.with_hotstart(hotstart_bytes)
return builder.build()


def _run_reference_with_hotstart_checkpoint(
sim_config: SimulationConfig,
domain_5by5,
dataset: xr.Dataset,
split_target_time: datetime,
) -> tuple[dict[str, datetime | np.ndarray], bytes, "Simulation"]:
simulation = _build_provider_simulation(sim_config, domain_5by5, dataset)
simulation.initialize()

while simulation.sim_time < split_target_time:
simulation.update()

checkpoint = {
"sim_time": simulation.sim_time,
"rain": simulation.raster_domain.get_array("rain").copy(),
}
hotstart_bytes = simulation.create_hotstart().getvalue()

while simulation.sim_time < simulation.end_time:
simulation.update()
simulation.finalize()

return checkpoint, hotstart_bytes, simulation


def _assert_resume_with_timed_xarray_inputs(
domain_5by5,
*,
use_relative_time: bool,
) -> None:
start_time = datetime(2000, 1, 1, 0, 0, 0)
end_time = start_time + timedelta(seconds=25)
split_target_time = start_time + timedelta(seconds=12)
temporal_type = TemporalType.RELATIVE if use_relative_time else TemporalType.ABSOLUTE

dataset, expected_rain_arrays = _make_xarray_input_dataset(
domain_5by5,
start_time,
use_relative_time=use_relative_time,
)
sim_config = _make_simulation_config(
start_time,
end_time,
temporal_type=temporal_type,
)
checkpoint, hotstart_bytes, reference = _run_reference_with_hotstart_checkpoint(
sim_config,
domain_5by5,
dataset,
split_target_time,
)

saved_sim_time = checkpoint["sim_time"]
second_slice_start = start_time + timedelta(seconds=10)
second_slice_end = start_time + timedelta(seconds=20)
assert second_slice_start <= saved_sim_time < second_slice_end

resumed = _build_provider_simulation(
sim_config,
domain_5by5,
dataset,
hotstart_bytes=hotstart_bytes,
)

assert resumed.sim_time == saved_sim_time
np.testing.assert_allclose(resumed.raster_domain.get_array("rain"), checkpoint["rain"])
np.testing.assert_allclose(resumed.raster_domain.get_array("rain"), expected_rain_arrays[10])

# TimedArray cache is rebuilt from the fresh provider during construction.
# After resume, the first update must realign that cache to the restored clock.
resumed.update()

assert resumed.timed_arrays is not None
rain_timed_array = resumed.timed_arrays["rain"]
assert second_slice_start <= resumed.sim_time < second_slice_end
assert rain_timed_array.arr_start == second_slice_start
assert rain_timed_array.arr_end == second_slice_end
np.testing.assert_allclose(resumed.raster_domain.get_array("rain"), expected_rain_arrays[10])
assert not np.allclose(resumed.raster_domain.get_array("rain"), expected_rain_arrays[0])

_run_to_end(resumed, skip_initialize=True)
_assert_final_state_matches(resumed, reference)


def test_resume_with_relative_time_xarray_inputs(domain_5by5) -> None:
_assert_resume_with_timed_xarray_inputs(
domain_5by5,
use_relative_time=True,
)


def test_resume_with_absolute_time_xarray_inputs(domain_5by5) -> None:
_assert_resume_with_timed_xarray_inputs(
domain_5by5,
use_relative_time=False,
)
Loading