Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 92 additions & 11 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Sequence
import copy
import functools
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import weakref

from absl import logging
Expand All @@ -28,7 +28,6 @@
from grain._src.python.dataset import stats
from grain._src.python.dataset.transformations import prefetch


T = TypeVar("T")


Expand Down Expand Up @@ -92,6 +91,7 @@ def __init__(
] = [None] * self._cycle_length
# Future states used for elastic iterators
self._future_states: dict[int, Any] = {}
self._iterator_start_states: dict[int, Any] = {}

@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
Expand Down Expand Up @@ -208,6 +208,41 @@ def get_state(self):
}
return state

def get_shard_states(self) -> Sequence[Any]:
state = self.get_state()
indices = state["iterators_in_use_indices"]
states = state["iterators_in_use_states"]
exhausted = state["exhausted"]
next_index_in_datasets = state["next_index_in_datasets"]

shard_states = [None] * len(self._datasets)

for i in range(len(self._datasets)): # pylint: disable=protected-access
# If the current shard index is greater than or equal to the next
# index in datasets, it means the current shard has not yet started
# to be iterated on.
if i >= next_index_in_datasets:
shard_states[i] = {
"exhausted": 0,
"state": self._get_iterator_start_state(i), # pylint: disable=protected-access
}
elif i not in indices:
# These shards are exhausted but should still create a state to maintain
# static state spec shapes.
shard_states[i] = {
"exhausted": 1,
"state": self._get_iterator_start_state(i), # pylint: disable=protected-access
}

for index, ds_state, is_exhausted in zip(indices, states, exhausted):
# These shards are currently being iterated on.
shard_states[index] = {
"exhausted": is_exhausted,
"state": ds_state,
}

return shard_states

def set_state(self, state):
exhausted = state["exhausted"]
for index_in_cycle, (index_in_datasets, it_state) in enumerate(
Expand Down Expand Up @@ -252,7 +287,52 @@ def set_state(self, state):
self._next_index_in_cycle = state["next_index_in_cycle"]
self._next_index_in_datasets = state["next_index_in_datasets"]
self._iterators_in_use_indices = state["iterators_in_use_indices"]
self._future_states = state.get("future_states", {})
self._future_states = cast(dict[int, Any], state.get("future_states", {}))

def set_shard_states(self, shard_states: Sequence[Any]) -> None:
active_states = []
for ind, shard_state in enumerate(shard_states):
if not shard_state["exhausted"]:
active_states.append((ind, shard_state["state"]))

iterators_in_use_indices = []
iterators_in_use_states = []
exhausted = []
count = 0
future_states = {}
for ind, s in active_states:
if count < self._cycle_length:
iterators_in_use_indices.append(ind)
iterators_in_use_states.append(s)
exhausted.append(0)
count += 1
elif s:
future_states[ind] = s
next_index_in_datasets = (
max(iterators_in_use_indices) + 1 if iterators_in_use_indices else 0
)

# This is the case where the cycle length is greater than the number of
# non-exhausted shards. We will just go back through the datasets to fill
# the cycle length.
fill_index = 0
while count < self._cycle_length:
if shard_states[fill_index]["exhausted"]:
iterators_in_use_indices.append(fill_index)
iterators_in_use_states.append(shard_states[fill_index]["state"])
exhausted.append(1)
count += 1
fill_index += 1

new_state = {
"next_index_in_cycle": 0,
"next_index_in_datasets": next_index_in_datasets,
"iterators_in_use_indices": iterators_in_use_indices,
"iterators_in_use_states": iterators_in_use_states,
"exhausted": exhausted,
"future_states": future_states,
}
self.set_state(new_state)

def _get_next_index(self) -> int:
if len(self._datasets) == 1:
Expand Down Expand Up @@ -334,14 +414,15 @@ def __str__(self) -> str:
)

def _get_iterator_start_state(self, index: int) -> dict[str, Any]:
it = _add_prefetch_and_make_iterator(
self._datasets[index],
weakref.ref(self),
start_prefetch=False,
)
state = it.get_state()
del it
return state
if index not in self._iterator_start_states:
it = _add_prefetch_and_make_iterator(
self._datasets[index],
weakref.ref(self),
start_prefetch=False,
)
self._iterator_start_states[index] = it.get_state()
it.close()
return self._iterator_start_states[index]


def _add_prefetch_and_make_iterator(
Expand Down
172 changes: 171 additions & 1 deletion grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
import numpy as np


_INTERLEAVE_TEST_CASES = (
dict(
testcase_name="cycle_length_1",
Expand Down Expand Up @@ -339,6 +339,176 @@ def test_future_states(self):
with self.assertRaises(StopIteration):
next(ds_iter)

@parameterized.named_parameters(
dict(
testcase_name="cycle_length_equals_num_datasets",
ds_elements=[[1, 2, 3], [4, 5, 6]],
cycle_length=2,
expected_shard_state=[
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 1}},
],
expected_remaining=[2, 5, 3, 6],
),
dict(
testcase_name="cycle_length_less_than_num_datasets",
ds_elements=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
cycle_length=2,
expected_shard_state=[
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 0}},
],
expected_remaining=[2, 5, 3, 6, 7, 8, 9],
),
)
def test_slice_state_management_checkpoints_correctly(
self,
ds_elements,
cycle_length,
expected_shard_state,
expected_remaining,
):
datasets = [
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in ds_elements
]
ds = self._create_dataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
it = ds.__iter__()

