Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
071ebd6
[Data] Add per-stage training-thread blocking attribution and pipelin…
OneSizeFitsQuorum Jun 17, 2026
f3935ab
[Data] Remove unused ShufflingBatcher compaction tracking
OneSizeFitsQuorum Jun 20, 2026
543a68d
[Data] Remove restore_order stage from blocked attribution
OneSizeFitsQuorum Jun 22, 2026
3a39cbd
[Data] Consolidate timing into StageTiming context manager
OneSizeFitsQuorum Jun 22, 2026
16fea70
[Data] Capture upstream blocked time in fetch stage
OneSizeFitsQuorum Jun 22, 2026
71da139
[Data] Add docstrings to timing dataclasses and _report_batch_timings
OneSizeFitsQuorum Jun 22, 2026
9fcde56
[Data] Restore isinstance check for BlockWithTiming compatibility
OneSizeFitsQuorum Jun 22, 2026
1266c92
[Data] Refactor to eliminate isinstance/Union in _BatchingIterator
OneSizeFitsQuorum Jun 22, 2026
ff658f2
[Data] Fix merge_fetch idle gap and blocked window alignment
OneSizeFitsQuorum Jun 22, 2026
b9bf847
[Data] Remove unused _merge_stage method
OneSizeFitsQuorum Jun 22, 2026
aa41de5
[Data] Revert merge_fetch to span-based approach
OneSizeFitsQuorum Jun 22, 2026
7874ef4
[Data] Fix test failures from BlockWithTiming refactor
OneSizeFitsQuorum Jun 23, 2026
2d311f4
[Data] Introduce TimeSpan and consolidate timing into Timer
OneSizeFitsQuorum Jun 26, 2026
74b9326
[Data] Split fetch stage into production_wait and data_transfer
OneSizeFitsQuorum Jun 26, 2026
261317c
[Data] Refine data model with BlockFetchResult and Optional timing
OneSizeFitsQuorum Jun 26, 2026
304dfe8
[Data] Add PipelineStage enum and match-based gauge mapping
OneSizeFitsQuorum Jun 26, 2026
09933dc
[Data] Defer iteration metrics rank label to follow-up PR
OneSizeFitsQuorum Jun 27, 2026
775c0a6
[Data] Fix nested timer double-counting in resolve_block_refs
OneSizeFitsQuorum Jun 27, 2026
f988f33
[Data] Add regression test for ref_bundles timer non-accumulation
OneSizeFitsQuorum Jun 27, 2026
c604a0a
[Data] Document blocked-attribution known limitations and typing-shim…
OneSizeFitsQuorum Jun 27, 2026
ea957ba
[Data] Add unit tests for Timer.timer, _timed, merge_fetch, get_block…
OneSizeFitsQuorum Jun 27, 2026
8fe5479
[Data] Apply black formatting to recent changes
OneSizeFitsQuorum Jun 27, 2026
4ac2b5c
[Data] Address review: rename IterationStage/_maybe_time, simplify co…
OneSizeFitsQuorum Jun 28, 2026
b13112d
[Data] Finish IterationStage rename in docstrings; slim _attribute_bl…
OneSizeFitsQuorum Jun 28, 2026
57ede0e
[Data] Record iter_total_s and flush metrics on early exit
OneSizeFitsQuorum Jun 28, 2026
39ebab5
[Data] Test code style cleanup
OneSizeFitsQuorum Jun 28, 2026
2e8f64d
[Data] Trim docstrings and comments for readability
OneSizeFitsQuorum Jun 28, 2026
e4d5df9
[Data] Drop circular "see docstring" pointer in resolve_block_refs
OneSizeFitsQuorum Jun 28, 2026
3d1694e
[Data] Drop circular "see _attribute_blocked_time" pointer in BatchTi…
OneSizeFitsQuorum Jun 28, 2026
dc78b7b
[Data] Drop Sphinx markup from Timer.timer() docstring
OneSizeFitsQuorum Jun 28, 2026
6ec8c26
[Data] Naming and test cleanup from self-review
OneSizeFitsQuorum Jun 28, 2026
9519cdf
[Data] Fix import sort order in test_stats.py
OneSizeFitsQuorum Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/ray/data/_internal/block_batching/block_batching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Iterator, Optional, TypeVar

