From 05508ebfe77d170db02f0e5e2f9e2cada0f57644 Mon Sep 17 00:00:00 2001 From: Phani Aenugula Date: Tue, 12 May 2026 13:51:02 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 914460149 --- .../dataset/transformations/interleave.py | 103 +++++++++-- .../transformations/interleave_test.py | 172 +++++++++++++++++- .../dataset/transformations/prefetch.py | 29 +++ 3 files changed, 292 insertions(+), 12 deletions(-) diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index f91077637..3f8280959 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -17,7 +17,7 @@ from collections.abc import Sequence import copy import functools -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import weakref from absl import logging @@ -28,7 +28,6 @@ from grain._src.python.dataset import stats from grain._src.python.dataset.transformations import prefetch - T = TypeVar("T") @@ -92,6 +91,7 @@ def __init__( ] = [None] * self._cycle_length # Future states used for elastic iterators self._future_states: dict[int, Any] = {} + self._iterator_start_states: dict[int, Any] = {} @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -208,6 +208,41 @@ def get_state(self): } return state + def get_shard_states(self) -> Sequence[Any]: + state = self.get_state() + indices = state["iterators_in_use_indices"] + states = state["iterators_in_use_states"] + exhausted = state["exhausted"] + next_index_in_datasets = state["next_index_in_datasets"] + + shard_states = [None] * len(self._datasets) + + for i in range(len(self._datasets)): # pylint: disable=protected-access + # If the current shard index is greater than or equal to the next + # index in datasets, it means the current shard has not yet started + # to be iterated on. + if i >= next_index_in_datasets: + shard_states[i] = { + "exhausted": 0, + "state": self._get_iterator_start_state(i), # pylint: disable=protected-access + } + elif i not in indices: + # These shards are exhausted but should still create a state to maintain + # static state spec shapes. + shard_states[i] = { + "exhausted": 1, + "state": self._get_iterator_start_state(i), # pylint: disable=protected-access + } + + for index, ds_state, is_exhausted in zip(indices, states, exhausted): + # These shards are currently being iterated on. + shard_states[index] = { + "exhausted": is_exhausted, + "state": ds_state, + } + + return shard_states + def set_state(self, state): exhausted = state["exhausted"] for index_in_cycle, (index_in_datasets, it_state) in enumerate( @@ -252,7 +287,52 @@ def set_state(self, state): self._next_index_in_cycle = state["next_index_in_cycle"] self._next_index_in_datasets = state["next_index_in_datasets"] self._iterators_in_use_indices = state["iterators_in_use_indices"] - self._future_states = state.get("future_states", {}) + self._future_states = cast(dict[int, Any], state.get("future_states", {})) + + def set_shard_states(self, shard_states: Sequence[Any]) -> None: + active_states = [] + for ind, shard_state in enumerate(shard_states): + if not shard_state["exhausted"]: + active_states.append((ind, shard_state["state"])) + + iterators_in_use_indices = [] + iterators_in_use_states = [] + exhausted = [] + count = 0 + future_states = {} + for ind, s in active_states: + if count < self._cycle_length: + iterators_in_use_indices.append(ind) + iterators_in_use_states.append(s) + exhausted.append(0) + count += 1 + elif s: + future_states[ind] = s + next_index_in_datasets = ( + max(iterators_in_use_indices) + 1 if iterators_in_use_indices else 0 + ) + + # This is the case where the cycle length is greater than the number of + # non-exhausted shards. We will just go back through the datasets to fill + # the cycle length. + fill_index = 0 + while count < self._cycle_length: + if shard_states[fill_index]["exhausted"]: + iterators_in_use_indices.append(fill_index) + iterators_in_use_states.append(shard_states[fill_index]["state"]) + exhausted.append(1) + count += 1 + fill_index += 1 + + new_state = { + "next_index_in_cycle": 0, + "next_index_in_datasets": next_index_in_datasets, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": exhausted, + "future_states": future_states, + } + self.set_state(new_state) def _get_next_index(self) -> int: if len(self._datasets) == 1: @@ -334,14 +414,15 @@ def __str__(self) -> str: ) def _get_iterator_start_state(self, index: int) -> dict[str, Any]: - it = _add_prefetch_and_make_iterator( - self._datasets[index], - weakref.ref(self), - start_prefetch=False, - ) - state = it.get_state() - del it - return state + if index not in self._iterator_start_states: + it = _add_prefetch_and_make_iterator( + self._datasets[index], + weakref.ref(self), + start_prefetch=False, + ) + self._iterator_start_states[index] = it.get_state() + it.close() + return self._iterator_start_states[index] def _add_prefetch_and_make_iterator( diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index 29cae4704..5b48925c6 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -21,10 +21,10 @@ from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import interleave +from grain._src.python.dataset.transformations import prefetch from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint import numpy as np - _INTERLEAVE_TEST_CASES = ( dict( testcase_name="cycle_length_1", @@ -339,6 +339,176 @@ def test_future_states(self): with self.assertRaises(StopIteration): next(ds_iter) + @parameterized.named_parameters( + dict( + testcase_name="cycle_length_equals_num_datasets", + ds_elements=[[1, 2, 3], [4, 5, 6]], + cycle_length=2, + expected_shard_state=[ + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 1}}, + ], + expected_remaining=[2, 5, 3, 6], + ), + dict( + testcase_name="cycle_length_less_than_num_datasets", + ds_elements=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + cycle_length=2, + expected_shard_state=[ + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 0}}, + ], + expected_remaining=[2, 5, 3, 6, 7, 8, 9], + ), + ) + def test_slice_state_management_checkpoints_correctly( + self, + ds_elements, + cycle_length, + expected_shard_state, + expected_remaining, + ): + datasets = [ + dataset.MapDataset.source(elements).to_iter_dataset() + for elements in ds_elements + ] + ds = self._create_dataset(datasets, cycle_length=cycle_length) + ds = self._maybe_wrap_ds(ds) + it = ds.__iter__() + + # Consume some elements to advance state. + for _ in range(2): + next(it) + + # Get the shard state. + assert isinstance(it, prefetch.SupportsSlicedStateManagement) + shard_state = it.get_shard_states() + self.assertEqual(shard_state, expected_shard_state) + + # Create a new iterator and restore state. + it2 = ds.__iter__() + assert isinstance(it2, prefetch.SupportsSlicedStateManagement) + it2.set_shard_states(shard_state) + + # Verify it continues from the correct position. + self.assertSequenceEqual(list(it2), expected_remaining) + + @parameterized.named_parameters( + dict( + testcase_name="cycle_length_equals_num_datasets", + ds_elements=[[1, 2, 3], [4, 5, 6]], + cycle_length=2, + expected_shard_state=[ + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 1}}, + ], + expected_future_states={}, + expected_next_index_in_datasets=2, + expected_iterators_in_use_indices=[0, 1], + expected_iterators_in_use_states=[ + {"next_index": 1}, + {"next_index": 1}, + ], + ), + dict( + testcase_name="cycle_length_less_than_num_datasets", + ds_elements=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + cycle_length=2, + expected_shard_state=[ + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 0}}, + ], + expected_future_states={2: {"next_index": 0}}, + expected_next_index_in_datasets=2, + expected_iterators_in_use_indices=[0, 1], + expected_iterators_in_use_states=[ + {"next_index": 1}, + {"next_index": 1}, + ], + ), + ) + def test_correct_interleave_state_after_setting_shards( + self, + ds_elements, + cycle_length, + expected_shard_state, + expected_future_states, + expected_next_index_in_datasets, + expected_iterators_in_use_indices, + expected_iterators_in_use_states, + ): + datasets = [ + dataset.MapDataset.source(elements).to_iter_dataset() + for elements in ds_elements + ] + ds = self._create_dataset(datasets, cycle_length=cycle_length) + ds = self._maybe_wrap_ds(ds) + it = ds.__iter__() + + # Consume some elements to advance state. + for _ in range(2): + next(it) + + assert isinstance(it, prefetch.SupportsSlicedStateManagement) + shard_state = it.get_shard_states() + self.assertEqual(shard_state, expected_shard_state) + + # Create a new iterator and restore state. + it2 = ds.__iter__() + assert isinstance(it2, prefetch.SupportsSlicedStateManagement) + it2.set_shard_states(shard_state) + + # Check get_shard_states() returns the set shard states correctly. + self.assertEqual(it2.get_shard_states(), expected_shard_state) + + # Check get_state() internal values. + state = it2.get_state() + self.assertEqual(state["next_index_in_cycle"], 0) + self.assertEqual( + state["next_index_in_datasets"], expected_next_index_in_datasets + ) + self.assertEqual( + state["iterators_in_use_indices"], expected_iterators_in_use_indices + ) + self.assertEqual( + state["iterators_in_use_states"], expected_iterators_in_use_states + ) + self.assertEqual(state["future_states"], expected_future_states) + + def test_setting_shard_state_with_exhausted_states(self): + datasets_data = [[1], [2, 3], [4, 5]] + datasets = [ + dataset.MapDataset.source(elements).to_iter_dataset() + for elements in datasets_data + ] + ds = self._create_dataset(datasets, cycle_length=2) + ds = self._maybe_wrap_ds(ds) + it = ds.__iter__() + + shard_state = [ + {"exhausted": 1, "state": {"next_index": 1}}, + {"exhausted": 1, "state": {"next_index": 2}}, + {"exhausted": 0, "state": {"next_index": 0}}, + ] + + # Create a new iterator and restore state. + assert isinstance(it, prefetch.SupportsSlicedStateManagement) + it.set_shard_states(shard_state) + + # Check get_state() internal values. + state = it.get_state() + self.assertEqual(state["next_index_in_cycle"], 0) + self.assertEqual(state["next_index_in_datasets"], 3) + self.assertEqual(state["iterators_in_use_indices"], [2, 0]) + self.assertEqual( + state["iterators_in_use_states"], + [{"next_index": 0}, {"next_index": 1}], + ) + if isinstance(self, InterleaveIterDatasetTest): + self.assertEqual(state["exhausted"], [0, 1]) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 651b09275..3fe681eb1 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -88,6 +88,35 @@ def set_slice(self, sl: slice, sequential_slice: bool = False) -> None: ... +@typing.runtime_checkable +class SupportsSlicedStateManagement(Protocol): + """Iterators that support setting a sliced state. + + This protocol is mainly used to support elastic resizing of iterators. + """ + + def get_shard_states(self) -> Sequence[Any]: + """Returns the states of all shards managed by this iterator. + + This is used for elastic resizing to capture the current progress of each + shard. + """ + ... + + def set_shard_states(self, shard_states: Sequence[Any]): + """Sets the states of all shards managed by this iterator. + + This is used for elastic resizing to restore the progress of each shard. + + Args: + shard_states: A sequence of dictionaries, one for each shard. Each dict + must contain 'exhausted' key with value bool indicating if the shard is + exhausted and 'state' key with value Any representing the state of the + shard. + """ + ... + + class PrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a thread pool for prefetching."""