# Consume some elements to advance state.
for _ in range(2):
next(it)

# Get the shard state.
assert isinstance(it, prefetch.SupportsSlicedStateManagement)
shard_state = it.get_shard_states()
self.assertEqual(shard_state, expected_shard_state)

# Create a new iterator and restore state.
it2 = ds.__iter__()
assert isinstance(it2, prefetch.SupportsSlicedStateManagement)
it2.set_shard_states(shard_state)

# Verify it continues from the correct position.
self.assertSequenceEqual(list(it2), expected_remaining)

@parameterized.named_parameters(
dict(
testcase_name="cycle_length_equals_num_datasets",
ds_elements=[[1, 2, 3], [4, 5, 6]],
cycle_length=2,
expected_shard_state=[
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 1}},
],
expected_future_states={},
expected_next_index_in_datasets=2,
expected_iterators_in_use_indices=[0, 1],
expected_iterators_in_use_states=[
{"next_index": 1},
{"next_index": 1},
],
),
dict(
testcase_name="cycle_length_less_than_num_datasets",
ds_elements=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
cycle_length=2,
expected_shard_state=[
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 0}},
],
expected_future_states={2: {"next_index": 0}},
expected_next_index_in_datasets=2,
expected_iterators_in_use_indices=[0, 1],
expected_iterators_in_use_states=[
{"next_index": 1},
{"next_index": 1},
],
),
)
def test_correct_interleave_state_after_setting_shards(
self,
ds_elements,
cycle_length,
expected_shard_state,
expected_future_states,
expected_next_index_in_datasets,
expected_iterators_in_use_indices,
expected_iterators_in_use_states,
):
datasets = [
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in ds_elements
]
ds = self._create_dataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
it = ds.__iter__()

# Consume some elements to advance state.
for _ in range(2):
next(it)

assert isinstance(it, prefetch.SupportsSlicedStateManagement)
shard_state = it.get_shard_states()
self.assertEqual(shard_state, expected_shard_state)

# Create a new iterator and restore state.
it2 = ds.__iter__()
assert isinstance(it2, prefetch.SupportsSlicedStateManagement)
it2.set_shard_states(shard_state)

# Check get_shard_states() returns the set shard states correctly.
self.assertEqual(it2.get_shard_states(), expected_shard_state)

# Check get_state() internal values.
state = it2.get_state()
self.assertEqual(state["next_index_in_cycle"], 0)
self.assertEqual(
state["next_index_in_datasets"], expected_next_index_in_datasets
)
self.assertEqual(
state["iterators_in_use_indices"], expected_iterators_in_use_indices
)
self.assertEqual(
state["iterators_in_use_states"], expected_iterators_in_use_states
)
self.assertEqual(state["future_states"], expected_future_states)

def test_setting_shard_state_with_exhausted_states(self):
datasets_data = [[1], [2, 3], [4, 5]]
datasets = [
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in datasets_data
]
ds = self._create_dataset(datasets, cycle_length=2)
ds = self._maybe_wrap_ds(ds)
it = ds.__iter__()

shard_state = [
{"exhausted": 1, "state": {"next_index": 1}},
{"exhausted": 1, "state": {"next_index": 2}},
{"exhausted": 0, "state": {"next_index": 0}},
]

# Create a new iterator and restore state.
assert isinstance(it, prefetch.SupportsSlicedStateManagement)
it.set_shard_states(shard_state)

# Check get_state() internal values.
state = it.get_state()
self.assertEqual(state["next_index_in_cycle"], 0)
self.assertEqual(state["next_index_in_datasets"], 3)
self.assertEqual(state["iterators_in_use_indices"], [2, 0])
self.assertEqual(
state["iterators_in_use_states"],
[{"next_index": 0}, {"next_index": 1}],
)
if isinstance(self, InterleaveIterDatasetTest):
self.assertEqual(state["exhausted"], [0, 1])


if __name__ == "__main__":
absltest.main()
29 changes: 29 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,35 @@ def set_slice(self, sl: slice, sequential_slice: bool = False) -> None:
...


@typing.runtime_checkable
class SupportsSlicedStateManagement(Protocol):
"""Iterators that support setting a sliced state.

This protocol is mainly used to support elastic resizing of iterators.
"""

def get_shard_states(self) -> Sequence[Any]:
"""Returns the states of all shards managed by this iterator.

This is used for elastic resizing to capture the current progress of each
shard.
"""
...

def set_shard_states(self, shard_states: Sequence[Any]):
"""Sets the states of all shards managed by this iterator.

This is used for elastic resizing to restore the progress of each shard.

Args:
shard_states: A sequence of dictionaries, one for each shard. Each dict
must contain 'exhausted' key with value bool indicating if the shard is
exhausted and 'state' key with value Any representing the state of the
shard.
"""
...


class PrefetchIterDataset(dataset.IterDataset[T]):
"""Iterable dataset that uses a thread pool for prefetching."""

Expand Down
Loading