from ray.data._internal.block_batching.interfaces import BlockFetchResult
from ray.data._internal.block_batching.util import (
_MappingIterator,
blocks_to_batches,
Expand Down Expand Up @@ -29,10 +30,19 @@ def batch_blocks(
This function takes in an iterator of already fetched blocks. Consequently, this
function doesn't support block prefetching.
"""
# Wrap raw blocks in BlockFetchResult with no fetch timing (these
# blocks were already resolved before entering the pipeline).
# Use map() instead of a generator expression to avoid holding
# references to blocks.
#
# TODO: make fetch timing optional at the _BatchingIterator level so
# this BlockFetchResult wrapping shim can be removed.
wrapped_blocks = map(lambda b: BlockFetchResult(block=b), blocks)

# Build the processing pipeline
batch_iter = format_batches(
blocks_to_batches(
block_iter=blocks,
block_iter=wrapped_blocks,
stats=stats,
batch_size=batch_size,
drop_last=drop_last,
Expand Down
74 changes: 72 additions & 2 deletions python/ray/data/_internal/block_batching/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,91 @@
import abc
from dataclasses import dataclass
from typing import Any, List
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional, Tuple

from ray.data._internal.stats import IterationStage, TimeSpan
from ray.data.block import Block, DataBatch
from ray.types import ObjectRef


@dataclass
class BlockFetchTiming:
"""Fetch timing for a single block (production_wait + data_transfer)."""

production_wait: Optional[TimeSpan] = None
data_transfer: Optional[TimeSpan] = None


@dataclass
class BlockFetchResult:
"""A resolved block paired with its per-block fetch timing.

``fetch`` is None when no timing was recorded (e.g. blocks already
resolved before entering the pipeline).
"""

block: Block
fetch: Optional[BlockFetchTiming] = None


@dataclass
class BatchTimings:
"""Per-batch timing windows for each iteration stage.

Each field is the ``(start_s, end_s)`` window a stage was active, or
None if the stage didn't run. Compared against the training thread's
blocked window to attribute stall.
"""

production_wait: Optional[TimeSpan] = None
data_transfer: Optional[TimeSpan] = None
batching: Optional[TimeSpan] = None
format: Optional[TimeSpan] = None
collate: Optional[TimeSpan] = None
finalize: Optional[TimeSpan] = None

def stages(self) -> Iterable[Tuple[IterationStage, Optional[TimeSpan]]]:
"""Yield (stage, timing) pairs."""
return (
(IterationStage.PRODUCTION_WAIT, self.production_wait),
(IterationStage.DATA_TRANSFER, self.data_transfer),
(IterationStage.BATCHING, self.batching),
(IterationStage.FORMAT, self.format),
(IterationStage.COLLATE, self.collate),
(IterationStage.FINALIZE, self.finalize),
)

def merge_fetch(self, src: BlockFetchTiming) -> None:
"""Merge per-block fetch timings into this batch's fetch windows."""
self.production_wait = _merge_span(self.production_wait, src.production_wait)
self.data_transfer = _merge_span(self.data_transfer, src.data_transfer)


def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]:
"""Return the union of two optional windows (or the non-None one)."""
if src is None:
return dst
if dst is None:
return TimeSpan(start_s=src.start_s, end_s=src.end_s)
return TimeSpan(
start_s=min(dst.start_s, src.start_s),
end_s=max(dst.end_s, src.end_s),
)
Comment thread
cursor[bot] marked this conversation as resolved.


@dataclass
class BatchMetadata:
"""Metadata associated with a batch.

Attributes:
batch_idx: The global index of this batch so that downstream operations can
maintain ordering.
num_rows: Number of rows in this batch (for ``iter_rows_total``).
timings: Per-stage timing windows.
"""

batch_idx: int
num_rows: int = 0
timings: BatchTimings = field(default_factory=BatchTimings)


@dataclass
Expand Down
48 changes: 47 additions & 1 deletion python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _prefetch_blocks(

def _resolve_block_refs(
self, block_refs: Iterator[ObjectRef[Block]]
) -> Iterator[Block]:
) -> Iterator[Any]:
return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats)

def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]:
Expand Down Expand Up @@ -243,21 +243,63 @@ def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]:
yield from batch_iter

def _iter_batches(self) -> Iterator[DataBatch]:
"""Pull batches from the pipeline and yield batch data.

Captures the training thread's blocked window around each ``next()``
call and attributes it to pipeline stages via
``_attribute_blocked_time``.
"""
batch_iter = iter_threaded(self._ref_bundles, fn=self._pipeline)

self.before_epoch_start()

while True:
with self.get_next_batch_context():
blocked_start_s = time.perf_counter()
try:
batch = next(batch_iter)
except StopIteration:
break
blocked_end_s = time.perf_counter()
self._attribute_blocked_time(batch, blocked_start_s, blocked_end_s)
with self.yield_batch_context(batch):
yield batch.data

self.after_epoch_end()

def _attribute_blocked_time(
self, batch: Batch, blocked_start_s: float, blocked_end_s: float
) -> None:
"""Attribute per-stage blocked time via overlap with the training window.

Each stage's window on ``batch.metadata.timings`` is intersected with
the training thread's blocked window ``[blocked_start_s, blocked_end_s]``::

overlap = min(timing.end, blocked_end) - max(timing.start, blocked_start)

TODO: ``sum(iter_blocked_*)`` only approximates ``iter_total_blocked_s``
— split fetch stages overlap for multi-block batches, and reorder
buffer wait under ``preserve_order`` is unattributed.

Args:
batch: Batch whose per-stage timings should be attributed.
blocked_start_s: perf_counter() just before next().
blocked_end_s: perf_counter() just after next() returned.
"""
if self._stats is None:
return
timings = batch.metadata.timings
for stage, timing in timings.stages():
if timing is None:
continue
Comment thread
OneSizeFitsQuorum marked this conversation as resolved.
overlap_s = min(timing.end_s, blocked_end_s) - max(
timing.start_s, blocked_start_s
)
if overlap_s > 0:
self._stats.get_blocked_timer(stage).add(overlap_s)
Comment thread
cursor[bot] marked this conversation as resolved.
self._stats.iter_batches_total += 1
self._stats.iter_rows_total += batch.metadata.num_rows

def __iter__(self) -> Iterator[DataBatch]:
return self._iter_batches()

Expand All @@ -276,6 +318,8 @@ def after_epoch_end(self):

@contextmanager
def get_next_batch_context(self):
"""Context around ``next(batch_iter)``: tracks total blocked time
and time-to-first-batch."""
try:
if self._stats:
# Always track total blocked time
Expand All @@ -295,6 +339,8 @@ def get_next_batch_context(self):

@contextmanager
def yield_batch_context(self, batch: Batch):
"""Context around yielding a batch to the user: tracks user time
and periodically flushes metrics."""
with self._stats.iter_user_s.timer() if self._stats else nullcontext():
yield

Expand Down
Loading
Loading