From 071ebd6507503f480618b269abd182577e670f48 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Wed, 17 Jun 2026 23:57:00 +0800 Subject: [PATCH 01/32] [Data] Add per-stage training-thread blocking attribution and pipeline observability to iter_batches Implements overlap-based latency attribution for Ray Data's iter_batches pipeline, addressing #64132 and RFC #63911. Each pipeline stage (fetch, batching, format, collate, finalize, restore_order) records an independent (start_s, end_s) time window. The training thread captures its own blocked window around next(). Attribution per stage is the overlap of the two windows, correctly handling prefetch > 1. New Prometheus metrics (14 total): - data_iter_blocked_{fetch,batching,format,collate,finalize,restore_order}_seconds - data_iter_batches_total, data_iter_rows_total - data_iter_total_seconds, data_iter_restore_order_buffer_peak - data_iter_shuffle_buffer_{rows,compactions_total,compaction_seconds} - data_iter_prefetch_queue_depth Also adds: - Per-stage breakdown rendering in IterStatsSummary.to_string() - Rank extraction from dataset tags for Prometheus labels - Final metrics flush on iterator completion Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/batcher.py | 9 + .../_internal/block_batching/interfaces.py | 58 ++- .../_internal/block_batching/iter_batches.py | 64 ++- .../ray/data/_internal/block_batching/util.py | 53 ++- python/ray/data/_internal/stats.py | 122 ++++- python/ray/data/iterator.py | 3 +- .../tests/block_batching/test_iter_batches.py | 432 ++++++++++++++++++ python/ray/data/tests/test_stats.py | 159 +++++++ 8 files changed, 878 insertions(+), 22 deletions(-) diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index c097ee668de6..7ff6ccf42c11 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -1,3 +1,4 @@ +import time import warnings from typing import Optional @@ -235,6 +236,8 @@ def __init__( self._total_object_store_nbytes = get_total_obj_store_mem_on_node() self._total_num_rows_added = 0 self._total_nbytes_added = 0 + self.compactions_total = 0 + self.compaction_time_s = 0.0 def add(self, block: Block): """Add a block to the shuffle buffer. @@ -320,6 +323,9 @@ def _num_rows(self) -> int: """ return self._num_compacted_rows() + self._num_uncompacted_rows() + def num_rows(self) -> int: + return self._num_rows() + def _num_compacted_rows(self) -> int: """Return number of unyielded rows in the compacted buffer.""" if self._shuffle_buffer is None: @@ -341,6 +347,7 @@ def next_batch(self) -> Block: self._done_adding or self._num_compacted_rows() <= self._min_rows_to_yield_batch ): + compaction_start_s = time.perf_counter() if self._shuffle_buffer is not None and self._batch_head < len( self._shuffled_indices ): @@ -363,6 +370,8 @@ def next_batch(self) -> Block: self._builder = DelegatingBlockBuilder() self._batch_head = 0 + self.compactions_total += 1 + self.compaction_time_s += time.perf_counter() - compaction_start_s assert self._shuffle_buffer is not None assert self._shuffled_indices is not None diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f0bed6b3dd4..10587586c941 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,11 +1,63 @@ import abc -from dataclasses import dataclass -from typing import Any, List +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Tuple from ray.data.block import Block, DataBatch from ray.types import ObjectRef +@dataclass +class StageTiming: + """Wall-clock window for a batch-processing stage.""" + + start_s: float = 0.0 + end_s: float = 0.0 + + def record(self, start_s: float, end_s: float) -> None: + if self.start_s == 0.0: + self.start_s = start_s + self.end_s = end_s + + +@dataclass +class BatchTimings: + fetch: StageTiming = field(default_factory=StageTiming) + batching: StageTiming = field(default_factory=StageTiming) + format: StageTiming = field(default_factory=StageTiming) + collate: StageTiming = field(default_factory=StageTiming) + finalize: StageTiming = field(default_factory=StageTiming) + restore_order: StageTiming = field(default_factory=StageTiming) + num_rows: int = 0 + + def stages(self) -> Iterable[Tuple[str, StageTiming]]: + return ( + ("fetch", self.fetch), + ("batching", self.batching), + ("format", self.format), + ("collate", self.collate), + ("finalize", self.finalize), + ("restore_order", self.restore_order), + ) + + def merge_fetch(self, other: "BatchTimings") -> None: + self._merge_stage(self.fetch, other.fetch) + + @staticmethod + def _merge_stage(dst: StageTiming, src: StageTiming) -> None: + if src.start_s == 0.0: + return + if dst.start_s == 0.0 or src.start_s < dst.start_s: + dst.start_s = src.start_s + if src.end_s > dst.end_s: + dst.end_s = src.end_s + + +@dataclass +class BlockWithTiming: + block: Block + timings: BatchTimings = field(default_factory=BatchTimings) + + @dataclass class BatchMetadata: """Metadata associated with a batch. @@ -13,9 +65,11 @@ class BatchMetadata: Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. + timings: Pipeline-stage timing windows for this batch. """ batch_idx: int + timings: BatchTimings = field(default_factory=BatchTimings) @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f9bf0076d2af..66e55437d68f 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -175,8 +175,10 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] - ) -> Iterator[Block]: - return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) + ) -> Iterator[Any]: + return resolve_block_refs( + block_ref_iter=block_refs, stats=self._stats, record_timings=True + ) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: return blocks_to_batches( @@ -216,7 +218,7 @@ def _finalize_batches( def _restore_original_batch_order( self, batches: Iterator[Batch] ) -> Iterator[Batch]: - return restore_original_order(batches) + return restore_original_order(batches, stats=self._stats) def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: # Step 1: Prefetch logical batches locally. @@ -248,16 +250,36 @@ def _iter_batches(self) -> Iterator[DataBatch]: self.before_epoch_start() while True: + blocked_start_s = time.perf_counter() with self.get_next_batch_context(): try: batch = next(batch_iter) except StopIteration: break + blocked_end_s = time.perf_counter() + self._report_batch_timings(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data self.after_epoch_end() + def _report_batch_timings( + self, batch: Batch, blocked_start_s: float, blocked_end_s: float + ) -> None: + if self._stats is None: + return + timings = batch.metadata.timings + for name, stage in timings.stages(): + if stage.start_s == 0.0 and stage.end_s == 0.0: + continue + overlap_s = min(stage.end_s, blocked_end_s) - max( + stage.start_s, blocked_start_s + ) + if overlap_s > 0: + getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) + self._stats.iter_batches_total += 1 + self._stats.iter_rows_total += timings.num_rows + def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() @@ -452,7 +474,9 @@ def get_next_ref_bundle() -> RefBundle: prefetcher.stop() -def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: +def restore_original_order( + batch_iter: Iterator[Batch], stats: Optional[DatasetStats] = None +) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on @@ -463,13 +487,31 @@ def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """ next_index_required = 0 buffer: Dict[int, Batch] = {} - for batch in batch_iter: - assert batch.metadata.batch_idx not in buffer - buffer[batch.metadata.batch_idx] = batch + restore_wait_start_s: Optional[float] = None + source_exhausted = False + + while True: while next_index_required in buffer: - yield buffer.pop(next_index_required) + next_batch = buffer.pop(next_index_required) + if restore_wait_start_s is not None: + next_batch.metadata.timings.restore_order.record( + restore_wait_start_s, time.perf_counter() + ) + restore_wait_start_s = None + yield next_batch next_index_required += 1 - while next_index_required in buffer: - yield buffer.pop(next_index_required) - next_index_required += 1 + if source_exhausted: + break + + if buffer and restore_wait_start_s is None: + restore_wait_start_s = time.perf_counter() + + try: + batch = next(batch_iter) + except StopIteration: + source_exhausted = True + continue + + assert batch.metadata.batch_idx not in buffer + buffer[batch.metadata.batch_idx] = batch diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 8a42cde7871e..fb9f1f3aa11a 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,6 +3,7 @@ import logging import queue import threading +import time from contextlib import nullcontext from typing import ( Any, @@ -14,6 +15,7 @@ Optional, Tuple, TypeVar, + Union, ) import ray @@ -22,8 +24,11 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + BlockWithTiming, CollatedBatch, + StageTiming, ) from ray.data._internal.stats import DatasetStats from ray.data.block import Block, BlockAccessor, DataBatch @@ -170,18 +175,27 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: return 0, 0, 0 +def _record_stage_window(stage: StageTiming, start_s: float, end_s: float) -> None: + stage.record(start_s, end_s) + + def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, -) -> Iterator[Block]: + record_timings: bool = False, +) -> Iterator[Union[Block, BlockWithTiming]]: """Resolves the block references for each logical batch. Args: block_ref_iter: An iterator over block object references. stats: An optional stats object to recording block hits and misses. + record_timings: If True, wrap each resolved block in a + ``BlockWithTiming`` carrying the per-block fetch window. Yields: - Block: The resolved blocks for each block reference. + Union[Block, BlockWithTiming]: The resolved blocks. When + *record_timings* is ``True`` each block is wrapped in a + ``BlockWithTiming``; otherwise raw ``Block`` instances are yielded. """ hits = 0 misses = 0 @@ -195,9 +209,16 @@ def resolve_block_refs( # TODO(amogkam): Optimized further by batching multiple references in a single # `ray.get()` call. + start_s = time.perf_counter() with stats.iter_get_s.timer() if stats else nullcontext(): block = ray.get(block_ref) - yield block + end_s = time.perf_counter() + if record_timings: + timings = BatchTimings() + _record_stage_window(timings.fetch, start_s, end_s) + yield BlockWithTiming(block=block, timings=timings) + else: + yield block if stats: stats.iter_blocks_local = hits @@ -206,7 +227,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Block], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -235,7 +256,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Block], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -248,6 +269,7 @@ def __init__( self._drop_last = drop_last self._global_counter = 0 self._done_adding = False + self._pending_timings = BatchTimings() if shuffle_buffer_min_size is not None: self._batcher = ShufflingBatcher( @@ -272,12 +294,22 @@ def __next__(self) -> Batch: if can_yield: with timer: + start_s = time.perf_counter() next_batch = self._batcher.next_batch() + end_s = time.perf_counter() res = Batch( - metadata=BatchMetadata(batch_idx=self._global_counter), + metadata=BatchMetadata( + batch_idx=self._global_counter, + timings=self._pending_timings, + ), data=next_batch, ) + _record_stage_window(res.metadata.timings.batching, start_s, end_s) + res.metadata.timings.num_rows = BlockAccessor.for_block( + next_batch + ).num_rows() + self._pending_timings = BatchTimings() self._global_counter += 1 return res @@ -287,6 +319,9 @@ def __next__(self) -> Batch: try: # NOTE: Block ref is released immediately block = next(self._block_iter) + if isinstance(block, BlockWithTiming): + self._pending_timings.merge_fetch(block.timings) + block = block.block self._batcher.add(block) except StopIteration: self._batcher.done_adding() @@ -306,12 +341,14 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: + start_s = time.perf_counter() with stats.iter_format_batch_s.timer() if stats else nullcontext(): formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) if ensure_copy: formatted_data = _copy_batch(formatted_data) + _record_stage_window(batch.metadata.timings.format, start_s, time.perf_counter()) return dataclasses.replace(batch, data=formatted_data) @@ -359,8 +396,10 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: + start_s = time.perf_counter() with stats.iter_collate_batch_s.timer() if stats else nullcontext(): collated_data = collate_fn(batch.data) + _record_stage_window(batch.metadata.timings.collate, start_s, time.perf_counter()) return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -384,8 +423,10 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: + start_s = time.perf_counter() with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): finalized_data = finalize_fn(batch.data) + _record_stage_window(batch.metadata.timings.finalize, start_s, time.perf_counter()) return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 618ed8d6a376..7afdb967aeaa 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1,6 +1,7 @@ import collections import copy import logging +import re import time from collections import defaultdict from contextlib import contextmanager @@ -60,6 +61,19 @@ StatsDict = Dict[str, List[BlockStats]] +def _create_iteration_tags(dataset_tag: Optional[str]): + dataset_tag = dataset_tag or "unknown_dataset" + tags = {"dataset": dataset_tag, "rank": "unknown"} + # Use findall + last match: the streaming-split index is always the + # trailing ``split_`` in the tag. The user-defined dataset name may + # itself contain ``split_`` so re.search (first match) could + # pick up the wrong one. + matches = re.findall(r"split_(\d+)", dataset_tag) + if matches: + tags["rank"] = matches[-1] + return tags + + def fmt(seconds: float) -> str: if seconds > 1: return str(round(seconds, 2)) + "s" @@ -448,7 +462,7 @@ def __init__(self, max_stats=1000): # Per Node metrics self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() - iter_tag_keys = ("dataset",) + iter_tag_keys = ("dataset", "rank") self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -488,6 +502,51 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked by iter_batches()", tag_keys=iter_tag_keys, ) + self.iter_total_s = Gauge( + "data_iter_total_seconds", + description="Total wall-clock seconds spent in the dataset iterator", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_fetch_s = Gauge( + "data_iter_blocked_fetch_seconds", + description="Seconds user thread is blocked on block fetching", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_batching_s = Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_format_s = Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_collate_s = Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_finalize_s = Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_restore_order_s = Gauge( + "data_iter_blocked_restore_order_seconds", + description="Seconds user thread is blocked on restoring batch order", + tag_keys=iter_tag_keys, + ) + self.iter_batches_total = Gauge( + "data_iter_batches_total", + description="Total batches delivered to the user thread", + tag_keys=iter_tag_keys, + ) + self.iter_rows_total = Gauge( + "data_iter_rows_total", + description="Total rows delivered to the user thread", + tag_keys=iter_tag_keys, + ) self.iter_user_s = Gauge( "data_iter_user_seconds", description="Seconds spent in user code", @@ -725,9 +784,10 @@ def update_iteration_metrics( stats: "DatasetStats", dataset_tag, ): - tags = self._create_tags(dataset_tag) + tags = self._create_iteration_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) + self.iter_total_s.set(stats.iter_total_s.get(), tags) self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags) self.iter_get_s.set(stats.iter_get_s.get(), tags) self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags) @@ -748,6 +808,16 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) + self.iter_blocked_fetch_s.set(stats.iter_blocked_fetch_s.get(), tags) + self.iter_blocked_batching_s.set(stats.iter_blocked_batching_s.get(), tags) + self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) + self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) + self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) + self.iter_blocked_restore_order_s.set( + stats.iter_blocked_restore_order_s.get(), tags + ) + self.iter_batches_total.set(stats.iter_batches_total, tags) + self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) def register_dataset( @@ -941,6 +1011,9 @@ def _create_tags( tags["node_ip"] = node_ip_tag return tags + def _create_iteration_tags(self, dataset_tag: Optional[str]): + return _create_iteration_tags(dataset_tag) + def get_or_create_stats_actor() -> ActorHandle[_StatsActor]: """Each cluster will contain exactly 1 _StatsActor. This function @@ -1138,9 +1211,17 @@ def __init__( self.iter_finalize_batch_s: Timer = Timer() self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() + self.iter_blocked_fetch_s: Timer = Timer() + self.iter_blocked_batching_s: Timer = Timer() + self.iter_blocked_format_s: Timer = Timer() + self.iter_blocked_collate_s: Timer = Timer() + self.iter_blocked_finalize_s: Timer = Timer() + self.iter_blocked_restore_order_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() + self.iter_batches_total: int = 0 + self.iter_rows_total: int = 0 self.extra_metrics = {} # Block fetch stats during iteration. @@ -1196,6 +1277,14 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocks_remote, self.iter_unknown_location, self.iter_prefetched_bytes, + self.iter_blocked_fetch_s, + self.iter_blocked_batching_s, + self.iter_blocked_format_s, + self.iter_blocked_collate_s, + self.iter_blocked_finalize_s, + self.iter_blocked_restore_order_s, + self.iter_batches_total, + self.iter_rows_total, ) stats_summary_parents = [] @@ -1878,6 +1967,16 @@ class IterStatsSummary: iter_unknown_location: int # Current bytes of prefetched blocks in the iterator iter_prefetched_bytes: int + # Per-stage training-thread blocked attribution timers. + blocked_fetch_time: Timer + blocked_batching_time: Timer + blocked_format_time: Timer + blocked_collate_time: Timer + blocked_finalize_time: Timer + blocked_restore_order_time: Timer + # Cumulative batch and row counters. + batches_total: int + rows_total: int def __str__(self) -> str: return self.to_string() @@ -1984,6 +2083,25 @@ def to_string(self) -> str: out += "Streaming split coordinator overhead time: " out += f"{fmt(self.streaming_split_coord_time.get())}\n" + # Per-stage training-thread blocked attribution. + stage_totals = [ + ("block fetch (ray.get)", self.blocked_fetch_time), + ("batching", self.blocked_batching_time), + ("format", self.blocked_format_time), + ("collate", self.blocked_collate_time), + ("finalize (host->device)", self.blocked_finalize_time), + ("restore order", self.blocked_restore_order_time), + ] + active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] + if active_stages: + out += "\nPer-stage training-thread blocked time breakdown:\n" + for stage_name, timer in active_stages: + out += " * {}: {}\n".format(stage_name, fmt(timer.get())) + if self.batches_total: + out += "Total batches consumed: {}\n".format(self.batches_total) + if self.rows_total: + out += "Total rows consumed: {}\n".format(self.rows_total) + return out def __repr__(self, level=0) -> str: diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 1f69962435ad..9cfaea2f8dfc 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -22,7 +22,7 @@ from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.logical.interfaces import LogicalPlan from ray.data._internal.logical.operators import InputData -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, _StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.collate_fn import ( ArrowBatchCollateFn, @@ -297,6 +297,7 @@ def callback(num_bytes: int) -> None: yield from batch_iterator if stats: stats.iter_total_s.add(time.perf_counter() - time_start) + _StatsManager.update_iteration_metrics(stats, dataset_tag) finally: # On early exit (e.g. ``break`` in the for-loop), the inner # ``_ClosingIterator`` would only shut down the executor via diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b0b55d1cb8fb..4c6fef31a4f5 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -11,7 +11,9 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, BlockPrefetcher, + StageTiming, ) from ray.data._internal.block_batching.iter_batches import ( BatchIterator, @@ -114,6 +116,415 @@ def test_restore_from_original_order(): assert idx == [0, 1, 2, 3] +def test_restore_original_order_stats(): + stats = DatasetStats(metadata={}, parent=None) + base_iterator = [ + Batch(BatchMetadata(batch_idx=2), None), + Batch(BatchMetadata(batch_idx=0), None), + Batch(BatchMetadata(batch_idx=1), None), + ] + + ordered = list(restore_original_order(iter(base_iterator), stats=stats)) + + assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] + assert any( + batch.metadata.timings.restore_order.start_s > 0 + and batch.metadata.timings.restore_order.end_s + >= batch.metadata.timings.restore_order.start_s + for batch in ordered + ) + + +def test_report_batch_timings_overlap_attribution(): + stats = DatasetStats(metadata={}, parent=None) + batch_iterator = BatchIterator(iter([]), stats=stats) + timings = BatchTimings(num_rows=8) + timings.fetch = StageTiming(start_s=10.0, end_s=20.0) + timings.batching = StageTiming(start_s=20.0, end_s=30.0) + timings.format = StageTiming(start_s=30.0, end_s=40.0) + timings.finalize = StageTiming(start_s=50.0, end_s=60.0) + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + batch_iterator._report_batch_timings( + batch, blocked_start_s=15.0, blocked_end_s=35.0 + ) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(10.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_collate_s.get() == 0 + assert stats.iter_blocked_finalize_s.get() == 0 + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 8 + + +def _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.0, + batching_start=0.0, + batching_end=0.0, + format_start=0.0, + format_end=0.0, + collate_start=0.0, + collate_end=0.0, + finalize_start=0.0, + finalize_end=0.0, + num_rows=0, +): + """Helper to construct a Batch with specific stage timing windows.""" + timings = BatchTimings(num_rows=num_rows) + timings.fetch = StageTiming(start_s=fetch_start, end_s=fetch_end) + timings.batching = StageTiming(start_s=batching_start, end_s=batching_end) + timings.format = StageTiming(start_s=format_start, end_s=format_end) + timings.collate = StageTiming(start_s=collate_start, end_s=collate_end) + timings.finalize = StageTiming(start_s=finalize_start, end_s=finalize_end) + return Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + +def _make_report_iterator(stats): + """Create a BatchIterator wired to the given stats without a real pipeline.""" + it = BatchIterator.__new__(BatchIterator) + it._stats = stats + return it + + +class TestReportBatchTimingsEdgeCases: + """Edge case tests for overlap-based blocked attribution.""" + + def test_zero_overlap_stage_finished_before_blocked(self): + """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.5) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == 0.0 + + def test_zero_overlap_blocked_before_stage(self): + """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=1.0) + assert stats.iter_blocked_format_s.get() == 0.0 + + def test_partial_overlap(self): + """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=2.0) + it._report_batch_timings(batch, blocked_start_s=1.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + + def test_full_overlap_stage_inside_blocked(self): + """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + + def test_no_collate_fn_zero_attribution(self): + """collate stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == 0.0 + + def test_no_finalize_fn_zero_attribution(self): + """finalize stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == 0.0 + + def test_prefetch_hides_fetch_from_training(self): + """Effective prefetch: fetch done before training blocks → 0 fetch attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.5, + collate_start=2.3, + collate_end=2.6, + ) + # Training only starts blocking at t=2 (prefetch worked) + it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=2.6) + assert stats.iter_blocked_fetch_s.get() == 0.0 + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) + + def test_accumulation_across_batches(self): + """Two batches each contribute to fetch — values accumulate.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 + b1 = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.0, num_rows=10) + it._report_batch_timings(b1, blocked_start_s=0.0, blocked_end_s=2.0) + # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 + b2 = _make_batch_with_timings(fetch_start=5.0, fetch_end=6.0, num_rows=20) + it._report_batch_timings(b2, blocked_start_s=5.0, blocked_end_s=7.0) + + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 30 + + def test_overlap_invariant_sum_leq_total(self): + """sum(iter_blocked_*) <= iter_total_blocked_s always holds.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + num_rows=5, + ) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + total = stats.iter_total_blocked_s.get() + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + + stats.iter_blocked_restore_order_s.get() + ) + assert sum_stages <= total + 1e-9 + + def test_restore_order_overlap(self): + """restore_order stage timing is correctly attributed.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + ) + batch.metadata.timings.restore_order = StageTiming(start_s=1.5, end_s=2.5) + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_restore_order_s.get() == pytest.approx(1.0) + + def test_blocked_inside_stage(self): + """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=10.0) + it._report_batch_timings(batch, blocked_start_s=3.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + + def test_all_stages_simultaneous_overlap(self): + """Multiple stages overlap with blocked window simultaneously.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + collate_start=3.0, + collate_end=4.0, + finalize_start=4.0, + finalize_end=5.0, + num_rows=100, + ) + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(1.0) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 100 + + +class TestStageTimingRecord: + """Tests for StageTiming.record() behavior.""" + + def test_record_sets_start_and_end(self): + """First record() sets both start_s and end_s.""" + t = StageTiming() + t.record(1.0, 2.0) + assert t.start_s == 1.0 + assert t.end_s == 2.0 + + def test_record_keeps_first_start(self): + """Subsequent record() calls keep the first start_s.""" + t = StageTiming() + t.record(1.0, 2.0) + t.record(3.0, 4.0) + assert t.start_s == 1.0 # kept first start + assert t.end_s == 4.0 # updated to latest end + + def test_record_multiple_expands_window(self): + """Multiple record() calls expand the end_s window.""" + t = StageTiming() + t.record(5.0, 6.0) + t.record(7.0, 8.0) + t.record(9.0, 10.0) + assert t.start_s == 5.0 + assert t.end_s == 10.0 + + def test_record_default_values(self): + """Unrecorded StageTiming has start_s=0 and end_s=0.""" + t = StageTiming() + assert t.start_s == 0.0 + assert t.end_s == 0.0 + + +class TestMergeFetch: + """Tests for BatchTimings.merge_fetch() with multiple blocks per batch.""" + + def test_merge_single_block(self): + """Merging a single block preserves its fetch window.""" + dst = BatchTimings() + src = BatchTimings() + src.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 2.0 + + def test_merge_multiple_blocks_expands_window(self): + """Merging multiple blocks produces the union window.""" + dst = BatchTimings() + + # Block 1: fetched [1.0, 2.0] + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=2.0) + dst.merge_fetch(src1) + + # Block 2: fetched [3.0, 4.0] + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=4.0) + dst.merge_fetch(src2) + + # Block 3: fetched [5.0, 6.0] + src3 = BatchTimings() + src3.fetch = StageTiming(start_s=5.0, end_s=6.0) + dst.merge_fetch(src3) + + # Union: [1.0, 6.0] + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 6.0 + + def test_merge_unrecorded_block_ignored(self): + """Merging a block with no fetch timing (start_s=0) is a no-op.""" + dst = BatchTimings() + dst.fetch = StageTiming(start_s=2.0, end_s=3.0) + + src = BatchTimings() # fetch defaults to (0.0, 0.0) + dst.merge_fetch(src) + + assert dst.fetch.start_s == 2.0 + assert dst.fetch.end_s == 3.0 + + def test_merge_overlapping_blocks(self): + """Overlapping fetch windows are correctly merged.""" + dst = BatchTimings() + + src1 = BatchTimings() + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + dst.merge_fetch(src1) + + src2 = BatchTimings() + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + dst.merge_fetch(src2) + + assert dst.fetch.start_s == 1.0 + assert dst.fetch.end_s == 7.0 + + def test_merge_into_empty_destination(self): + """Merging into an empty BatchTimings takes the source window.""" + dst = BatchTimings() # fetch = (0.0, 0.0) + src = BatchTimings() + src.fetch = StageTiming(start_s=10.0, end_s=20.0) + dst.merge_fetch(src) + assert dst.fetch.start_s == 10.0 + assert dst.fetch.end_s == 20.0 + + +class TestEndToEndTimingPropagation: + """Tests that stage timings propagate correctly through the full pipeline.""" + + def test_batch_carries_timings_through_pipeline(self): + """A Batch's metadata.timings carries all stage windows.""" + timings = BatchTimings(num_rows=50) + timings.fetch = StageTiming(start_s=1.0, end_s=2.0) + timings.batching = StageTiming(start_s=2.0, end_s=3.0) + timings.format = StageTiming(start_s=3.0, end_s=4.0) + timings.collate = StageTiming(start_s=4.0, end_s=5.0) + timings.finalize = StageTiming(start_s=5.0, end_s=6.0) + + batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + + # Verify all stages are accessible via stages() iterator + stage_dict = dict(batch.metadata.timings.stages()) + assert len(stage_dict) == 6 + assert stage_dict["fetch"].start_s == 1.0 + assert stage_dict["batching"].end_s == 3.0 + assert stage_dict["format"].start_s == 3.0 + assert stage_dict["collate"].end_s == 5.0 + assert stage_dict["finalize"].start_s == 5.0 + assert stage_dict["restore_order"].start_s == 0.0 # not recorded + assert batch.metadata.timings.num_rows == 50 + + def test_full_pipeline_attribution(self): + """End-to-end: all 6 stages with realistic timing, full overlap.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_report_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + + batch = _make_batch_with_timings( + fetch_start=0.0, + fetch_end=0.5, + batching_start=0.5, + batching_end=1.0, + format_start=1.0, + format_end=2.0, + collate_start=2.0, + collate_end=2.5, + finalize_start=2.5, + finalize_end=3.0, + num_rows=256, + ) + # Also set restore_order + batch.metadata.timings.restore_order = StageTiming(start_s=3.0, end_s=3.5) + + # Blocked window covers all stages + it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + # Each stage gets its full duration + assert stats.iter_blocked_fetch_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_batching_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_restore_order_s.get() == pytest.approx(0.5) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 256 + + # Invariant: sum = 3.5 <= total_blocked = 5.0 + sum_stages = ( + stats.iter_blocked_fetch_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + + stats.iter_blocked_restore_order_s.get() + ) + assert sum_stages == pytest.approx(3.5) + assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 + + def test_finalize_fn_uses_single_thread(ray_start_regular_shared): """Tests that finalize_fn is not run with multiple threads.""" ref_bundles_iter = ref_bundle_generator(num_blocks=20, num_rows=2) @@ -194,6 +605,27 @@ def collate_fn(batch: pd.DataFrame): assert concat_df["foo"].iloc[i + 1] >= concat_df["foo"].iloc[i] +def test_iter_batches_counts_rows_at_pipeline_exit(ray_start_regular_shared): + stats = DatasetStats(metadata={}, parent=None) + ref_bundles_iter = ref_bundle_generator(num_blocks=4, num_rows=2) + + output_batches = list( + BatchIterator( + ref_bundles_iter, + stats=stats, + batch_size=3, + prefetch_batches=0, + batch_format="pandas", + drop_last=True, + ) + ) + + assert len(output_batches) == 2 + assert [len(batch) for batch in output_batches] == [3, 3] + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 6 + + def test_iter_batches_e2e_async(ray_start_regular_shared): """We add time.sleep in 3 places: 1. In the base generator to simulate streaming executor blocking on next results. diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index cb4c31553541..28af6b1ae1ea 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -36,6 +36,7 @@ OperatorStatsSummary, StatsSummary, Timer, + _create_iteration_tags, _StatsActor, get_or_create_stats_actor, ) @@ -1878,6 +1879,164 @@ def test_stats_actor_iter_metrics(): assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1] +def test_create_iteration_tags_extracts_rank(): + assert _create_iteration_tags("train_abc_split_2") == { + "dataset": "train_abc_split_2", + "rank": "2", + } + assert _create_iteration_tags("dataset_without_split") == { + "dataset": "dataset_without_split", + "rank": "unknown", + } + # User-defined dataset name may contain split_; the trailing + # split index (from streaming split coordinator) should be used. + assert _create_iteration_tags("my_split_3_data_abc123_split_5") == { + "dataset": "my_split_3_data_abc123_split_5", + "rank": "5", + } + + +def test_update_iteration_metrics_exports_new_iter_metrics(): + stats = DatasetStats(metadata={}, parent=None) + stats.iter_total_s.add(11.0) + stats.iter_blocked_fetch_s.add(1.0) + stats.iter_blocked_batching_s.add(2.0) + stats.iter_blocked_format_s.add(3.0) + stats.iter_blocked_collate_s.add(4.0) + stats.iter_blocked_finalize_s.add(5.0) + stats.iter_blocked_restore_order_s.add(6.0) + stats.iter_batches_total = 7 + stats.iter_rows_total = 8 + + actor = _StatsActor.__ray_metadata__.modified_class.__new__( + _StatsActor.__ray_metadata__.modified_class + ) + recorded = {} + + class FakeGauge: + def __init__(self, name): + self.name = name + + def set(self, value, tags): + recorded[self.name] = (value, tags) + + for attr in [ + "iter_initialize_s", + "iter_total_s", + "iter_get_ref_bundles_s", + "iter_get_s", + "iter_next_batch_s", + "iter_format_batch_s", + "iter_collate_batch_s", + "iter_finalize_batch_s", + "iter_blocks_local", + "iter_blocks_remote", + "iter_unknown_location", + "iter_prefetched_bytes", + "iter_block_fetching_s", + "iter_batch_shaping_s", + "iter_batch_formatting_s", + "iter_batch_collating_s", + "iter_batch_finalizing_s", + "time_to_first_batch_s", + "iter_total_blocked_s", + "iter_blocked_fetch_s", + "iter_blocked_batching_s", + "iter_blocked_format_s", + "iter_blocked_collate_s", + "iter_blocked_finalize_s", + "iter_blocked_restore_order_s", + "iter_batches_total", + "iter_rows_total", + "iter_user_s", + ]: + setattr(actor, attr, FakeGauge(attr)) + + actor.update_iteration_metrics(stats, "train_dataset_split_3") + + expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} + assert recorded["iter_total_s"] == (11.0, expected_tags) + assert recorded["iter_blocked_fetch_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_batching_s"] == (2.0, expected_tags) + assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) + assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) + assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) + assert recorded["iter_blocked_restore_order_s"] == (6.0, expected_tags) + assert recorded["iter_batches_total"] == (7, expected_tags) + assert recorded["iter_rows_total"] == (8, expected_tags) + + +def test_iter_stats_summary_has_new_fields(): + """IterStatsSummary includes per-stage blocked timers and counters.""" + stats = DatasetStats(metadata={}, parent=None) + summary = stats.to_summary() + iter_summary = summary.iter_stats + + assert hasattr(iter_summary, "blocked_fetch_time") + assert hasattr(iter_summary, "blocked_batching_time") + assert hasattr(iter_summary, "blocked_format_time") + assert hasattr(iter_summary, "blocked_collate_time") + assert hasattr(iter_summary, "blocked_finalize_time") + assert hasattr(iter_summary, "blocked_restore_order_time") + assert hasattr(iter_summary, "batches_total") + assert hasattr(iter_summary, "rows_total") + + +def test_iter_stats_summary_reflects_accumulated_values(): + """IterStatsSummary carries the accumulated timer values.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_blocked_batching_s.add(0.2) + stats.iter_batches_total = 10 + stats.iter_rows_total = 320 + + summary = stats.to_summary().iter_stats + assert summary.blocked_fetch_time.get() == pytest.approx(0.5) + assert summary.blocked_batching_time.get() == pytest.approx(0.2) + assert summary.batches_total == 10 + assert summary.rows_total == 320 + + +def test_iter_stats_to_string_shows_stage_breakdown(): + """to_string() renders per-stage breakdown when values are non-zero.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(1.5) + stats.iter_blocked_format_s.add(0.8) + stats.iter_batches_total = 5 + stats.iter_rows_total = 160 + stats.iter_total_blocked_s.add(2.3) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + assert "format" in text + assert "Total batches consumed: 5" in text + assert "Total rows consumed: 160" in text + assert "Per-stage training-thread blocked time breakdown" in text + + +def test_iter_stats_to_string_omits_zero_stages(): + """to_string() omits stages with zero values from the breakdown.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_fetch_s.add(0.5) + stats.iter_total_blocked_s.add(0.5) + + text = str(stats.to_summary().iter_stats) + assert "block fetch" in text + # Zero stages should not appear + assert "batching" not in text + assert "collate" not in text + assert "restore order" not in text + + +def test_iter_stats_to_string_no_breakdown_when_all_zero(): + """When all blocked_* stages are zero, no breakdown section appears.""" + stats = DatasetStats(metadata={}, parent=None) + text = str(stats.to_summary().iter_stats) + assert "Per-stage training-thread blocked time breakdown" not in text + assert "Total batches consumed" not in text + assert "Total rows consumed" not in text + + def test_dataset_name_and_id(): # Test deprecated APIs: _set_name and _name ds = ray.data.range(1) From f3935abf291cfa6cd75026774b05bfd013c36ddf Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 20 Jun 2026 12:49:56 +0800 Subject: [PATCH 02/32] [Data] Remove unused ShufflingBatcher compaction tracking Reverts batcher.py changes that were only needed for the shuffle buffer metrics which have been removed from this PR's scope per reviewer feedback. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/batcher.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/ray/data/_internal/batcher.py b/python/ray/data/_internal/batcher.py index 7ff6ccf42c11..c097ee668de6 100644 --- a/python/ray/data/_internal/batcher.py +++ b/python/ray/data/_internal/batcher.py @@ -1,4 +1,3 @@ -import time import warnings from typing import Optional @@ -236,8 +235,6 @@ def __init__( self._total_object_store_nbytes = get_total_obj_store_mem_on_node() self._total_num_rows_added = 0 self._total_nbytes_added = 0 - self.compactions_total = 0 - self.compaction_time_s = 0.0 def add(self, block: Block): """Add a block to the shuffle buffer. @@ -323,9 +320,6 @@ def _num_rows(self) -> int: """ return self._num_compacted_rows() + self._num_uncompacted_rows() - def num_rows(self) -> int: - return self._num_rows() - def _num_compacted_rows(self) -> int: """Return number of unyielded rows in the compacted buffer.""" if self._shuffle_buffer is None: @@ -347,7 +341,6 @@ def next_batch(self) -> Block: self._done_adding or self._num_compacted_rows() <= self._min_rows_to_yield_batch ): - compaction_start_s = time.perf_counter() if self._shuffle_buffer is not None and self._batch_head < len( self._shuffled_indices ): @@ -370,8 +363,6 @@ def next_batch(self) -> Block: self._builder = DelegatingBlockBuilder() self._batch_head = 0 - self.compactions_total += 1 - self.compaction_time_s += time.perf_counter() - compaction_start_s assert self._shuffle_buffer is not None assert self._shuffled_indices is not None From 543a68d7e198d39721d295e0ba785472927305db Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:30:56 +0800 Subject: [PATCH 03/32] [Data] Remove restore_order stage from blocked attribution Per reviewer feedback, restore_order is an implementation detail rather than an actionable user-facing metric. Reverts restore_original_order() to the original simple for-loop and removes the data_iter_blocked_restore_order_seconds Prometheus metric along with all related fields, exports, and tests. The PR now exposes 8 core metrics (5 blocked stages + batches/rows/total). Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 2 - .../_internal/block_batching/iter_batches.py | 38 +++++-------------- python/ray/data/_internal/stats.py | 12 ------ .../tests/block_batching/test_iter_batches.py | 36 +++--------------- python/ray/data/tests/test_stats.py | 4 -- 5 files changed, 14 insertions(+), 78 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 10587586c941..aabb6781a33d 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -26,7 +26,6 @@ class BatchTimings: format: StageTiming = field(default_factory=StageTiming) collate: StageTiming = field(default_factory=StageTiming) finalize: StageTiming = field(default_factory=StageTiming) - restore_order: StageTiming = field(default_factory=StageTiming) num_rows: int = 0 def stages(self) -> Iterable[Tuple[str, StageTiming]]: @@ -36,7 +35,6 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ("format", self.format), ("collate", self.collate), ("finalize", self.finalize), - ("restore_order", self.restore_order), ) def merge_fetch(self, other: "BatchTimings") -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 66e55437d68f..2572f6a10dcc 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -218,7 +218,7 @@ def _finalize_batches( def _restore_original_batch_order( self, batches: Iterator[Batch] ) -> Iterator[Batch]: - return restore_original_order(batches, stats=self._stats) + return restore_original_order(batches) def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: # Step 1: Prefetch logical batches locally. @@ -474,9 +474,7 @@ def get_next_ref_bundle() -> RefBundle: prefetcher.stop() -def restore_original_order( - batch_iter: Iterator[Batch], stats: Optional[DatasetStats] = None -) -> Iterator[Batch]: +def restore_original_order(batch_iter: Iterator[Batch]) -> Iterator[Batch]: """Restores the original order of the provided `batch_iter` This function will yield items from `base_iterator` in the correct order based on @@ -487,31 +485,13 @@ def restore_original_order( """ next_index_required = 0 buffer: Dict[int, Batch] = {} - restore_wait_start_s: Optional[float] = None - source_exhausted = False - - while True: + for batch in batch_iter: + assert batch.metadata.batch_idx not in buffer + buffer[batch.metadata.batch_idx] = batch while next_index_required in buffer: - next_batch = buffer.pop(next_index_required) - if restore_wait_start_s is not None: - next_batch.metadata.timings.restore_order.record( - restore_wait_start_s, time.perf_counter() - ) - restore_wait_start_s = None - yield next_batch + yield buffer.pop(next_index_required) next_index_required += 1 - if source_exhausted: - break - - if buffer and restore_wait_start_s is None: - restore_wait_start_s = time.perf_counter() - - try: - batch = next(batch_iter) - except StopIteration: - source_exhausted = True - continue - - assert batch.metadata.batch_idx not in buffer - buffer[batch.metadata.batch_idx] = batch + while next_index_required in buffer: + yield buffer.pop(next_index_required) + next_index_required += 1 diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 7afdb967aeaa..adb06af01c2f 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -532,11 +532,6 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked on batch finalization", tag_keys=iter_tag_keys, ) - self.iter_blocked_restore_order_s = Gauge( - "data_iter_blocked_restore_order_seconds", - description="Seconds user thread is blocked on restoring batch order", - tag_keys=iter_tag_keys, - ) self.iter_batches_total = Gauge( "data_iter_batches_total", description="Total batches delivered to the user thread", @@ -813,9 +808,6 @@ def update_iteration_metrics( self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) - self.iter_blocked_restore_order_s.set( - stats.iter_blocked_restore_order_s.get(), tags - ) self.iter_batches_total.set(stats.iter_batches_total, tags) self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) @@ -1216,7 +1208,6 @@ def __init__( self.iter_blocked_format_s: Timer = Timer() self.iter_blocked_collate_s: Timer = Timer() self.iter_blocked_finalize_s: Timer = Timer() - self.iter_blocked_restore_order_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() @@ -1282,7 +1273,6 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocked_format_s, self.iter_blocked_collate_s, self.iter_blocked_finalize_s, - self.iter_blocked_restore_order_s, self.iter_batches_total, self.iter_rows_total, ) @@ -1973,7 +1963,6 @@ class IterStatsSummary: blocked_format_time: Timer blocked_collate_time: Timer blocked_finalize_time: Timer - blocked_restore_order_time: Timer # Cumulative batch and row counters. batches_total: int rows_total: int @@ -2090,7 +2079,6 @@ def to_string(self) -> str: ("format", self.blocked_format_time), ("collate", self.blocked_collate_time), ("finalize (host->device)", self.blocked_finalize_time), - ("restore order", self.blocked_restore_order_time), ] active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] if active_stages: diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 4c6fef31a4f5..3d6e94fb0676 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -117,22 +117,15 @@ def test_restore_from_original_order(): def test_restore_original_order_stats(): - stats = DatasetStats(metadata={}, parent=None) base_iterator = [ Batch(BatchMetadata(batch_idx=2), None), Batch(BatchMetadata(batch_idx=0), None), Batch(BatchMetadata(batch_idx=1), None), ] - ordered = list(restore_original_order(iter(base_iterator), stats=stats)) + ordered = list(restore_original_order(iter(base_iterator))) assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] - assert any( - batch.metadata.timings.restore_order.start_s > 0 - and batch.metadata.timings.restore_order.end_s - >= batch.metadata.timings.restore_order.start_s - for batch in ordered - ) def test_report_batch_timings_overlap_attribution(): @@ -294,23 +287,9 @@ def test_overlap_invariant_sum_leq_total(self): + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() + stats.iter_blocked_finalize_s.get() - + stats.iter_blocked_restore_order_s.get() ) assert sum_stages <= total + 1e-9 - def test_restore_order_overlap(self): - """restore_order stage timing is correctly attributed.""" - stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) - batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=1.0, - ) - batch.metadata.timings.restore_order = StageTiming(start_s=1.5, end_s=2.5) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) - assert stats.iter_blocked_restore_order_s.get() == pytest.approx(1.0) - def test_blocked_inside_stage(self): """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" stats = DatasetStats(metadata={}, parent=None) @@ -468,17 +447,16 @@ def test_batch_carries_timings_through_pipeline(self): # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) - assert len(stage_dict) == 6 + assert len(stage_dict) == 5 assert stage_dict["fetch"].start_s == 1.0 assert stage_dict["batching"].end_s == 3.0 assert stage_dict["format"].start_s == 3.0 assert stage_dict["collate"].end_s == 5.0 assert stage_dict["finalize"].start_s == 5.0 - assert stage_dict["restore_order"].start_s == 0.0 # not recorded assert batch.metadata.timings.num_rows == 50 def test_full_pipeline_attribution(self): - """End-to-end: all 6 stages with realistic timing, full overlap.""" + """End-to-end: all 5 stages with realistic timing, full overlap.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) stats.iter_total_blocked_s.add(5.0) @@ -496,8 +474,6 @@ def test_full_pipeline_attribution(self): finalize_end=3.0, num_rows=256, ) - # Also set restore_order - batch.metadata.timings.restore_order = StageTiming(start_s=3.0, end_s=3.5) # Blocked window covers all stages it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) @@ -508,20 +484,18 @@ def test_full_pipeline_attribution(self): assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) - assert stats.iter_blocked_restore_order_s.get() == pytest.approx(0.5) assert stats.iter_batches_total == 1 assert stats.iter_rows_total == 256 - # Invariant: sum = 3.5 <= total_blocked = 5.0 + # Invariant: sum = 3.0 <= total_blocked = 5.0 sum_stages = ( stats.iter_blocked_fetch_s.get() + stats.iter_blocked_batching_s.get() + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() + stats.iter_blocked_finalize_s.get() - + stats.iter_blocked_restore_order_s.get() ) - assert sum_stages == pytest.approx(3.5) + assert sum_stages == pytest.approx(3.0) assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 28af6b1ae1ea..245c0bfed8d4 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1904,7 +1904,6 @@ def test_update_iteration_metrics_exports_new_iter_metrics(): stats.iter_blocked_format_s.add(3.0) stats.iter_blocked_collate_s.add(4.0) stats.iter_blocked_finalize_s.add(5.0) - stats.iter_blocked_restore_order_s.add(6.0) stats.iter_batches_total = 7 stats.iter_rows_total = 8 @@ -1945,7 +1944,6 @@ def set(self, value, tags): "iter_blocked_format_s", "iter_blocked_collate_s", "iter_blocked_finalize_s", - "iter_blocked_restore_order_s", "iter_batches_total", "iter_rows_total", "iter_user_s", @@ -1961,7 +1959,6 @@ def set(self, value, tags): assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) - assert recorded["iter_blocked_restore_order_s"] == (6.0, expected_tags) assert recorded["iter_batches_total"] == (7, expected_tags) assert recorded["iter_rows_total"] == (8, expected_tags) @@ -1977,7 +1974,6 @@ def test_iter_stats_summary_has_new_fields(): assert hasattr(iter_summary, "blocked_format_time") assert hasattr(iter_summary, "blocked_collate_time") assert hasattr(iter_summary, "blocked_finalize_time") - assert hasattr(iter_summary, "blocked_restore_order_time") assert hasattr(iter_summary, "batches_total") assert hasattr(iter_summary, "rows_total") From 3a39cbd66c2136ecf098772f9a68d6685dfbcae2 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:50:33 +0800 Subject: [PATCH 04/32] [Data] Consolidate timing into StageTiming context manager Per reviewer feedback, consolidates the dual timing mechanism: - StageTiming now supports context manager protocol (__enter__/ __exit__) to automatically capture start_s/end_s - Timer gains start_s/end_s fields populated by timer() - Pipeline functions (resolve_block_refs, _format_batch, _collate_batch, _finalize_batch) use nested context managers instead of redundant perf_counter() + _record_stage_window() - resolve_block_refs always returns BlockWithTiming, removing the record_timings parameter, Union types, and isinstance branching - Removed _record_stage_window helper (no longer needed) Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 31 +++++-- .../_internal/block_batching/iter_batches.py | 4 +- .../ray/data/_internal/block_batching/util.py | 86 ++++++++----------- python/ray/data/_internal/stats.py | 9 +- .../tests/block_batching/test_iter_batches.py | 34 +++----- 5 files changed, 81 insertions(+), 83 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index aabb6781a33d..4f3020e65412 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,4 +1,6 @@ import abc +import time +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, Iterable, List, Tuple @@ -8,15 +10,34 @@ @dataclass class StageTiming: - """Wall-clock window for a batch-processing stage.""" + """Wall-clock window for a single batch-processing stage. + + Can be used as a context manager to automatically capture the start and + end timestamps of a pipeline operation:: + + with stage_timing: + do_work() + # stage_timing.start_s and stage_timing.end_s are now set + """ start_s: float = 0.0 end_s: float = 0.0 - def record(self, start_s: float, end_s: float) -> None: - if self.start_s == 0.0: - self.start_s = start_s - self.end_s = end_s + def __enter__(self): + self.start_s = time.perf_counter() + return self + + def __exit__(self, *args): + self.end_s = time.perf_counter() + + @contextmanager + def timer(self): + """Alias for using as a context manager, matching Timer.timer() API.""" + self.start_s = time.perf_counter() + try: + yield + finally: + self.end_s = time.perf_counter() @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 2572f6a10dcc..dbcfedecd748 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -176,9 +176,7 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] ) -> Iterator[Any]: - return resolve_block_refs( - block_ref_iter=block_refs, stats=self._stats, record_timings=True - ) + return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: return blocks_to_batches( diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index fb9f1f3aa11a..a7dfffe9bee9 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,7 +3,6 @@ import logging import queue import threading -import time from contextlib import nullcontext from typing import ( Any, @@ -15,7 +14,6 @@ Optional, Tuple, TypeVar, - Union, ) import ray @@ -28,7 +26,6 @@ BlockPrefetcher, BlockWithTiming, CollatedBatch, - StageTiming, ) from ray.data._internal.stats import DatasetStats from ray.data.block import Block, BlockAccessor, DataBatch @@ -175,27 +172,24 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: return 0, 0, 0 -def _record_stage_window(stage: StageTiming, start_s: float, end_s: float) -> None: - stage.record(start_s, end_s) - - def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, - record_timings: bool = False, -) -> Iterator[Union[Block, BlockWithTiming]]: +) -> Iterator[BlockWithTiming]: """Resolves the block references for each logical batch. + Each resolved block is wrapped in a :class:`BlockWithTiming` that carries + the per-block fetch window (``start_s``/``end_s`` around ``ray.get()``). + When *stats* is provided, the cumulative fetch time is also recorded in + ``stats.iter_get_s``. + Args: block_ref_iter: An iterator over block object references. - stats: An optional stats object to recording block hits and misses. - record_timings: If True, wrap each resolved block in a - ``BlockWithTiming`` carrying the per-block fetch window. + stats: An optional stats object to record block hits, misses, and + cumulative fetch time. Yields: - Union[Block, BlockWithTiming]: The resolved blocks. When - *record_timings* is ``True`` each block is wrapped in a - ``BlockWithTiming``; otherwise raw ``Block`` instances are yielded. + BlockWithTiming: Each resolved block with its fetch timing window. """ hits = 0 misses = 0 @@ -209,16 +203,11 @@ def resolve_block_refs( # TODO(amogkam): Optimized further by batching multiple references in a single # `ray.get()` call. - start_s = time.perf_counter() - with stats.iter_get_s.timer() if stats else nullcontext(): - block = ray.get(block_ref) - end_s = time.perf_counter() - if record_timings: - timings = BatchTimings() - _record_stage_window(timings.fetch, start_s, end_s) - yield BlockWithTiming(block=block, timings=timings) - else: - yield block + timings = BatchTimings() + with timings.fetch: + with stats.iter_get_s.timer() if stats else nullcontext(): + block = ray.get(block_ref) + yield BlockWithTiming(block=block, timings=timings) if stats: stats.iter_blocks_local = hits @@ -227,7 +216,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -256,7 +245,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -294,9 +283,8 @@ def __next__(self) -> Batch: if can_yield: with timer: - start_s = time.perf_counter() - next_batch = self._batcher.next_batch() - end_s = time.perf_counter() + with self._pending_timings.batching: + next_batch = self._batcher.next_batch() res = Batch( metadata=BatchMetadata( @@ -305,7 +293,6 @@ def __next__(self) -> Batch: ), data=next_batch, ) - _record_stage_window(res.metadata.timings.batching, start_s, end_s) res.metadata.timings.num_rows = BlockAccessor.for_block( next_batch ).num_rows() @@ -318,11 +305,9 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - if isinstance(block, BlockWithTiming): - self._pending_timings.merge_fetch(block.timings) - block = block.block - self._batcher.add(block) + block_with_timing = next(self._block_iter) + self._pending_timings.merge_fetch(block_with_timing.timings) + self._batcher.add(block_with_timing.block) except StopIteration: self._batcher.done_adding() self._done_adding = True @@ -341,14 +326,13 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - start_s = time.perf_counter() - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) - if ensure_copy: - formatted_data = _copy_batch(formatted_data) - _record_stage_window(batch.metadata.timings.format, start_s, time.perf_counter()) + with batch.metadata.timings.format: + with stats.iter_format_batch_s.timer() if stats else nullcontext(): + formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) + if ensure_copy: + formatted_data = _copy_batch(formatted_data) return dataclasses.replace(batch, data=formatted_data) @@ -396,10 +380,9 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - start_s = time.perf_counter() - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): - collated_data = collate_fn(batch.data) - _record_stage_window(batch.metadata.timings.collate, start_s, time.perf_counter()) + with batch.metadata.timings.collate: + with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + collated_data = collate_fn(batch.data) return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -423,10 +406,9 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - start_s = time.perf_counter() - with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): - finalized_data = finalize_fn(batch.data) - _record_stage_window(batch.metadata.timings.finalize, start_s, time.perf_counter()) + with batch.metadata.timings.finalize: + with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): + finalized_data = finalize_fn(batch.data) return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index adb06af01c2f..dc5c32d1d500 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -195,17 +195,22 @@ def __init__(self): self._min: float = float("inf") self._max: float = 0 self._total_count: float = 0 + # Wall-clock window of the most recent timer() invocation. + # Used by overlap-based blocked attribution in iter_batches. + self.start_s: float = 0.0 + self.end_s: float = 0.0 # Bounded-memory percentile backend. add() forwards every value # to ``add_sample`` and ``percentile`` reads from it. self._distribution: DistributionTracker = DistributionTracker() @contextmanager def timer(self) -> None: - time_start = time.perf_counter() + self.start_s = time.perf_counter() try: yield finally: - self.add(time.perf_counter() - time_start) + self.end_s = time.perf_counter() + self.add(self.end_s - self.start_s) def add(self, value: float) -> None: self._total += value diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 3d6e94fb0676..3ca2c68725d8 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -329,31 +329,23 @@ def test_all_stages_simultaneous_overlap(self): class TestStageTimingRecord: """Tests for StageTiming.record() behavior.""" - def test_record_sets_start_and_end(self): - """First record() sets both start_s and end_s.""" + def test_context_manager_captures_window(self): + """Using as context manager captures start_s and end_s.""" t = StageTiming() - t.record(1.0, 2.0) - assert t.start_s == 1.0 - assert t.end_s == 2.0 + with t: + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s - def test_record_keeps_first_start(self): - """Subsequent record() calls keep the first start_s.""" + def test_timer_context_manager(self): + """The timer() method works as a context manager too.""" t = StageTiming() - t.record(1.0, 2.0) - t.record(3.0, 4.0) - assert t.start_s == 1.0 # kept first start - assert t.end_s == 4.0 # updated to latest end + with t.timer(): + pass + assert t.start_s > 0 + assert t.end_s >= t.start_s - def test_record_multiple_expands_window(self): - """Multiple record() calls expand the end_s window.""" - t = StageTiming() - t.record(5.0, 6.0) - t.record(7.0, 8.0) - t.record(9.0, 10.0) - assert t.start_s == 5.0 - assert t.end_s == 10.0 - - def test_record_default_values(self): + def test_default_values(self): """Unrecorded StageTiming has start_s=0 and end_s=0.""" t = StageTiming() assert t.start_s == 0.0 From 16fea70f3ae648d15da5a2ea619f3acfb7827360 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:54:16 +0800 Subject: [PATCH 05/32] [Data] Capture upstream blocked time in fetch stage The fetch timing window in resolve_block_refs now spans from when we start waiting for the upstream iterator (blocked on the data pipeline) through ray.get() completion. This captures cross-node transfer and upstream production delays, giving a more complete picture of what blocks the training thread. Signed-off-by: OneSizeFitsQuorum --- .../ray/data/_internal/block_batching/util.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index a7dfffe9bee9..fcc6cf2a40c2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -179,7 +179,9 @@ def resolve_block_refs( """Resolves the block references for each logical batch. Each resolved block is wrapped in a :class:`BlockWithTiming` that carries - the per-block fetch window (``start_s``/``end_s`` around ``ray.get()``). + the per-block fetch window. The fetch window spans from the moment we + start waiting for the upstream iterator (blocked on the data pipeline or + cross-node transfer) until ``ray.get()`` returns the resolved block. When *stats* is provided, the cumulative fetch time is also recorded in ``stats.iter_get_s``. @@ -195,18 +197,28 @@ def resolve_block_refs( misses = 0 unknowns = 0 - for block_ref in block_ref_iter: - current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) - hits += current_hit - misses += current_miss - unknowns += current_unknown - - # TODO(amogkam): Optimized further by batching multiple references in a single - # `ray.get()` call. + while True: + # Time the upstream pull — captures blocked time waiting for the + # data pipeline to produce the next block ref. timings = BatchTimings() with timings.fetch: + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + + current_hit, current_miss, current_unknown = _calculate_ref_hits( + [block_ref] + ) + hits += current_hit + misses += current_miss + unknowns += current_unknown + + # TODO(amogkam): Optimized further by batching multiple references + # in a single `ray.get()` call. with stats.iter_get_s.timer() if stats else nullcontext(): block = ray.get(block_ref) + yield BlockWithTiming(block=block, timings=timings) if stats: From 71da139f9aa60f89db049334cc720df2c3a80076 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 10:57:35 +0800 Subject: [PATCH 06/32] [Data] Add docstrings to timing dataclasses and _report_batch_timings Per reviewer feedback, adds clear docstrings to: - BatchTimings (per-batch pipeline-stage timing windows) - BlockWithTiming (resolved block with fetch timing) - BatchTimings.merge_fetch() (multi-block fetch window expansion) - BatchTimings.stages() (stage name/timing iterator) - _report_batch_timings() (overlap-based attribution algorithm) Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 30 ++++++++++++++++++ .../_internal/block_batching/iter_batches.py | 31 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f3020e65412..f0cd2ad1ff7e 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -42,6 +42,23 @@ def timer(self): @dataclass class BatchTimings: + """Per-batch pipeline-stage timing windows for overlap-based attribution. + + Each field records the ``(start_s, end_s)`` wall-clock window during which + a particular pipeline stage was active for this batch. The training thread + later compares these windows against its own blocked window to determine + how much each stage contributed to training-thread stall (see + :meth:`BatchIterator._report_batch_timings`). + + Attributes: + fetch: Waiting for upstream data production + ``ray.get()`` transfer. + batching: Assembling blocks into a batch via ``_batcher.next_batch()``. + format: Converting the batch to the requested format (numpy, pandas…). + collate: Running the user-provided ``collate_fn``. + finalize: Running the user-provided ``finalize_fn`` (e.g. host→device). + num_rows: Number of rows in this batch (for ``iter_rows_total``). + """ + fetch: StageTiming = field(default_factory=StageTiming) batching: StageTiming = field(default_factory=StageTiming) format: StageTiming = field(default_factory=StageTiming) @@ -50,6 +67,7 @@ class BatchTimings: num_rows: int = 0 def stages(self) -> Iterable[Tuple[str, StageTiming]]: + """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" return ( ("fetch", self.fetch), ("batching", self.batching), @@ -59,6 +77,12 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ) def merge_fetch(self, other: "BatchTimings") -> None: + """Expand this batch's fetch window to encompass another's. + + Used when a single batch is assembled from multiple blocks, each + fetched independently. The merged window spans from the earliest + fetch start to the latest fetch end. + """ self._merge_stage(self.fetch, other.fetch) @staticmethod @@ -73,6 +97,12 @@ def _merge_stage(dst: StageTiming, src: StageTiming) -> None: @dataclass class BlockWithTiming: + """A resolved block paired with its fetch timing window. + + Produced by :func:`resolve_block_refs` so that downstream pipeline stages + can track how long each block took to fetch (upstream wait + ``ray.get()``). + """ + block: Block timings: BatchTimings = field(default_factory=BatchTimings) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index dbcfedecd748..bb4e34d6aad4 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -264,6 +264,37 @@ def _iter_batches(self) -> Iterator[DataBatch]: def _report_batch_timings( self, batch: Batch, blocked_start_s: float, blocked_end_s: float ) -> None: + """Attribute per-stage blocked time via overlap with the training window. + + For each pipeline stage we know when it ran ``[stage.start_s, + stage.end_s]`` (recorded by background threads onto + ``batch.metadata.timings``). We also know when the training thread + was blocked ``[blocked_start_s, blocked_end_s]`` (captured in + ``_iter_batches`` around ``next()``). + + The attribution for a stage is the length of the intersection:: + + overlap = min(stage.end, blocked_end) - max(stage.start, blocked_start) + + This correctly handles all prefetch configurations: + + * Stage finished before training blocked → overlap ≤ 0 → zero credit. + * Stage fully inside blocked window → full stage duration credited. + * Partial overlap → partial credit. + + **Invariant**: ``sum(iter_blocked_*) ≤ iter_total_blocked_s``. + + Runs in the training thread; no locks needed because background + threads finished writing ``batch.metadata.timings`` before the batch + was enqueued. + + Args: + batch: The batch whose per-stage timings should be attributed. + blocked_start_s: ``perf_counter()`` value just before the + training thread called ``next(batch_iter)``. + blocked_end_s: ``perf_counter()`` value just after ``next()`` + returned. + """ if self._stats is None: return timings = batch.metadata.timings From 9fcde561a5e2f6f0b227d0096c2068cbcca8228a Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 14:13:32 +0800 Subject: [PATCH 07/32] [Data] Restore isinstance check for BlockWithTiming compatibility _BatchingIterator can receive blocks from paths other than resolve_block_refs (e.g., doctest examples that pass raw pyarrow Tables). Restore the isinstance check to handle both BlockWithTiming and raw Block objects gracefully. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index fcc6cf2a40c2..2dca971d9fb8 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -14,6 +14,7 @@ Optional, Tuple, TypeVar, + Union, ) import ray @@ -228,7 +229,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -257,7 +258,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[Union[Block, BlockWithTiming]], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -317,9 +318,11 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block_with_timing = next(self._block_iter) - self._pending_timings.merge_fetch(block_with_timing.timings) - self._batcher.add(block_with_timing.block) + block = next(self._block_iter) + if isinstance(block, BlockWithTiming): + self._pending_timings.merge_fetch(block.timings) + block = block.block + self._batcher.add(block) except StopIteration: self._batcher.done_adding() self._done_adding = True From 1266c927fe8df6238825fc6650b98d5a1865cfc0 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 14:58:26 +0800 Subject: [PATCH 08/32] [Data] Refactor to eliminate isinstance/Union in _BatchingIterator Per reviewer feedback, removed isinstance check and Union type from _BatchingIterator by ensuring all entry points wrap blocks in BlockWithTiming: - batch_blocks() now wraps raw blocks in BlockWithTiming with zero timing before passing to blocks_to_batches() - _BatchingIterator now assumes all blocks are BlockWithTiming - Removed Union import from util.py This provides a uniform type throughout the batching pipeline while maintaining backward compatibility for external callers of batch_blocks(). Signed-off-by: OneSizeFitsQuorum --- .../data/_internal/block_batching/block_batching.py | 10 +++++++++- python/ray/data/_internal/block_batching/util.py | 13 +++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index ef54a593920b..e10a212de1a6 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,5 +1,9 @@ from typing import Callable, Iterator, Optional, TypeVar +from ray.data._internal.block_batching.interfaces import ( + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _MappingIterator, blocks_to_batches, @@ -29,10 +33,14 @@ def batch_blocks( This function takes in an iterator of already fetched blocks. Consequently, this function doesn't support block prefetching. """ + # Wrap raw blocks in BlockWithTiming with zero timing so that + # _BatchingIterator receives a uniform type. + wrapped_blocks = (BlockWithTiming(block=b, timings=BatchTimings()) for b in blocks) + # Build the processing pipeline batch_iter = format_batches( blocks_to_batches( - block_iter=blocks, + block_iter=wrapped_blocks, stats=stats, batch_size=batch_size, drop_last=drop_last, diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 2dca971d9fb8..fcc6cf2a40c2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -14,7 +14,6 @@ Optional, Tuple, TypeVar, - Union, ) import ray @@ -229,7 +228,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -258,7 +257,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Union[Block, BlockWithTiming]], + block_iter: Iterator[BlockWithTiming], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -318,11 +317,9 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - if isinstance(block, BlockWithTiming): - self._pending_timings.merge_fetch(block.timings) - block = block.block - self._batcher.add(block) + block_with_timing = next(self._block_iter) + self._pending_timings.merge_fetch(block_with_timing.timings) + self._batcher.add(block_with_timing.block) except StopIteration: self._batcher.done_adding() self._done_adding = True From ff658f2b462f8d992dac833d7992c6830b5733f3 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:32:31 +0800 Subject: [PATCH 09/32] [Data] Fix merge_fetch idle gap and blocked window alignment Per Cursor Bugbot review: 1. merge_fetch now sums fetch durations instead of taking the span, avoiding counting idle gaps between consecutive block fetches as fetch blocking time. 2. Move blocked_start_s/blocked_end_s captures inside get_next_batch_context() so the blocked window aligns with iter_total_blocked_s, preventing sum(iter_blocked_*) from exceeding iter_total_blocked_s. Updated tests to reflect the new duration-summing behavior. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 18 +++++++++++----- .../_internal/block_batching/iter_batches.py | 4 ++-- .../tests/block_batching/test_iter_batches.py | 21 ++++++++++--------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index f0cd2ad1ff7e..c689dd504745 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -77,13 +77,21 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: ) def merge_fetch(self, other: "BatchTimings") -> None: - """Expand this batch's fetch window to encompass another's. + """Merge fetch timings from another batch into this one. - Used when a single batch is assembled from multiple blocks, each - fetched independently. The merged window spans from the earliest - fetch start to the latest fetch end. + Sums the fetch durations rather than taking the span, to avoid + counting idle gaps between consecutive block fetches as fetch time. """ - self._merge_stage(self.fetch, other.fetch) + if other.fetch.start_s == 0.0: + return + if self.fetch.start_s == 0.0: + # First block: copy the timing + self.fetch.start_s = other.fetch.start_s + self.fetch.end_s = other.fetch.end_s + else: + # Subsequent blocks: add duration to existing span + duration = other.fetch.end_s - other.fetch.start_s + self.fetch.end_s += duration @staticmethod def _merge_stage(dst: StageTiming, src: StageTiming) -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index bb4e34d6aad4..dc669c9725fd 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -248,13 +248,13 @@ def _iter_batches(self) -> Iterator[DataBatch]: self.before_epoch_start() while True: - blocked_start_s = time.perf_counter() with self.get_next_batch_context(): + blocked_start_s = time.perf_counter() try: batch = next(batch_iter) except StopIteration: break - blocked_end_s = time.perf_counter() + blocked_end_s = time.perf_counter() self._report_batch_timings(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 3ca2c68725d8..074ee23d147d 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -365,27 +365,27 @@ def test_merge_single_block(self): assert dst.fetch.end_s == 2.0 def test_merge_multiple_blocks_expands_window(self): - """Merging multiple blocks produces the union window.""" + """Merging multiple blocks sums their fetch durations.""" dst = BatchTimings() - # Block 1: fetched [1.0, 2.0] + # Block 1: fetched [1.0, 2.0] (duration: 1.0) src1 = BatchTimings() src1.fetch = StageTiming(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) - # Block 2: fetched [3.0, 4.0] + # Block 2: fetched [3.0, 4.0] (duration: 1.0) src2 = BatchTimings() src2.fetch = StageTiming(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) - # Block 3: fetched [5.0, 6.0] + # Block 3: fetched [5.0, 6.0] (duration: 1.0) src3 = BatchTimings() src3.fetch = StageTiming(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) - # Union: [1.0, 6.0] + # Sum of durations: 1.0 + 1.0 + 1.0 = 3.0 assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 6.0 + assert dst.fetch.end_s == 4.0 # 1.0 + 3.0 def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" @@ -399,19 +399,20 @@ def test_merge_unrecorded_block_ignored(self): assert dst.fetch.end_s == 3.0 def test_merge_overlapping_blocks(self): - """Overlapping fetch windows are correctly merged.""" + """Overlapping fetch windows sum their durations.""" dst = BatchTimings() src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) # duration: 4.0 dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) # duration: 4.0 dst.merge_fetch(src2) + # Sum of durations: 4.0 + 4.0 = 8.0 assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 7.0 + assert dst.fetch.end_s == 9.0 # 1.0 + 8.0 def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" From b9bf847d7857a50d7511cfcbebf91487d0056f3d Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:48:10 +0800 Subject: [PATCH 10/32] [Data] Remove unused _merge_stage method After changing merge_fetch to sum durations instead of taking the span, the _merge_stage helper is no longer called anywhere. Remove the dead code. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/interfaces.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index c689dd504745..6292425c1cac 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -93,15 +93,6 @@ def merge_fetch(self, other: "BatchTimings") -> None: duration = other.fetch.end_s - other.fetch.start_s self.fetch.end_s += duration - @staticmethod - def _merge_stage(dst: StageTiming, src: StageTiming) -> None: - if src.start_s == 0.0: - return - if dst.start_s == 0.0 or src.start_s < dst.start_s: - dst.start_s = src.start_s - if src.end_s > dst.end_s: - dst.end_s = src.end_s - @dataclass class BlockWithTiming: From aa41de515f8f29560d9baea1fe3c5c24acda3701 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 22 Jun 2026 15:59:17 +0800 Subject: [PATCH 11/32] [Data] Revert merge_fetch to span-based approach After deeper analysis, the span approach (taking [earliest_start, latest_end]) is semantically correct for multi-block fetches: - From the training thread's perspective, it's blocked for the entire span, even if there are gaps between consecutive block fetches - Those "idle gaps" are actually pipeline overhead (batching logic, scheduling) and are part of the blocking experience - Summing durations would underestimate the actual blocking time The Cursor Bugbot concern about "idle gaps" is valid in theory, but in practice: 1. The gaps are very small (microseconds of pipeline overhead) 2. They represent real blocking time from the training thread's perspective 3. Span aligns with the semantic meaning of "how long did training wait" Reverted tests to expect span behavior. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 14 +++++++----- .../tests/block_batching/test_iter_batches.py | 22 +++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 6292425c1cac..cfdd21e918a3 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -79,8 +79,10 @@ def stages(self) -> Iterable[Tuple[str, StageTiming]]: def merge_fetch(self, other: "BatchTimings") -> None: """Merge fetch timings from another batch into this one. - Sums the fetch durations rather than taking the span, to avoid - counting idle gaps between consecutive block fetches as fetch time. + Expands the fetch window to span from the earliest block fetch start + to the latest block fetch end. This represents the total time the + training thread was blocked waiting for this batch, including any + pipeline overhead between consecutive block fetches. """ if other.fetch.start_s == 0.0: return @@ -89,9 +91,11 @@ def merge_fetch(self, other: "BatchTimings") -> None: self.fetch.start_s = other.fetch.start_s self.fetch.end_s = other.fetch.end_s else: - # Subsequent blocks: add duration to existing span - duration = other.fetch.end_s - other.fetch.start_s - self.fetch.end_s += duration + # Subsequent blocks: expand the window + if other.fetch.start_s < self.fetch.start_s: + self.fetch.start_s = other.fetch.start_s + if other.fetch.end_s > self.fetch.end_s: + self.fetch.end_s = other.fetch.end_s @dataclass diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 074ee23d147d..63128be388e3 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -365,27 +365,27 @@ def test_merge_single_block(self): assert dst.fetch.end_s == 2.0 def test_merge_multiple_blocks_expands_window(self): - """Merging multiple blocks sums their fetch durations.""" + """Merging multiple blocks produces the union window.""" dst = BatchTimings() - # Block 1: fetched [1.0, 2.0] (duration: 1.0) + # Block 1: fetched [1.0, 2.0] src1 = BatchTimings() src1.fetch = StageTiming(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) - # Block 2: fetched [3.0, 4.0] (duration: 1.0) + # Block 2: fetched [3.0, 4.0] src2 = BatchTimings() src2.fetch = StageTiming(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) - # Block 3: fetched [5.0, 6.0] (duration: 1.0) + # Block 3: fetched [5.0, 6.0] src3 = BatchTimings() src3.fetch = StageTiming(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) - # Sum of durations: 1.0 + 1.0 + 1.0 = 3.0 + # Union: [1.0, 6.0] assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 4.0 # 1.0 + 3.0 + assert dst.fetch.end_s == 6.0 def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" @@ -399,20 +399,20 @@ def test_merge_unrecorded_block_ignored(self): assert dst.fetch.end_s == 3.0 def test_merge_overlapping_blocks(self): - """Overlapping fetch windows sum their durations.""" + """Overlapping fetch windows are correctly merged.""" dst = BatchTimings() src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=5.0) # duration: 4.0 + src1.fetch = StageTiming(start_s=1.0, end_s=5.0) dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=7.0) # duration: 4.0 + src2.fetch = StageTiming(start_s=3.0, end_s=7.0) dst.merge_fetch(src2) - # Sum of durations: 4.0 + 4.0 = 8.0 + # Union: [1.0, 7.0] assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 9.0 # 1.0 + 8.0 + assert dst.fetch.end_s == 7.0 def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" From 7874ef45b95889390c22ca418795dd54729cb4be Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Tue, 23 Jun 2026 09:57:09 +0800 Subject: [PATCH 12/32] [Data] Fix test failures from BlockWithTiming refactor - test_util.py: Updated test_resolve_block_refs to expect BlockWithTiming objects and test_blocks_to_batches to wrap raw blocks - block_batching.py: Changed generator expression to map() to avoid holding references to blocks, fixing test_chained_transforms_release_intermediates Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/block_batching.py | 8 ++++++-- .../ray/data/tests/block_batching/test_util.py | 17 ++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index e10a212de1a6..999e2e10af9b 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -34,8 +34,12 @@ def batch_blocks( function doesn't support block prefetching. """ # Wrap raw blocks in BlockWithTiming with zero timing so that - # _BatchingIterator receives a uniform type. - wrapped_blocks = (BlockWithTiming(block=b, timings=BatchTimings()) for b in blocks) + # _BatchingIterator receives a uniform type. Use map() instead of a + # generator expression to avoid holding references to blocks. + def _wrap_block(b): + return BlockWithTiming(block=b, timings=BatchTimings()) + + wrapped_blocks = map(_wrap_block, blocks) # Build the processing pipeline batch_iter = format_batches( diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 6ead5741f0e1..7a6223b06f78 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -13,7 +13,12 @@ import pytest import ray -from ray.data._internal.block_batching.interfaces import Batch, BatchMetadata +from ray.data._internal.block_batching.interfaces import ( + Batch, + BatchMetadata, + BatchTimings, + BlockWithTiming, +) from ray.data._internal.block_batching.util import ( _calculate_ref_hits, blocks_to_batches, @@ -37,7 +42,9 @@ def test_resolve_block_refs(ray_start_regular_shared): block_refs = [ray.put(0), ray.put(1), ray.put(2)] resolved_iter = resolve_block_refs(iter(block_refs)) - assert list(resolved_iter) == [0, 1, 2] + resolved = list(resolved_iter) + assert all(isinstance(b, BlockWithTiming) for b in resolved) + assert [b.block for b in resolved] == [0, 1, 2] @pytest.mark.parametrize("block_size", [1, 10]) @@ -45,10 +52,14 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_blocks_to_batches(block_size, drop_last): num_blocks = 5 block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + # Wrap raw blocks in BlockWithTiming as blocks_to_batches now expects + wrapped_blocks = ( + BlockWithTiming(block=b, timings=BatchTimings()) for b in block_iter + ) batch_size = 3 batch_iter = list( - blocks_to_batches(block_iter, batch_size=batch_size, drop_last=drop_last) + blocks_to_batches(wrapped_blocks, batch_size=batch_size, drop_last=drop_last) ) if drop_last: From 2d311f45a085a78d75fc116e4db0c46eded9520f Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Fri, 26 Jun 2026 16:42:51 +0800 Subject: [PATCH 13/32] [Data] Introduce TimeSpan and consolidate timing into Timer Per reviewer feedback, consolidates the dual timing mechanism: - New TimeSpan dataclass (pure data: start_s, end_s, duration) - Timer.timer() now yields a thread-local TimeSpan (safe for concurrent use from thread pools) - Removed thread-unsafe start_s/end_s fields from Timer - New _timed() helper yields Optional[TimeSpan] (None when no stats) - Removed StageTiming class entirely - BatchTimings fields are now Optional[TimeSpan] (None = stage didn't execute, eliminating magic-number 0.0 checks) - Pipeline functions use single _timed() context instead of nested StageTiming + Timer.timer() context managers - _merge_span helper handles Optional[TimeSpan] merge logic Co-authored-by: Claude Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 84 +++++++---------- .../_internal/block_batching/iter_batches.py | 8 +- .../ray/data/_internal/block_batching/util.py | 78 ++++++++-------- python/ray/data/_internal/stats.py | 52 +++++++++-- .../tests/block_batching/test_iter_batches.py | 90 +++++++++---------- 5 files changed, 161 insertions(+), 151 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index cfdd21e918a3..2d115ffefca9 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,45 +1,12 @@ import abc -import time -from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Iterable, List, Tuple +from typing import Any, Iterable, List, Optional, Tuple +from ray.data._internal.stats import TimeSpan from ray.data.block import Block, DataBatch from ray.types import ObjectRef -@dataclass -class StageTiming: - """Wall-clock window for a single batch-processing stage. - - Can be used as a context manager to automatically capture the start and - end timestamps of a pipeline operation:: - - with stage_timing: - do_work() - # stage_timing.start_s and stage_timing.end_s are now set - """ - - start_s: float = 0.0 - end_s: float = 0.0 - - def __enter__(self): - self.start_s = time.perf_counter() - return self - - def __exit__(self, *args): - self.end_s = time.perf_counter() - - @contextmanager - def timer(self): - """Alias for using as a context manager, matching Timer.timer() API.""" - self.start_s = time.perf_counter() - try: - yield - finally: - self.end_s = time.perf_counter() - - @dataclass class BatchTimings: """Per-batch pipeline-stage timing windows for overlap-based attribution. @@ -48,7 +15,10 @@ class BatchTimings: a particular pipeline stage was active for this batch. The training thread later compares these windows against its own blocked window to determine how much each stage contributed to training-thread stall (see - :meth:`BatchIterator._report_batch_timings`). + :meth:`BatchIterator._attribute_blocked_time`). + + A field value of ``None`` indicates the stage did not execute for this + batch (e.g. no ``collate_fn`` provided). Attributes: fetch: Waiting for upstream data production + ``ray.get()`` transfer. @@ -59,14 +29,14 @@ class BatchTimings: num_rows: Number of rows in this batch (for ``iter_rows_total``). """ - fetch: StageTiming = field(default_factory=StageTiming) - batching: StageTiming = field(default_factory=StageTiming) - format: StageTiming = field(default_factory=StageTiming) - collate: StageTiming = field(default_factory=StageTiming) - finalize: StageTiming = field(default_factory=StageTiming) + fetch: Optional[TimeSpan] = None + batching: Optional[TimeSpan] = None + format: Optional[TimeSpan] = None + collate: Optional[TimeSpan] = None + finalize: Optional[TimeSpan] = None num_rows: int = 0 - def stages(self) -> Iterable[Tuple[str, StageTiming]]: + def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" return ( ("fetch", self.fetch), @@ -84,18 +54,24 @@ def merge_fetch(self, other: "BatchTimings") -> None: training thread was blocked waiting for this batch, including any pipeline overhead between consecutive block fetches. """ - if other.fetch.start_s == 0.0: - return - if self.fetch.start_s == 0.0: - # First block: copy the timing - self.fetch.start_s = other.fetch.start_s - self.fetch.end_s = other.fetch.end_s - else: - # Subsequent blocks: expand the window - if other.fetch.start_s < self.fetch.start_s: - self.fetch.start_s = other.fetch.start_s - if other.fetch.end_s > self.fetch.end_s: - self.fetch.end_s = other.fetch.end_s + self.fetch = _merge_span(self.fetch, other.fetch) + + +def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]: + """Merge two optional ``TimeSpan`` windows into a spanning window. + + Returns ``dst`` unchanged if ``src`` is ``None`` (stage didn't run). + Returns a copy of ``src`` if ``dst`` is ``None`` (first block). + Otherwise returns a new ``TimeSpan`` spanning both windows. + """ + if src is None: + return dst + if dst is None: + return TimeSpan(start_s=src.start_s, end_s=src.end_s) + return TimeSpan( + start_s=min(dst.start_s, src.start_s), + end_s=max(dst.end_s, src.end_s), + ) @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index dc669c9725fd..bc59edb20866 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -298,11 +298,11 @@ def _report_batch_timings( if self._stats is None: return timings = batch.metadata.timings - for name, stage in timings.stages(): - if stage.start_s == 0.0 and stage.end_s == 0.0: + for name, timing in timings.stages(): + if timing is None: continue - overlap_s = min(stage.end_s, blocked_end_s) - max( - stage.start_s, blocked_start_s + overlap_s = min(timing.end_s, blocked_end_s) - max( + timing.start_s, blocked_start_s ) if overlap_s > 0: getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index fcc6cf2a40c2..623b49b61db6 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,7 +3,7 @@ import logging import queue import threading -from contextlib import nullcontext +import time from typing import ( Any, Callable, @@ -27,7 +27,7 @@ BlockWithTiming, CollatedBatch, ) -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, TimeSpan, _timed from ray.data.block import Block, BlockAccessor, DataBatch from ray.types import ObjectRef @@ -198,26 +198,26 @@ def resolve_block_refs( unknowns = 0 while True: - # Time the upstream pull — captures blocked time waiting for the - # data pipeline to produce the next block ref. + # Time the full fetch window — from upstream pull through ray.get(). + # In a follow-up, this will be split into production_wait + data_transfer. timings = BatchTimings() - with timings.fetch: - try: - block_ref = next(block_ref_iter) - except StopIteration: - break - - current_hit, current_miss, current_unknown = _calculate_ref_hits( - [block_ref] - ) - hits += current_hit - misses += current_miss - unknowns += current_unknown - - # TODO(amogkam): Optimized further by batching multiple references - # in a single `ray.get()` call. - with stats.iter_get_s.timer() if stats else nullcontext(): - block = ray.get(block_ref) + fetch_span = TimeSpan(start_s=time.perf_counter()) + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + + current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) + hits += current_hit + misses += current_miss + unknowns += current_unknown + + # TODO(amogkam): Optimized further by batching multiple references + # in a single `ray.get()` call. + with _timed(stats.iter_get_s if stats else None): + block = ray.get(block_ref) + fetch_span.end_s = time.perf_counter() + timings.fetch = fetch_span yield BlockWithTiming(block=block, timings=timings) @@ -285,8 +285,6 @@ def __iter__(self) -> "_BatchingIterator": return self def __next__(self) -> Batch: - timer = self._stats.iter_next_batch_s.timer() if self._stats else nullcontext() - # Try to get a batch from current batcher state while True: can_yield = self._batcher.has_batch() or ( @@ -294,9 +292,11 @@ def __next__(self) -> Batch: ) if can_yield: - with timer: - with self._pending_timings.batching: - next_batch = self._batcher.next_batch() + with _timed( + self._stats.iter_next_batch_s if self._stats else None + ) as span: + next_batch = self._batcher.next_batch() + self._pending_timings.batching = span res = Batch( metadata=BatchMetadata( @@ -338,13 +338,13 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - with batch.metadata.timings.format: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): - formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( - batch_format - ) - if ensure_copy: - formatted_data = _copy_batch(formatted_data) + with _timed(stats.iter_format_batch_s if stats else None) as span: + formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( + batch_format + ) + if ensure_copy: + formatted_data = _copy_batch(formatted_data) + batch.metadata.timings.format = span return dataclasses.replace(batch, data=formatted_data) @@ -392,9 +392,9 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with batch.metadata.timings.collate: - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): - collated_data = collate_fn(batch.data) + with _timed(stats.iter_collate_batch_s if stats else None) as span: + collated_data = collate_fn(batch.data) + batch.metadata.timings.collate = span return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -418,9 +418,9 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with batch.metadata.timings.finalize: - with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): - finalized_data = finalize_fn(batch.data) + with _timed(stats.iter_finalize_batch_s if stats else None) as span: + finalized_data = finalize_fn(batch.data) + batch.metadata.timings.finalize = span return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index dc5c32d1d500..d74d70b92d40 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -11,6 +11,7 @@ Any, DefaultDict, Dict, + Iterator, List, Mapping, Optional, @@ -174,6 +175,37 @@ def get( ) +@dataclass +class TimeSpan: + """A measured wall-clock interval. + + Created by :meth:`Timer.timer` and carried per-batch for overlap-based + blocked attribution. ``None`` (when used as ``Optional[TimeSpan]``) + indicates the stage did not execute. + """ + + start_s: float = 0.0 + end_s: float = 0.0 + + @property + def duration(self) -> float: + return self.end_s - self.start_s + + +@contextmanager +def _timed(timer: Optional["Timer"]) -> Iterator[Optional[TimeSpan]]: + """Time a block of code, yielding a :class:`TimeSpan` (or ``None``). + + When *timer* is ``None`` (e.g. ``stats`` is not configured), yields + ``None`` and skips timing entirely — no ``perf_counter`` calls. + """ + if timer is None: + yield None + else: + with timer.timer() as span: + yield span + + class Timer: """Helper class for tracking accumulated time (in seconds). @@ -195,22 +227,24 @@ def __init__(self): self._min: float = float("inf") self._max: float = 0 self._total_count: float = 0 - # Wall-clock window of the most recent timer() invocation. - # Used by overlap-based blocked attribution in iter_batches. - self.start_s: float = 0.0 - self.end_s: float = 0.0 # Bounded-memory percentile backend. add() forwards every value # to ``add_sample`` and ``percentile`` reads from it. self._distribution: DistributionTracker = DistributionTracker() @contextmanager - def timer(self) -> None: - self.start_s = time.perf_counter() + def timer(self) -> Iterator[TimeSpan]: + """Time a block, yielding a thread-local :class:`TimeSpan`. + + The returned ``TimeSpan`` is a fresh instance per call, making + this safe to use from multiple threads sharing the same ``Timer``. + The duration is also accumulated into ``self`` via :meth:`add`. + """ + span = TimeSpan(start_s=time.perf_counter()) try: - yield + yield span finally: - self.end_s = time.perf_counter() - self.add(self.end_s - self.start_s) + span.end_s = time.perf_counter() + self.add(span.duration) def add(self, value: float) -> None: self._total += value diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 63128be388e3..2deec2fe0cad 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,7 +1,7 @@ import queue import threading import time -from typing import Iterator, List +from typing import Iterator, List, Optional import pandas as pd import pyarrow as pa @@ -13,7 +13,6 @@ BatchMetadata, BatchTimings, BlockPrefetcher, - StageTiming, ) from ray.data._internal.block_batching.iter_batches import ( BatchIterator, @@ -22,7 +21,7 @@ ) from ray.data._internal.block_batching.util import WaitBlockPrefetcher from ray.data._internal.execution.interfaces.ref_bundle import BlockEntry, RefBundle -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, TimeSpan from ray.data.block import Block, BlockMetadata from ray.types import ObjectRef @@ -132,10 +131,10 @@ def test_report_batch_timings_overlap_attribution(): stats = DatasetStats(metadata={}, parent=None) batch_iterator = BatchIterator(iter([]), stats=stats) timings = BatchTimings(num_rows=8) - timings.fetch = StageTiming(start_s=10.0, end_s=20.0) - timings.batching = StageTiming(start_s=20.0, end_s=30.0) - timings.format = StageTiming(start_s=30.0, end_s=40.0) - timings.finalize = StageTiming(start_s=50.0, end_s=60.0) + timings.fetch = TimeSpan(start_s=10.0, end_s=20.0) + timings.batching = TimeSpan(start_s=20.0, end_s=30.0) + timings.format = TimeSpan(start_s=30.0, end_s=40.0) + timings.finalize = TimeSpan(start_s=50.0, end_s=60.0) batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) batch_iterator._report_batch_timings( @@ -151,6 +150,13 @@ def test_report_batch_timings_overlap_attribution(): assert stats.iter_rows_total == 8 +def _make_span(start: float, end: float) -> Optional[TimeSpan]: + """Create a TimeSpan, or None if the stage didn't execute (both zero).""" + if start == 0.0 and end == 0.0: + return None + return TimeSpan(start_s=start, end_s=end) + + def _make_batch_with_timings( fetch_start=0.0, fetch_end=0.0, @@ -166,11 +172,11 @@ def _make_batch_with_timings( ): """Helper to construct a Batch with specific stage timing windows.""" timings = BatchTimings(num_rows=num_rows) - timings.fetch = StageTiming(start_s=fetch_start, end_s=fetch_end) - timings.batching = StageTiming(start_s=batching_start, end_s=batching_end) - timings.format = StageTiming(start_s=format_start, end_s=format_end) - timings.collate = StageTiming(start_s=collate_start, end_s=collate_end) - timings.finalize = StageTiming(start_s=finalize_start, end_s=finalize_end) + timings.fetch = _make_span(fetch_start, fetch_end) + timings.batching = _make_span(batching_start, batching_end) + timings.format = _make_span(format_start, format_end) + timings.collate = _make_span(collate_start, collate_end) + timings.finalize = _make_span(finalize_start, finalize_end) return Batch(BatchMetadata(batch_idx=0, timings=timings), None) @@ -326,31 +332,25 @@ def test_all_stages_simultaneous_overlap(self): assert stats.iter_rows_total == 100 -class TestStageTimingRecord: - """Tests for StageTiming.record() behavior.""" - - def test_context_manager_captures_window(self): - """Using as context manager captures start_s and end_s.""" - t = StageTiming() - with t: - pass - assert t.start_s > 0 - assert t.end_s >= t.start_s - - def test_timer_context_manager(self): - """The timer() method works as a context manager too.""" - t = StageTiming() - with t.timer(): - pass - assert t.start_s > 0 - assert t.end_s >= t.start_s +class TestTimeSpan: + """Tests for TimeSpan dataclass.""" def test_default_values(self): - """Unrecorded StageTiming has start_s=0 and end_s=0.""" - t = StageTiming() + """Default TimeSpan has start_s=0 and end_s=0.""" + t = TimeSpan() assert t.start_s == 0.0 assert t.end_s == 0.0 + def test_duration(self): + """Duration is end_s - start_s.""" + t = TimeSpan(start_s=1.0, end_s=3.5) + assert t.duration == pytest.approx(2.5) + + def test_zero_duration(self): + """Default TimeSpan has zero duration.""" + t = TimeSpan() + assert t.duration == 0.0 + class TestMergeFetch: """Tests for BatchTimings.merge_fetch() with multiple blocks per batch.""" @@ -359,7 +359,7 @@ def test_merge_single_block(self): """Merging a single block preserves its fetch window.""" dst = BatchTimings() src = BatchTimings() - src.fetch = StageTiming(start_s=1.0, end_s=2.0) + src.fetch = TimeSpan(start_s=1.0, end_s=2.0) dst.merge_fetch(src) assert dst.fetch.start_s == 1.0 assert dst.fetch.end_s == 2.0 @@ -370,17 +370,17 @@ def test_merge_multiple_blocks_expands_window(self): # Block 1: fetched [1.0, 2.0] src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=2.0) + src1.fetch = TimeSpan(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) # Block 2: fetched [3.0, 4.0] src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=4.0) + src2.fetch = TimeSpan(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) # Block 3: fetched [5.0, 6.0] src3 = BatchTimings() - src3.fetch = StageTiming(start_s=5.0, end_s=6.0) + src3.fetch = TimeSpan(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) # Union: [1.0, 6.0] @@ -390,7 +390,7 @@ def test_merge_multiple_blocks_expands_window(self): def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" dst = BatchTimings() - dst.fetch = StageTiming(start_s=2.0, end_s=3.0) + dst.fetch = TimeSpan(start_s=2.0, end_s=3.0) src = BatchTimings() # fetch defaults to (0.0, 0.0) dst.merge_fetch(src) @@ -403,11 +403,11 @@ def test_merge_overlapping_blocks(self): dst = BatchTimings() src1 = BatchTimings() - src1.fetch = StageTiming(start_s=1.0, end_s=5.0) + src1.fetch = TimeSpan(start_s=1.0, end_s=5.0) dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = StageTiming(start_s=3.0, end_s=7.0) + src2.fetch = TimeSpan(start_s=3.0, end_s=7.0) dst.merge_fetch(src2) # Union: [1.0, 7.0] @@ -418,7 +418,7 @@ def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" dst = BatchTimings() # fetch = (0.0, 0.0) src = BatchTimings() - src.fetch = StageTiming(start_s=10.0, end_s=20.0) + src.fetch = TimeSpan(start_s=10.0, end_s=20.0) dst.merge_fetch(src) assert dst.fetch.start_s == 10.0 assert dst.fetch.end_s == 20.0 @@ -430,11 +430,11 @@ class TestEndToEndTimingPropagation: def test_batch_carries_timings_through_pipeline(self): """A Batch's metadata.timings carries all stage windows.""" timings = BatchTimings(num_rows=50) - timings.fetch = StageTiming(start_s=1.0, end_s=2.0) - timings.batching = StageTiming(start_s=2.0, end_s=3.0) - timings.format = StageTiming(start_s=3.0, end_s=4.0) - timings.collate = StageTiming(start_s=4.0, end_s=5.0) - timings.finalize = StageTiming(start_s=5.0, end_s=6.0) + timings.fetch = TimeSpan(start_s=1.0, end_s=2.0) + timings.batching = TimeSpan(start_s=2.0, end_s=3.0) + timings.format = TimeSpan(start_s=3.0, end_s=4.0) + timings.collate = TimeSpan(start_s=4.0, end_s=5.0) + timings.finalize = TimeSpan(start_s=5.0, end_s=6.0) batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) From 74b9326faebb51ad4275b295b658f7b4b8b5c5ee Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Fri, 26 Jun 2026 17:03:18 +0800 Subject: [PATCH 14/32] [Data] Split fetch stage into production_wait and data_transfer Per reviewer feedback, splits the single fetch stage into two distinct stages to differentiate data production stall vs cross-node transfer stall: - production_wait: blocked on upstream data pipeline (next on ref bundle iterator) - data_transfer: cross-node transfer via ray.get() New Prometheus metrics: - data_iter_blocked_production_wait_seconds - data_iter_blocked_data_transfer_seconds Removed: data_iter_blocked_fetch_seconds (replaced by the two above) resolve_block_refs now times each sub-stage independently using _timed() with the appropriate Timer (iter_get_ref_bundles_s for production_wait, iter_get_s for data_transfer). Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 19 +-- .../ray/data/_internal/block_batching/util.py | 23 ++-- python/ray/data/_internal/stats.py | 30 +++-- .../tests/block_batching/test_iter_batches.py | 109 ++++++++++-------- python/ray/data/tests/test_stats.py | 25 ++-- 5 files changed, 120 insertions(+), 86 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 2d115ffefca9..09840b43a1a3 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -21,7 +21,9 @@ class BatchTimings: batch (e.g. no ``collate_fn`` provided). Attributes: - fetch: Waiting for upstream data production + ``ray.get()`` transfer. + production_wait: Waiting for upstream data production (next on + the ref bundle iterator). + data_transfer: Cross-node transfer via ``ray.get()``. batching: Assembling blocks into a batch via ``_batcher.next_batch()``. format: Converting the batch to the requested format (numpy, pandas…). collate: Running the user-provided ``collate_fn``. @@ -29,7 +31,8 @@ class BatchTimings: num_rows: Number of rows in this batch (for ``iter_rows_total``). """ - fetch: Optional[TimeSpan] = None + production_wait: Optional[TimeSpan] = None + data_transfer: Optional[TimeSpan] = None batching: Optional[TimeSpan] = None format: Optional[TimeSpan] = None collate: Optional[TimeSpan] = None @@ -39,7 +42,8 @@ class BatchTimings: def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" return ( - ("fetch", self.fetch), + ("production_wait", self.production_wait), + ("data_transfer", self.data_transfer), ("batching", self.batching), ("format", self.format), ("collate", self.collate), @@ -49,12 +53,11 @@ def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: def merge_fetch(self, other: "BatchTimings") -> None: """Merge fetch timings from another batch into this one. - Expands the fetch window to span from the earliest block fetch start - to the latest block fetch end. This represents the total time the - training thread was blocked waiting for this batch, including any - pipeline overhead between consecutive block fetches. + Expands each fetch sub-stage window to span from the earliest + block's start to the latest block's end. """ - self.fetch = _merge_span(self.fetch, other.fetch) + self.production_wait = _merge_span(self.production_wait, other.production_wait) + self.data_transfer = _merge_span(self.data_transfer, other.data_transfer) def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]: diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 623b49b61db6..7f07c0c14a67 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,7 +3,6 @@ import logging import queue import threading -import time from typing import ( Any, Callable, @@ -27,7 +26,7 @@ BlockWithTiming, CollatedBatch, ) -from ray.data._internal.stats import DatasetStats, TimeSpan, _timed +from ray.data._internal.stats import DatasetStats, _timed from ray.data.block import Block, BlockAccessor, DataBatch from ray.types import ObjectRef @@ -198,26 +197,26 @@ def resolve_block_refs( unknowns = 0 while True: - # Time the full fetch window — from upstream pull through ray.get(). - # In a follow-up, this will be split into production_wait + data_transfer. timings = BatchTimings() - fetch_span = TimeSpan(start_s=time.perf_counter()) - try: - block_ref = next(block_ref_iter) - except StopIteration: - break + # (1) production_wait: blocked on upstream data pipeline + with _timed(stats.iter_get_ref_bundles_s if stats else None) as prod_span: + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + timings.production_wait = prod_span current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) hits += current_hit misses += current_miss unknowns += current_unknown + # (2) data_transfer: cross-node transfer via ray.get() # TODO(amogkam): Optimized further by batching multiple references # in a single `ray.get()` call. - with _timed(stats.iter_get_s if stats else None): + with _timed(stats.iter_get_s if stats else None) as xfer_span: block = ray.get(block_ref) - fetch_span.end_s = time.perf_counter() - timings.fetch = fetch_span + timings.data_transfer = xfer_span yield BlockWithTiming(block=block, timings=timings) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index d74d70b92d40..5c7c1519cb13 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -546,9 +546,14 @@ def __init__(self, max_stats=1000): description="Total wall-clock seconds spent in the dataset iterator", tag_keys=iter_tag_keys, ) - self.iter_blocked_fetch_s = Gauge( - "data_iter_blocked_fetch_seconds", - description="Seconds user thread is blocked on block fetching", + self.iter_blocked_production_wait_s = Gauge( + "data_iter_blocked_production_wait_seconds", + description="Seconds user thread is blocked on upstream data production", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_data_transfer_s = Gauge( + "data_iter_blocked_data_transfer_seconds", + description="Seconds user thread is blocked on cross-node data transfer (ray.get)", tag_keys=iter_tag_keys, ) self.iter_blocked_batching_s = Gauge( @@ -842,7 +847,12 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) - self.iter_blocked_fetch_s.set(stats.iter_blocked_fetch_s.get(), tags) + self.iter_blocked_production_wait_s.set( + stats.iter_blocked_production_wait_s.get(), tags + ) + self.iter_blocked_data_transfer_s.set( + stats.iter_blocked_data_transfer_s.get(), tags + ) self.iter_blocked_batching_s.set(stats.iter_blocked_batching_s.get(), tags) self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) @@ -1242,7 +1252,8 @@ def __init__( self.iter_finalize_batch_s: Timer = Timer() self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() - self.iter_blocked_fetch_s: Timer = Timer() + self.iter_blocked_production_wait_s: Timer = Timer() + self.iter_blocked_data_transfer_s: Timer = Timer() self.iter_blocked_batching_s: Timer = Timer() self.iter_blocked_format_s: Timer = Timer() self.iter_blocked_collate_s: Timer = Timer() @@ -1307,7 +1318,8 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocks_remote, self.iter_unknown_location, self.iter_prefetched_bytes, - self.iter_blocked_fetch_s, + self.iter_blocked_production_wait_s, + self.iter_blocked_data_transfer_s, self.iter_blocked_batching_s, self.iter_blocked_format_s, self.iter_blocked_collate_s, @@ -1997,7 +2009,8 @@ class IterStatsSummary: # Current bytes of prefetched blocks in the iterator iter_prefetched_bytes: int # Per-stage training-thread blocked attribution timers. - blocked_fetch_time: Timer + blocked_production_wait_time: Timer + blocked_data_transfer_time: Timer blocked_batching_time: Timer blocked_format_time: Timer blocked_collate_time: Timer @@ -2113,7 +2126,8 @@ def to_string(self) -> str: # Per-stage training-thread blocked attribution. stage_totals = [ - ("block fetch (ray.get)", self.blocked_fetch_time), + ("production wait", self.blocked_production_wait_time), + ("data transfer (ray.get)", self.blocked_data_transfer_time), ("batching", self.blocked_batching_time), ("format", self.blocked_format_time), ("collate", self.blocked_collate_time), diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 2deec2fe0cad..a4b42d119727 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -131,7 +131,7 @@ def test_report_batch_timings_overlap_attribution(): stats = DatasetStats(metadata={}, parent=None) batch_iterator = BatchIterator(iter([]), stats=stats) timings = BatchTimings(num_rows=8) - timings.fetch = TimeSpan(start_s=10.0, end_s=20.0) + timings.production_wait = TimeSpan(start_s=10.0, end_s=20.0) timings.batching = TimeSpan(start_s=20.0, end_s=30.0) timings.format = TimeSpan(start_s=30.0, end_s=40.0) timings.finalize = TimeSpan(start_s=50.0, end_s=60.0) @@ -141,7 +141,7 @@ def test_report_batch_timings_overlap_attribution(): batch, blocked_start_s=15.0, blocked_end_s=35.0 ) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(5.0) assert stats.iter_blocked_batching_s.get() == pytest.approx(10.0) assert stats.iter_blocked_format_s.get() == pytest.approx(5.0) assert stats.iter_blocked_collate_s.get() == 0 @@ -158,8 +158,10 @@ def _make_span(start: float, end: float) -> Optional[TimeSpan]: def _make_batch_with_timings( - fetch_start=0.0, - fetch_end=0.0, + production_wait_start=0.0, + production_wait_end=0.0, + data_transfer_start=0.0, + data_transfer_end=0.0, batching_start=0.0, batching_end=0.0, format_start=0.0, @@ -172,7 +174,8 @@ def _make_batch_with_timings( ): """Helper to construct a Batch with specific stage timing windows.""" timings = BatchTimings(num_rows=num_rows) - timings.fetch = _make_span(fetch_start, fetch_end) + timings.production_wait = _make_span(production_wait_start, production_wait_end) + timings.data_transfer = _make_span(data_transfer_start, data_transfer_end) timings.batching = _make_span(batching_start, batching_end) timings.format = _make_span(format_start, format_end) timings.collate = _make_span(collate_start, collate_end) @@ -194,9 +197,11 @@ def test_zero_overlap_stage_finished_before_blocked(self): """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) - batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.5) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=1.5 + ) it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=3.0) - assert stats.iter_blocked_fetch_s.get() == 0.0 + assert stats.iter_blocked_production_wait_s.get() == 0.0 def test_zero_overlap_blocked_before_stage(self): """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" @@ -210,9 +215,11 @@ def test_partial_overlap(self): """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) - batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=2.0) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=2.0 + ) it._report_batch_timings(batch, blocked_start_s=1.0, blocked_end_s=3.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) def test_full_overlap_stage_inside_blocked(self): """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" @@ -245,14 +252,14 @@ def test_prefetch_hides_fetch_from_training(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=1.5, + production_wait_start=0.0, + production_wait_end=1.5, collate_start=2.3, collate_end=2.6, ) # Training only starts blocking at t=2 (prefetch worked) it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=2.6) - assert stats.iter_blocked_fetch_s.get() == 0.0 + assert stats.iter_blocked_production_wait_s.get() == 0.0 assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) def test_accumulation_across_batches(self): @@ -260,13 +267,17 @@ def test_accumulation_across_batches(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 - b1 = _make_batch_with_timings(fetch_start=0.0, fetch_end=1.0, num_rows=10) + b1 = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=1.0, num_rows=10 + ) it._report_batch_timings(b1, blocked_start_s=0.0, blocked_end_s=2.0) # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 - b2 = _make_batch_with_timings(fetch_start=5.0, fetch_end=6.0, num_rows=20) + b2 = _make_batch_with_timings( + production_wait_start=5.0, production_wait_end=6.0, num_rows=20 + ) it._report_batch_timings(b2, blocked_start_s=5.0, blocked_end_s=7.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) assert stats.iter_batches_total == 2 assert stats.iter_rows_total == 30 @@ -276,8 +287,8 @@ def test_overlap_invariant_sum_leq_total(self): it = _make_report_iterator(stats) stats.iter_total_blocked_s.add(5.0) batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=1.0, + production_wait_start=0.0, + production_wait_end=1.0, batching_start=1.0, batching_end=2.0, format_start=2.0, @@ -288,7 +299,7 @@ def test_overlap_invariant_sum_leq_total(self): total = stats.iter_total_blocked_s.get() sum_stages = ( - stats.iter_blocked_fetch_s.get() + stats.iter_blocked_production_wait_s.get() + stats.iter_blocked_batching_s.get() + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() @@ -300,17 +311,19 @@ def test_blocked_inside_stage(self): """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) - batch = _make_batch_with_timings(fetch_start=0.0, fetch_end=10.0) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=10.0 + ) it._report_batch_timings(batch, blocked_start_s=3.0, blocked_end_s=5.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(2.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) def test_all_stages_simultaneous_overlap(self): """Multiple stages overlap with blocked window simultaneously.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=1.0, + production_wait_start=0.0, + production_wait_end=1.0, batching_start=1.0, batching_end=2.0, format_start=2.0, @@ -323,7 +336,7 @@ def test_all_stages_simultaneous_overlap(self): ) # Blocked window covers all stages it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) - assert stats.iter_blocked_fetch_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) @@ -359,10 +372,10 @@ def test_merge_single_block(self): """Merging a single block preserves its fetch window.""" dst = BatchTimings() src = BatchTimings() - src.fetch = TimeSpan(start_s=1.0, end_s=2.0) + src.production_wait = TimeSpan(start_s=1.0, end_s=2.0) dst.merge_fetch(src) - assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 2.0 + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 2.0 def test_merge_multiple_blocks_expands_window(self): """Merging multiple blocks produces the union window.""" @@ -370,58 +383,58 @@ def test_merge_multiple_blocks_expands_window(self): # Block 1: fetched [1.0, 2.0] src1 = BatchTimings() - src1.fetch = TimeSpan(start_s=1.0, end_s=2.0) + src1.production_wait = TimeSpan(start_s=1.0, end_s=2.0) dst.merge_fetch(src1) # Block 2: fetched [3.0, 4.0] src2 = BatchTimings() - src2.fetch = TimeSpan(start_s=3.0, end_s=4.0) + src2.production_wait = TimeSpan(start_s=3.0, end_s=4.0) dst.merge_fetch(src2) # Block 3: fetched [5.0, 6.0] src3 = BatchTimings() - src3.fetch = TimeSpan(start_s=5.0, end_s=6.0) + src3.production_wait = TimeSpan(start_s=5.0, end_s=6.0) dst.merge_fetch(src3) # Union: [1.0, 6.0] - assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 6.0 + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 6.0 def test_merge_unrecorded_block_ignored(self): """Merging a block with no fetch timing (start_s=0) is a no-op.""" dst = BatchTimings() - dst.fetch = TimeSpan(start_s=2.0, end_s=3.0) + dst.production_wait = TimeSpan(start_s=2.0, end_s=3.0) src = BatchTimings() # fetch defaults to (0.0, 0.0) dst.merge_fetch(src) - assert dst.fetch.start_s == 2.0 - assert dst.fetch.end_s == 3.0 + assert dst.production_wait.start_s == 2.0 + assert dst.production_wait.end_s == 3.0 def test_merge_overlapping_blocks(self): """Overlapping fetch windows are correctly merged.""" dst = BatchTimings() src1 = BatchTimings() - src1.fetch = TimeSpan(start_s=1.0, end_s=5.0) + src1.production_wait = TimeSpan(start_s=1.0, end_s=5.0) dst.merge_fetch(src1) src2 = BatchTimings() - src2.fetch = TimeSpan(start_s=3.0, end_s=7.0) + src2.production_wait = TimeSpan(start_s=3.0, end_s=7.0) dst.merge_fetch(src2) # Union: [1.0, 7.0] - assert dst.fetch.start_s == 1.0 - assert dst.fetch.end_s == 7.0 + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 7.0 def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" dst = BatchTimings() # fetch = (0.0, 0.0) src = BatchTimings() - src.fetch = TimeSpan(start_s=10.0, end_s=20.0) + src.production_wait = TimeSpan(start_s=10.0, end_s=20.0) dst.merge_fetch(src) - assert dst.fetch.start_s == 10.0 - assert dst.fetch.end_s == 20.0 + assert dst.production_wait.start_s == 10.0 + assert dst.production_wait.end_s == 20.0 class TestEndToEndTimingPropagation: @@ -430,7 +443,7 @@ class TestEndToEndTimingPropagation: def test_batch_carries_timings_through_pipeline(self): """A Batch's metadata.timings carries all stage windows.""" timings = BatchTimings(num_rows=50) - timings.fetch = TimeSpan(start_s=1.0, end_s=2.0) + timings.production_wait = TimeSpan(start_s=1.0, end_s=2.0) timings.batching = TimeSpan(start_s=2.0, end_s=3.0) timings.format = TimeSpan(start_s=3.0, end_s=4.0) timings.collate = TimeSpan(start_s=4.0, end_s=5.0) @@ -440,8 +453,8 @@ def test_batch_carries_timings_through_pipeline(self): # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) - assert len(stage_dict) == 5 - assert stage_dict["fetch"].start_s == 1.0 + assert len(stage_dict) == 6 + assert stage_dict["production_wait"].start_s == 1.0 assert stage_dict["batching"].end_s == 3.0 assert stage_dict["format"].start_s == 3.0 assert stage_dict["collate"].end_s == 5.0 @@ -455,8 +468,8 @@ def test_full_pipeline_attribution(self): stats.iter_total_blocked_s.add(5.0) batch = _make_batch_with_timings( - fetch_start=0.0, - fetch_end=0.5, + production_wait_start=0.0, + production_wait_end=0.5, batching_start=0.5, batching_end=1.0, format_start=1.0, @@ -472,7 +485,7 @@ def test_full_pipeline_attribution(self): it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) # Each stage gets its full duration - assert stats.iter_blocked_fetch_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(0.5) assert stats.iter_blocked_batching_s.get() == pytest.approx(0.5) assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) @@ -482,7 +495,7 @@ def test_full_pipeline_attribution(self): # Invariant: sum = 3.0 <= total_blocked = 5.0 sum_stages = ( - stats.iter_blocked_fetch_s.get() + stats.iter_blocked_production_wait_s.get() + stats.iter_blocked_batching_s.get() + stats.iter_blocked_format_s.get() + stats.iter_blocked_collate_s.get() diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 245c0bfed8d4..ebe0e4895a7b 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1899,7 +1899,8 @@ def test_create_iteration_tags_extracts_rank(): def test_update_iteration_metrics_exports_new_iter_metrics(): stats = DatasetStats(metadata={}, parent=None) stats.iter_total_s.add(11.0) - stats.iter_blocked_fetch_s.add(1.0) + stats.iter_blocked_production_wait_s.add(1.0) + stats.iter_blocked_data_transfer_s.add(1.5) stats.iter_blocked_batching_s.add(2.0) stats.iter_blocked_format_s.add(3.0) stats.iter_blocked_collate_s.add(4.0) @@ -1939,7 +1940,8 @@ def set(self, value, tags): "iter_batch_finalizing_s", "time_to_first_batch_s", "iter_total_blocked_s", - "iter_blocked_fetch_s", + "iter_blocked_production_wait_s", + "iter_blocked_data_transfer_s", "iter_blocked_batching_s", "iter_blocked_format_s", "iter_blocked_collate_s", @@ -1954,7 +1956,8 @@ def set(self, value, tags): expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} assert recorded["iter_total_s"] == (11.0, expected_tags) - assert recorded["iter_blocked_fetch_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_production_wait_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_data_transfer_s"] == (1.5, expected_tags) assert recorded["iter_blocked_batching_s"] == (2.0, expected_tags) assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) @@ -1969,7 +1972,8 @@ def test_iter_stats_summary_has_new_fields(): summary = stats.to_summary() iter_summary = summary.iter_stats - assert hasattr(iter_summary, "blocked_fetch_time") + assert hasattr(iter_summary, "blocked_production_wait_time") + assert hasattr(iter_summary, "blocked_data_transfer_time") assert hasattr(iter_summary, "blocked_batching_time") assert hasattr(iter_summary, "blocked_format_time") assert hasattr(iter_summary, "blocked_collate_time") @@ -1981,13 +1985,14 @@ def test_iter_stats_summary_has_new_fields(): def test_iter_stats_summary_reflects_accumulated_values(): """IterStatsSummary carries the accumulated timer values.""" stats = DatasetStats(metadata={}, parent=None) - stats.iter_blocked_fetch_s.add(0.5) + stats.iter_blocked_production_wait_s.add(0.5) stats.iter_blocked_batching_s.add(0.2) stats.iter_batches_total = 10 stats.iter_rows_total = 320 summary = stats.to_summary().iter_stats - assert summary.blocked_fetch_time.get() == pytest.approx(0.5) + assert summary.blocked_production_wait_time.get() == pytest.approx(0.5) + assert summary.blocked_data_transfer_time.get() == pytest.approx(0.0) assert summary.blocked_batching_time.get() == pytest.approx(0.2) assert summary.batches_total == 10 assert summary.rows_total == 320 @@ -1996,14 +2001,14 @@ def test_iter_stats_summary_reflects_accumulated_values(): def test_iter_stats_to_string_shows_stage_breakdown(): """to_string() renders per-stage breakdown when values are non-zero.""" stats = DatasetStats(metadata={}, parent=None) - stats.iter_blocked_fetch_s.add(1.5) + stats.iter_blocked_production_wait_s.add(1.5) stats.iter_blocked_format_s.add(0.8) stats.iter_batches_total = 5 stats.iter_rows_total = 160 stats.iter_total_blocked_s.add(2.3) text = str(stats.to_summary().iter_stats) - assert "block fetch" in text + assert "production wait" in text assert "format" in text assert "Total batches consumed: 5" in text assert "Total rows consumed: 160" in text @@ -2013,11 +2018,11 @@ def test_iter_stats_to_string_shows_stage_breakdown(): def test_iter_stats_to_string_omits_zero_stages(): """to_string() omits stages with zero values from the breakdown.""" stats = DatasetStats(metadata={}, parent=None) - stats.iter_blocked_fetch_s.add(0.5) + stats.iter_blocked_production_wait_s.add(0.5) stats.iter_total_blocked_s.add(0.5) text = str(stats.to_summary().iter_stats) - assert "block fetch" in text + assert "production wait" in text # Zero stages should not appear assert "batching" not in text assert "collate" not in text From 261317c2bb5d4ae8705f4adc5b04ad0d08fce6ca Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Fri, 26 Jun 2026 17:21:36 +0800 Subject: [PATCH 15/32] [Data] Refine data model with BlockFetchResult and Optional timing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per reviewer feedback, refines the data model to accurately reflect what each component populates: - New BlockFetchTiming: only production_wait + data_transfer (what a single block actually has) - New BlockFetchResult: replaces BlockWithTiming, carries block + Optional[BlockFetchTiming] (None when no fetch timing) - Removed BlockWithTiming (data model was inaccurate — it carried full BatchTimings but only fetch was populated) - Moved num_rows from BatchTimings to BatchMetadata (it's not a timing, it's batch metadata) - batch_blocks() shim simplified: BlockFetchResult(block=b) with fetch=None instead of creating fake BatchTimings - merge_fetch now accepts BlockFetchTiming (not BatchTimings) Signed-off-by: OneSizeFitsQuorum --- .../block_batching/block_batching.py | 17 ++---- .../_internal/block_batching/interfaces.py | 61 +++++++++---------- .../_internal/block_batching/iter_batches.py | 2 +- .../ray/data/_internal/block_batching/util.py | 30 +++++---- .../tests/block_batching/test_iter_batches.py | 14 ++--- .../data/tests/block_batching/test_util.py | 11 ++-- 6 files changed, 62 insertions(+), 73 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 999e2e10af9b..0a54689b9552 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,9 +1,6 @@ from typing import Callable, Iterator, Optional, TypeVar -from ray.data._internal.block_batching.interfaces import ( - BatchTimings, - BlockWithTiming, -) +from ray.data._internal.block_batching.interfaces import BlockFetchResult from ray.data._internal.block_batching.util import ( _MappingIterator, blocks_to_batches, @@ -33,13 +30,11 @@ def batch_blocks( This function takes in an iterator of already fetched blocks. Consequently, this function doesn't support block prefetching. """ - # Wrap raw blocks in BlockWithTiming with zero timing so that - # _BatchingIterator receives a uniform type. Use map() instead of a - # generator expression to avoid holding references to blocks. - def _wrap_block(b): - return BlockWithTiming(block=b, timings=BatchTimings()) - - wrapped_blocks = map(_wrap_block, blocks) + # Wrap raw blocks in BlockFetchResult with no fetch timing (these + # blocks were already resolved before entering the pipeline). + # Use map() instead of a generator expression to avoid holding + # references to blocks. + wrapped_blocks = map(lambda b: BlockFetchResult(block=b), blocks) # Build the processing pipeline batch_iter = format_batches( diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 09840b43a1a3..15d3c85522c3 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -7,6 +7,30 @@ from ray.types import ObjectRef +@dataclass +class BlockFetchTiming: + """Fetch timing for a single block (production_wait + data_transfer). + + Produced by :func:`resolve_block_refs` and merged into + :class:`BatchTimings` by :class:`_BatchingIterator`. + """ + + production_wait: Optional[TimeSpan] = None + data_transfer: Optional[TimeSpan] = None + + +@dataclass +class BlockFetchResult: + """A resolved block paired with its per-block fetch timing. + + When ``fetch`` is ``None``, no fetch timing was recorded (e.g. blocks + that were already resolved before entering the pipeline). + """ + + block: Block + fetch: Optional[BlockFetchTiming] = None + + @dataclass class BatchTimings: """Per-batch pipeline-stage timing windows for overlap-based attribution. @@ -19,16 +43,6 @@ class BatchTimings: A field value of ``None`` indicates the stage did not execute for this batch (e.g. no ``collate_fn`` provided). - - Attributes: - production_wait: Waiting for upstream data production (next on - the ref bundle iterator). - data_transfer: Cross-node transfer via ``ray.get()``. - batching: Assembling blocks into a batch via ``_batcher.next_batch()``. - format: Converting the batch to the requested format (numpy, pandas…). - collate: Running the user-provided ``collate_fn``. - finalize: Running the user-provided ``finalize_fn`` (e.g. host→device). - num_rows: Number of rows in this batch (for ``iter_rows_total``). """ production_wait: Optional[TimeSpan] = None @@ -37,7 +51,6 @@ class BatchTimings: format: Optional[TimeSpan] = None collate: Optional[TimeSpan] = None finalize: Optional[TimeSpan] = None - num_rows: int = 0 def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" @@ -50,14 +63,10 @@ def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: ("finalize", self.finalize), ) - def merge_fetch(self, other: "BatchTimings") -> None: - """Merge fetch timings from another batch into this one. - - Expands each fetch sub-stage window to span from the earliest - block's start to the latest block's end. - """ - self.production_wait = _merge_span(self.production_wait, other.production_wait) - self.data_transfer = _merge_span(self.data_transfer, other.data_transfer) + def merge_fetch(self, src: BlockFetchTiming) -> None: + """Merge per-block fetch timings into this batch's fetch windows.""" + self.production_wait = _merge_span(self.production_wait, src.production_wait) + self.data_transfer = _merge_span(self.data_transfer, src.data_transfer) def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]: @@ -77,18 +86,6 @@ def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[Ti ) -@dataclass -class BlockWithTiming: - """A resolved block paired with its fetch timing window. - - Produced by :func:`resolve_block_refs` so that downstream pipeline stages - can track how long each block took to fetch (upstream wait + ``ray.get()``). - """ - - block: Block - timings: BatchTimings = field(default_factory=BatchTimings) - - @dataclass class BatchMetadata: """Metadata associated with a batch. @@ -96,10 +93,12 @@ class BatchMetadata: Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. + num_rows: Number of rows in this batch (for ``iter_rows_total``). timings: Pipeline-stage timing windows for this batch. """ batch_idx: int + num_rows: int = 0 timings: BatchTimings = field(default_factory=BatchTimings) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index bc59edb20866..254eec47efa1 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -307,7 +307,7 @@ def _report_batch_timings( if overlap_s > 0: getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) self._stats.iter_batches_total += 1 - self._stats.iter_rows_total += timings.num_rows + self._stats.iter_rows_total += batch.metadata.num_rows def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 7f07c0c14a67..b3728f795617 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -22,8 +22,9 @@ Batch, BatchMetadata, BatchTimings, + BlockFetchResult, + BlockFetchTiming, BlockPrefetcher, - BlockWithTiming, CollatedBatch, ) from ray.data._internal.stats import DatasetStats, _timed @@ -174,10 +175,10 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, -) -> Iterator[BlockWithTiming]: +) -> Iterator[BlockFetchResult]: """Resolves the block references for each logical batch. - Each resolved block is wrapped in a :class:`BlockWithTiming` that carries + Each resolved block is wrapped in a :class:`BlockFetchResult` that carries the per-block fetch window. The fetch window spans from the moment we start waiting for the upstream iterator (blocked on the data pipeline or cross-node transfer) until ``ray.get()`` returns the resolved block. @@ -190,21 +191,19 @@ def resolve_block_refs( cumulative fetch time. Yields: - BlockWithTiming: Each resolved block with its fetch timing window. + BlockFetchResult: Each resolved block with its fetch timing window. """ hits = 0 misses = 0 unknowns = 0 while True: - timings = BatchTimings() # (1) production_wait: blocked on upstream data pipeline with _timed(stats.iter_get_ref_bundles_s if stats else None) as prod_span: try: block_ref = next(block_ref_iter) except StopIteration: break - timings.production_wait = prod_span current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) hits += current_hit @@ -216,9 +215,9 @@ def resolve_block_refs( # in a single `ray.get()` call. with _timed(stats.iter_get_s if stats else None) as xfer_span: block = ray.get(block_ref) - timings.data_transfer = xfer_span - yield BlockWithTiming(block=block, timings=timings) + fetch = BlockFetchTiming(production_wait=prod_span, data_transfer=xfer_span) + yield BlockFetchResult(block=block, fetch=fetch) if stats: stats.iter_blocks_local = hits @@ -227,7 +226,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[BlockFetchResult], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -256,7 +255,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[BlockWithTiming], + block_iter: Iterator[BlockFetchResult], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -300,13 +299,11 @@ def __next__(self) -> Batch: res = Batch( metadata=BatchMetadata( batch_idx=self._global_counter, + num_rows=BlockAccessor.for_block(next_batch).num_rows(), timings=self._pending_timings, ), data=next_batch, ) - res.metadata.timings.num_rows = BlockAccessor.for_block( - next_batch - ).num_rows() self._pending_timings = BatchTimings() self._global_counter += 1 @@ -316,9 +313,10 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block_with_timing = next(self._block_iter) - self._pending_timings.merge_fetch(block_with_timing.timings) - self._batcher.add(block_with_timing.block) + block_result = next(self._block_iter) + if block_result.fetch is not None: + self._pending_timings.merge_fetch(block_result.fetch) + self._batcher.add(block_result.block) except StopIteration: self._batcher.done_adding() self._done_adding = True diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index a4b42d119727..47f154da2b19 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -130,12 +130,12 @@ def test_restore_original_order_stats(): def test_report_batch_timings_overlap_attribution(): stats = DatasetStats(metadata={}, parent=None) batch_iterator = BatchIterator(iter([]), stats=stats) - timings = BatchTimings(num_rows=8) + timings = BatchTimings() timings.production_wait = TimeSpan(start_s=10.0, end_s=20.0) timings.batching = TimeSpan(start_s=20.0, end_s=30.0) timings.format = TimeSpan(start_s=30.0, end_s=40.0) timings.finalize = TimeSpan(start_s=50.0, end_s=60.0) - batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + batch = Batch(BatchMetadata(batch_idx=0, num_rows=8, timings=timings), None) batch_iterator._report_batch_timings( batch, blocked_start_s=15.0, blocked_end_s=35.0 @@ -173,14 +173,14 @@ def _make_batch_with_timings( num_rows=0, ): """Helper to construct a Batch with specific stage timing windows.""" - timings = BatchTimings(num_rows=num_rows) + timings = BatchTimings() timings.production_wait = _make_span(production_wait_start, production_wait_end) timings.data_transfer = _make_span(data_transfer_start, data_transfer_end) timings.batching = _make_span(batching_start, batching_end) timings.format = _make_span(format_start, format_end) timings.collate = _make_span(collate_start, collate_end) timings.finalize = _make_span(finalize_start, finalize_end) - return Batch(BatchMetadata(batch_idx=0, timings=timings), None) + return Batch(BatchMetadata(batch_idx=0, num_rows=num_rows, timings=timings), None) def _make_report_iterator(stats): @@ -442,14 +442,14 @@ class TestEndToEndTimingPropagation: def test_batch_carries_timings_through_pipeline(self): """A Batch's metadata.timings carries all stage windows.""" - timings = BatchTimings(num_rows=50) + timings = BatchTimings() timings.production_wait = TimeSpan(start_s=1.0, end_s=2.0) timings.batching = TimeSpan(start_s=2.0, end_s=3.0) timings.format = TimeSpan(start_s=3.0, end_s=4.0) timings.collate = TimeSpan(start_s=4.0, end_s=5.0) timings.finalize = TimeSpan(start_s=5.0, end_s=6.0) - batch = Batch(BatchMetadata(batch_idx=0, timings=timings), None) + batch = Batch(BatchMetadata(batch_idx=0, num_rows=50, timings=timings), None) # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) @@ -459,7 +459,7 @@ def test_batch_carries_timings_through_pipeline(self): assert stage_dict["format"].start_s == 3.0 assert stage_dict["collate"].end_s == 5.0 assert stage_dict["finalize"].start_s == 5.0 - assert batch.metadata.timings.num_rows == 50 + assert batch.metadata.num_rows == 50 def test_full_pipeline_attribution(self): """End-to-end: all 5 stages with realistic timing, full overlap.""" diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 7a6223b06f78..eabe54b20754 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -16,8 +16,7 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, - BatchTimings, - BlockWithTiming, + BlockFetchResult, ) from ray.data._internal.block_batching.util import ( _calculate_ref_hits, @@ -43,7 +42,7 @@ def test_resolve_block_refs(ray_start_regular_shared): resolved_iter = resolve_block_refs(iter(block_refs)) resolved = list(resolved_iter) - assert all(isinstance(b, BlockWithTiming) for b in resolved) + assert all(isinstance(b, BlockFetchResult) for b in resolved) assert [b.block for b in resolved] == [0, 1, 2] @@ -52,10 +51,8 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_blocks_to_batches(block_size, drop_last): num_blocks = 5 block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) - # Wrap raw blocks in BlockWithTiming as blocks_to_batches now expects - wrapped_blocks = ( - BlockWithTiming(block=b, timings=BatchTimings()) for b in block_iter - ) + # Wrap raw blocks in BlockFetchResult (fetch=None) as blocks_to_batches now expects + wrapped_blocks = (BlockFetchResult(block=b) for b in block_iter) batch_size = 3 batch_iter = list( From 304dfe8aea7d137d443093b67116b8d14a471fab Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Fri, 26 Jun 2026 17:37:45 +0800 Subject: [PATCH 16/32] [Data] Add PipelineStage enum and match-based gauge mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per reviewer feedback, replaces string-based stage identifiers and getattr with type-safe alternatives: - New PipelineStage(Enum) in stats.py (avoids circular import) - BatchTimings.stages() returns Tuple[PipelineStage, Optional[TimeSpan]] - DatasetStats.get_blocked_timer(stage) uses match statement (no getattr, no string formatting) - _StatsActor stores blocked gauges in Dict[PipelineStage, Gauge] instead of individual attributes - update_iteration_metrics iterates the dict instead of calling 6 individual .set() methods - Renamed _report_batch_timings → _attribute_blocked_time (more descriptive of what the method does) Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 18 ++-- .../_internal/block_batching/iter_batches.py | 14 +-- python/ray/data/_internal/stats.py | 102 +++++++++++------- .../tests/block_batching/test_iter_batches.py | 42 ++++---- python/ray/data/tests/test_stats.py | 18 ++-- 5 files changed, 111 insertions(+), 83 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 15d3c85522c3..30fb9d79f422 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Iterable, List, Optional, Tuple -from ray.data._internal.stats import TimeSpan +from ray.data._internal.stats import PipelineStage, TimeSpan from ray.data.block import Block, DataBatch from ray.types import ObjectRef @@ -52,15 +52,15 @@ class BatchTimings: collate: Optional[TimeSpan] = None finalize: Optional[TimeSpan] = None - def stages(self) -> Iterable[Tuple[str, Optional[TimeSpan]]]: - """Iterate over ``(name, timing)`` pairs for all pipeline stages.""" + def stages(self) -> Iterable[Tuple[PipelineStage, Optional[TimeSpan]]]: + """Iterate over ``(stage, timing)`` pairs for all pipeline stages.""" return ( - ("production_wait", self.production_wait), - ("data_transfer", self.data_transfer), - ("batching", self.batching), - ("format", self.format), - ("collate", self.collate), - ("finalize", self.finalize), + (PipelineStage.PRODUCTION_WAIT, self.production_wait), + (PipelineStage.DATA_TRANSFER, self.data_transfer), + (PipelineStage.BATCHING, self.batching), + (PipelineStage.FORMAT, self.format), + (PipelineStage.COLLATE, self.collate), + (PipelineStage.FINALIZE, self.finalize), ) def merge_fetch(self, src: BlockFetchTiming) -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 254eec47efa1..fed9dbab3a12 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -255,26 +255,26 @@ def _iter_batches(self) -> Iterator[DataBatch]: except StopIteration: break blocked_end_s = time.perf_counter() - self._report_batch_timings(batch, blocked_start_s, blocked_end_s) + self._attribute_blocked_time(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data self.after_epoch_end() - def _report_batch_timings( + def _attribute_blocked_time( self, batch: Batch, blocked_start_s: float, blocked_end_s: float ) -> None: """Attribute per-stage blocked time via overlap with the training window. - For each pipeline stage we know when it ran ``[stage.start_s, - stage.end_s]`` (recorded by background threads onto + For each pipeline stage we know when it ran ``[timing.start_s, + timing.end_s]`` (recorded by background threads onto ``batch.metadata.timings``). We also know when the training thread was blocked ``[blocked_start_s, blocked_end_s]`` (captured in ``_iter_batches`` around ``next()``). The attribution for a stage is the length of the intersection:: - overlap = min(stage.end, blocked_end) - max(stage.start, blocked_start) + overlap = min(timing.end, blocked_end) - max(timing.start, blocked_start) This correctly handles all prefetch configurations: @@ -298,14 +298,14 @@ def _report_batch_timings( if self._stats is None: return timings = batch.metadata.timings - for name, timing in timings.stages(): + for stage, timing in timings.stages(): if timing is None: continue overlap_s = min(timing.end_s, blocked_end_s) - max( timing.start_s, blocked_start_s ) if overlap_s > 0: - getattr(self._stats, f"iter_blocked_{name}_s").add(overlap_s) + self._stats.get_blocked_timer(stage).add(overlap_s) self._stats.iter_batches_total += 1 self._stats.iter_rows_total += batch.metadata.num_rows diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 5c7c1519cb13..53f80c73fb69 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -6,6 +6,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, fields +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -175,6 +176,17 @@ def get( ) +class PipelineStage(Enum): + """Pipeline stages for blocked attribution.""" + + PRODUCTION_WAIT = "production_wait" + DATA_TRANSFER = "data_transfer" + BATCHING = "batching" + FORMAT = "format" + COLLATE = "collate" + FINALIZE = "finalize" + + @dataclass class TimeSpan: """A measured wall-clock interval. @@ -546,36 +558,38 @@ def __init__(self, max_stats=1000): description="Total wall-clock seconds spent in the dataset iterator", tag_keys=iter_tag_keys, ) - self.iter_blocked_production_wait_s = Gauge( - "data_iter_blocked_production_wait_seconds", - description="Seconds user thread is blocked on upstream data production", - tag_keys=iter_tag_keys, - ) - self.iter_blocked_data_transfer_s = Gauge( - "data_iter_blocked_data_transfer_seconds", - description="Seconds user thread is blocked on cross-node data transfer (ray.get)", - tag_keys=iter_tag_keys, - ) - self.iter_blocked_batching_s = Gauge( - "data_iter_blocked_batching_seconds", - description="Seconds user thread is blocked on batch creation", - tag_keys=iter_tag_keys, - ) - self.iter_blocked_format_s = Gauge( - "data_iter_blocked_format_seconds", - description="Seconds user thread is blocked on batch formatting", - tag_keys=iter_tag_keys, - ) - self.iter_blocked_collate_s = Gauge( - "data_iter_blocked_collate_seconds", - description="Seconds user thread is blocked on batch collation", - tag_keys=iter_tag_keys, - ) - self.iter_blocked_finalize_s = Gauge( - "data_iter_blocked_finalize_seconds", - description="Seconds user thread is blocked on batch finalization", - tag_keys=iter_tag_keys, - ) + self._blocked_gauges: Dict[PipelineStage, Gauge] = { + PipelineStage.PRODUCTION_WAIT: Gauge( + "data_iter_blocked_production_wait_seconds", + description="Seconds user thread is blocked on upstream data production", + tag_keys=iter_tag_keys, + ), + PipelineStage.DATA_TRANSFER: Gauge( + "data_iter_blocked_data_transfer_seconds", + description="Seconds user thread is blocked on cross-node data transfer (ray.get)", + tag_keys=iter_tag_keys, + ), + PipelineStage.BATCHING: Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ), + PipelineStage.FORMAT: Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ), + PipelineStage.COLLATE: Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ), + PipelineStage.FINALIZE: Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ), + } self.iter_batches_total = Gauge( "data_iter_batches_total", description="Total batches delivered to the user thread", @@ -847,16 +861,8 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) - self.iter_blocked_production_wait_s.set( - stats.iter_blocked_production_wait_s.get(), tags - ) - self.iter_blocked_data_transfer_s.set( - stats.iter_blocked_data_transfer_s.get(), tags - ) - self.iter_blocked_batching_s.set(stats.iter_blocked_batching_s.get(), tags) - self.iter_blocked_format_s.set(stats.iter_blocked_format_s.get(), tags) - self.iter_blocked_collate_s.set(stats.iter_blocked_collate_s.get(), tags) - self.iter_blocked_finalize_s.set(stats.iter_blocked_finalize_s.get(), tags) + for stage, gauge in self._blocked_gauges.items(): + gauge.set(stats.get_blocked_timer(stage).get(), tags) self.iter_batches_total.set(stats.iter_batches_total, tags) self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) @@ -1284,6 +1290,22 @@ def __init__( # Streaming split coordinator stats (dataset level) self.streaming_split_coordinator_s: Timer = Timer() + def get_blocked_timer(self, stage: PipelineStage) -> Timer: + """Return the blocked-attribution Timer for the given pipeline stage.""" + match stage: + case PipelineStage.PRODUCTION_WAIT: + return self.iter_blocked_production_wait_s + case PipelineStage.DATA_TRANSFER: + return self.iter_blocked_data_transfer_s + case PipelineStage.BATCHING: + return self.iter_blocked_batching_s + case PipelineStage.FORMAT: + return self.iter_blocked_format_s + case PipelineStage.COLLATE: + return self.iter_blocked_collate_s + case PipelineStage.FINALIZE: + return self.iter_blocked_finalize_s + @property def stats_actor(self): return get_or_create_stats_actor() diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 47f154da2b19..8c522771fc03 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -21,7 +21,7 @@ ) from ray.data._internal.block_batching.util import WaitBlockPrefetcher from ray.data._internal.execution.interfaces.ref_bundle import BlockEntry, RefBundle -from ray.data._internal.stats import DatasetStats, TimeSpan +from ray.data._internal.stats import DatasetStats, PipelineStage, TimeSpan from ray.data.block import Block, BlockMetadata from ray.types import ObjectRef @@ -127,7 +127,7 @@ def test_restore_original_order_stats(): assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] -def test_report_batch_timings_overlap_attribution(): +def test_attribute_blocked_time_overlap_attribution(): stats = DatasetStats(metadata={}, parent=None) batch_iterator = BatchIterator(iter([]), stats=stats) timings = BatchTimings() @@ -137,7 +137,7 @@ def test_report_batch_timings_overlap_attribution(): timings.finalize = TimeSpan(start_s=50.0, end_s=60.0) batch = Batch(BatchMetadata(batch_idx=0, num_rows=8, timings=timings), None) - batch_iterator._report_batch_timings( + batch_iterator._attribute_blocked_time( batch, blocked_start_s=15.0, blocked_end_s=35.0 ) @@ -200,7 +200,7 @@ def test_zero_overlap_stage_finished_before_blocked(self): batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.5 ) - it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=3.0) + it._attribute_blocked_time(batch, blocked_start_s=2.0, blocked_end_s=3.0) assert stats.iter_blocked_production_wait_s.get() == 0.0 def test_zero_overlap_blocked_before_stage(self): @@ -208,7 +208,7 @@ def test_zero_overlap_blocked_before_stage(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=1.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=1.0) assert stats.iter_blocked_format_s.get() == 0.0 def test_partial_overlap(self): @@ -218,7 +218,7 @@ def test_partial_overlap(self): batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=2.0 ) - it._report_batch_timings(batch, blocked_start_s=1.0, blocked_end_s=3.0) + it._attribute_blocked_time(batch, blocked_start_s=1.0, blocked_end_s=3.0) assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) def test_full_overlap_stage_inside_blocked(self): @@ -226,7 +226,7 @@ def test_full_overlap_stage_inside_blocked(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) def test_no_collate_fn_zero_attribution(self): @@ -234,7 +234,7 @@ def test_no_collate_fn_zero_attribution(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) assert stats.iter_blocked_collate_s.get() == 0.0 @@ -243,7 +243,7 @@ def test_no_finalize_fn_zero_attribution(self): stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=3.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) assert stats.iter_blocked_finalize_s.get() == 0.0 @@ -258,7 +258,7 @@ def test_prefetch_hides_fetch_from_training(self): collate_end=2.6, ) # Training only starts blocking at t=2 (prefetch worked) - it._report_batch_timings(batch, blocked_start_s=2.0, blocked_end_s=2.6) + it._attribute_blocked_time(batch, blocked_start_s=2.0, blocked_end_s=2.6) assert stats.iter_blocked_production_wait_s.get() == 0.0 assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) @@ -270,12 +270,12 @@ def test_accumulation_across_batches(self): b1 = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.0, num_rows=10 ) - it._report_batch_timings(b1, blocked_start_s=0.0, blocked_end_s=2.0) + it._attribute_blocked_time(b1, blocked_start_s=0.0, blocked_end_s=2.0) # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 b2 = _make_batch_with_timings( production_wait_start=5.0, production_wait_end=6.0, num_rows=20 ) - it._report_batch_timings(b2, blocked_start_s=5.0, blocked_end_s=7.0) + it._attribute_blocked_time(b2, blocked_start_s=5.0, blocked_end_s=7.0) assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) assert stats.iter_batches_total == 2 @@ -295,7 +295,7 @@ def test_overlap_invariant_sum_leq_total(self): format_end=3.0, num_rows=5, ) - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) total = stats.iter_total_blocked_s.get() sum_stages = ( @@ -314,7 +314,7 @@ def test_blocked_inside_stage(self): batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=10.0 ) - it._report_batch_timings(batch, blocked_start_s=3.0, blocked_end_s=5.0) + it._attribute_blocked_time(batch, blocked_start_s=3.0, blocked_end_s=5.0) assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) def test_all_stages_simultaneous_overlap(self): @@ -335,7 +335,7 @@ def test_all_stages_simultaneous_overlap(self): num_rows=100, ) # Blocked window covers all stages - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) @@ -454,11 +454,11 @@ def test_batch_carries_timings_through_pipeline(self): # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) assert len(stage_dict) == 6 - assert stage_dict["production_wait"].start_s == 1.0 - assert stage_dict["batching"].end_s == 3.0 - assert stage_dict["format"].start_s == 3.0 - assert stage_dict["collate"].end_s == 5.0 - assert stage_dict["finalize"].start_s == 5.0 + assert stage_dict[PipelineStage.PRODUCTION_WAIT].start_s == 1.0 + assert stage_dict[PipelineStage.BATCHING].end_s == 3.0 + assert stage_dict[PipelineStage.FORMAT].start_s == 3.0 + assert stage_dict[PipelineStage.COLLATE].end_s == 5.0 + assert stage_dict[PipelineStage.FINALIZE].start_s == 5.0 assert batch.metadata.num_rows == 50 def test_full_pipeline_attribution(self): @@ -482,7 +482,7 @@ def test_full_pipeline_attribution(self): ) # Blocked window covers all stages - it._report_batch_timings(batch, blocked_start_s=0.0, blocked_end_s=5.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) # Each stage gets its full duration assert stats.iter_blocked_production_wait_s.get() == pytest.approx(0.5) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index ebe0e4895a7b..1de9e06d939d 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1940,18 +1940,24 @@ def set(self, value, tags): "iter_batch_finalizing_s", "time_to_first_batch_s", "iter_total_blocked_s", - "iter_blocked_production_wait_s", - "iter_blocked_data_transfer_s", - "iter_blocked_batching_s", - "iter_blocked_format_s", - "iter_blocked_collate_s", - "iter_blocked_finalize_s", "iter_batches_total", "iter_rows_total", "iter_user_s", ]: setattr(actor, attr, FakeGauge(attr)) + # Set up the blocked gauges dict (now stored as Dict[PipelineStage, Gauge]) + from ray.data._internal.stats import PipelineStage + + actor._blocked_gauges = { + PipelineStage.PRODUCTION_WAIT: FakeGauge("iter_blocked_production_wait_s"), + PipelineStage.DATA_TRANSFER: FakeGauge("iter_blocked_data_transfer_s"), + PipelineStage.BATCHING: FakeGauge("iter_blocked_batching_s"), + PipelineStage.FORMAT: FakeGauge("iter_blocked_format_s"), + PipelineStage.COLLATE: FakeGauge("iter_blocked_collate_s"), + PipelineStage.FINALIZE: FakeGauge("iter_blocked_finalize_s"), + } + actor.update_iteration_metrics(stats, "train_dataset_split_3") expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} From 09933dcfa91c40cd5bc760232c49d7cb60993b43 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 13:32:39 +0800 Subject: [PATCH 17/32] [Data] Defer iteration metrics rank label to follow-up PR Per reviewer feedback, the per-streaming-split-worker `rank` label on iteration metrics is out of scope for this PR. Reverts the rank label addition so existing iteration gauges keep their master label set (`("dataset",)`), and leaves a TODO for a follow-up PR to add `rank` across all iteration metrics (including the new blocked-attribution gauges) at once. - Removes module-level `_create_iteration_tags` and the `_StatsActor._create_iteration_tags` method - Reverts `iter_tag_keys` from `("dataset", "rank")` to `("dataset",)` - `update_iteration_metrics` uses `self._create_tags(dataset_tag)` again - Drops the `re` import (only used by the removed function) - Removes `test_create_iteration_tags_extracts_rank` and drops `rank` from `expected_tags` in `test_update_iteration_metrics_exports_new_iter_metrics` Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/stats.py | 26 +++++++------------------- python/ray/data/tests/test_stats.py | 20 +------------------- 2 files changed, 8 insertions(+), 38 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 53f80c73fb69..dd7b875493da 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1,7 +1,6 @@ import collections import copy import logging -import re import time from collections import defaultdict from contextlib import contextmanager @@ -63,19 +62,6 @@ StatsDict = Dict[str, List[BlockStats]] -def _create_iteration_tags(dataset_tag: Optional[str]): - dataset_tag = dataset_tag or "unknown_dataset" - tags = {"dataset": dataset_tag, "rank": "unknown"} - # Use findall + last match: the streaming-split index is always the - # trailing ``split_`` in the tag. The user-defined dataset name may - # itself contain ``split_`` so re.search (first match) could - # pick up the wrong one. - matches = re.findall(r"split_(\d+)", dataset_tag) - if matches: - tags["rank"] = matches[-1] - return tags - - def fmt(seconds: float) -> str: if seconds > 1: return str(round(seconds, 2)) + "s" @@ -513,7 +499,12 @@ def __init__(self, max_stats=1000): # Per Node metrics self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() - iter_tag_keys = ("dataset", "rank") + iter_tag_keys = ("dataset",) + # TODO: add a per-streaming-split-worker ``rank`` label to these + # iteration metrics (including the blocked-attribution gauges below) + # in a follow-up PR, so users can distinguish which split worker is + # blocked on which stage. Deferred to keep this PR's scope focused + # on the blocked-attribution data model. self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -837,7 +828,7 @@ def update_iteration_metrics( stats: "DatasetStats", dataset_tag, ): - tags = self._create_iteration_tags(dataset_tag) + tags = self._create_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) self.iter_total_s.set(stats.iter_total_s.get(), tags) @@ -1058,9 +1049,6 @@ def _create_tags( tags["node_ip"] = node_ip_tag return tags - def _create_iteration_tags(self, dataset_tag: Optional[str]): - return _create_iteration_tags(dataset_tag) - def get_or_create_stats_actor() -> ActorHandle[_StatsActor]: """Each cluster will contain exactly 1 _StatsActor. This function diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 1de9e06d939d..c2bec8a2be56 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -36,7 +36,6 @@ OperatorStatsSummary, StatsSummary, Timer, - _create_iteration_tags, _StatsActor, get_or_create_stats_actor, ) @@ -1879,23 +1878,6 @@ def test_stats_actor_iter_metrics(): assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1] -def test_create_iteration_tags_extracts_rank(): - assert _create_iteration_tags("train_abc_split_2") == { - "dataset": "train_abc_split_2", - "rank": "2", - } - assert _create_iteration_tags("dataset_without_split") == { - "dataset": "dataset_without_split", - "rank": "unknown", - } - # User-defined dataset name may contain split_; the trailing - # split index (from streaming split coordinator) should be used. - assert _create_iteration_tags("my_split_3_data_abc123_split_5") == { - "dataset": "my_split_3_data_abc123_split_5", - "rank": "5", - } - - def test_update_iteration_metrics_exports_new_iter_metrics(): stats = DatasetStats(metadata={}, parent=None) stats.iter_total_s.add(11.0) @@ -1960,7 +1942,7 @@ def set(self, value, tags): actor.update_iteration_metrics(stats, "train_dataset_split_3") - expected_tags = {"dataset": "train_dataset_split_3", "rank": "3"} + expected_tags = {"dataset": "train_dataset_split_3"} assert recorded["iter_total_s"] == (11.0, expected_tags) assert recorded["iter_blocked_production_wait_s"] == (1.0, expected_tags) assert recorded["iter_blocked_data_transfer_s"] == (1.5, expected_tags) From 775c0a641e4e2e9724658cd6d028bbebb7587e22 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 13:32:44 +0800 Subject: [PATCH 18/32] [Data] Fix nested timer double-counting in resolve_block_refs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When prefetching is enabled, `iter_get_ref_bundles_s` was accumulated twice per block: once by `prefetch_batches_locally.get_next_ref_bundle` (which times `next(ref_bundles)`) and again by `resolve_block_refs` (which timed `next(block_ref_iter)` with the same Timer). The upstream wait was therefore double-counted in Prometheus `data_iter_get_ref_bundles_seconds`. `production_wait` is now captured as a bare `TimeSpan` (start/end `perf_counter` only) for overlap attribution, without calling `Timer.add()`. The cumulative `iter_get_ref_bundles_s` remains driven solely by `get_next_ref_bundle` when prefetch is on, and is not tracked when prefetch is off — matching master behavior. `data_transfer` (`iter_get_s`, the `ray.get()` call) is unaffected. Signed-off-by: OneSizeFitsQuorum --- .../ray/data/_internal/block_batching/util.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index b3728f795617..80b6a9a556f2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,6 +3,7 @@ import logging import queue import threading +import time from typing import ( Any, Callable, @@ -27,7 +28,7 @@ BlockPrefetcher, CollatedBatch, ) -from ray.data._internal.stats import DatasetStats, _timed +from ray.data._internal.stats import DatasetStats, TimeSpan, _timed from ray.data.block import Block, BlockAccessor, DataBatch from ray.types import ObjectRef @@ -185,6 +186,14 @@ def resolve_block_refs( When *stats* is provided, the cumulative fetch time is also recorded in ``stats.iter_get_s``. + Note: ``production_wait`` is captured as a :class:`TimeSpan` for + per-batch overlap attribution only — it does **not** accumulate into + ``stats.iter_get_ref_bundles_s``. When prefetching is enabled, that + cumulative Timer is already driven by + :func:`prefetch_batches_locally.get_next_ref_bundle`; accumulating here + too would double-count the upstream wait. When prefetching is off, + ``iter_get_ref_bundles_s`` is not tracked, matching master behavior. + Args: block_ref_iter: An iterator over block object references. stats: An optional stats object to record block hits, misses, and @@ -198,19 +207,28 @@ def resolve_block_refs( unknowns = 0 while True: - # (1) production_wait: blocked on upstream data pipeline - with _timed(stats.iter_get_ref_bundles_s if stats else None) as prod_span: - try: - block_ref = next(block_ref_iter) - except StopIteration: - break + # (1) production_wait: blocked on upstream data pipeline. + # Captured as a TimeSpan for overlap attribution only — do not + # accumulate into ``iter_get_ref_bundles_s`` here, because + # ``prefetch_batches_locally.get_next_ref_bundle`` already times + # the same ``next()`` call when prefetch is enabled. + prod_start_s = time.perf_counter() if stats else 0.0 + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + prod_span = ( + TimeSpan(start_s=prod_start_s, end_s=time.perf_counter()) + if stats + else None + ) current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) hits += current_hit misses += current_miss unknowns += current_unknown - # (2) data_transfer: cross-node transfer via ray.get() + # (2) data_transfer: cross-node transfer via ray.get(). # TODO(amogkam): Optimized further by batching multiple references # in a single `ray.get()` call. with _timed(stats.iter_get_s if stats else None) as xfer_span: From f988f33ac058180cb48068516a4135d433c34e14 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 13:32:49 +0800 Subject: [PATCH 19/32] [Data] Add regression test for ref_bundles timer non-accumulation Adds `test_resolve_block_refs_does_not_accumulate_ref_bundles_timer` to guard against re-introducing the nested-timer double-counting bug fixed in the previous commit. The test asserts that: - `iter_get_ref_bundles_s` stays at 0 after `resolve_block_refs` processes a slow upstream iterator (i.e. the function does not accumulate into that Timer; `prefetch_batches_locally` owns it). - `production_wait` `TimeSpan` is still captured per block, so overlap-based blocked attribution continues to work. Signed-off-by: OneSizeFitsQuorum --- .../data/tests/block_batching/test_util.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index eabe54b20754..8885db22b491 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -46,6 +46,40 @@ def test_resolve_block_refs(ray_start_regular_shared): assert [b.block for b in resolved] == [0, 1, 2] +def test_resolve_block_refs_does_not_accumulate_ref_bundles_timer( + ray_start_regular_shared, +): + """Regression test for nested-timer double-counting. + + ``resolve_block_refs`` must NOT accumulate into + ``iter_get_ref_bundles_s``. When prefetching is enabled, that Timer is + already driven by ``prefetch_batches_locally.get_next_ref_bundle``; + accumulating here too would double-count the upstream wait. The + ``production_wait`` TimeSpan is still captured per block for overlap + attribution. + """ + from ray.data._internal.stats import DatasetStats + + def slow_block_ref_iter(): + for i in range(3): + time.sleep(0.05) + yield ray.put(i) + + stats = DatasetStats(metadata={}, parent=None) + resolved = list(resolve_block_refs(slow_block_ref_iter(), stats=stats)) + + assert len(resolved) == 3 + + # production_wait TimeSpan captured per block for overlap attribution. + for r in resolved: + assert r.fetch is not None + assert r.fetch.production_wait is not None + assert r.fetch.production_wait.duration >= 0.0 + + # iter_get_ref_bundles_s must NOT be accumulated here. + assert stats.iter_get_ref_bundles_s.get() == 0.0 + + @pytest.mark.parametrize("block_size", [1, 10]) @pytest.mark.parametrize("drop_last", [True, False]) def test_blocks_to_batches(block_size, drop_last): From c604a0a02ee8e3048940b23444183ed90acbe79e Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 14:04:41 +0800 Subject: [PATCH 20/32] [Data] Document blocked-attribution known limitations and typing-shim TODO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback from JasonLi1909 and Cursor on design-level threads that don't require code changes, only clarification: - block_batching.py: expand the BlockFetchResult wrapping comment into a detailed TODO explaining why the typing shim exists (uniform iterator type across prefetch/non-prefetch paths) and how a future refactor could remove it. Per Jason's request to leave a TODO when the fix is out of scope. - iter_batches.py: relax the `sum(iter_blocked_*) ≤ iter_total_blocked_s` invariant in `_attribute_blocked_time`'s docstring to "approximates", and document two known cases that can push the sum above or below total by design: * Split fetch stages overlap (multi-block batches: production_wait for block N+1 concurrent with data_transfer for block N). * Reorder buffer wait is unattributed under preserve_order + prefetch (per-stage TimeSpan recorded at format/collate completion, before the batch leaves restore_original_order). These are documented as expected behavior, not bugs. Signed-off-by: OneSizeFitsQuorum --- .../block_batching/block_batching.py | 9 +++++++ .../_internal/block_batching/iter_batches.py | 25 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 0a54689b9552..95566ee9e037 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -34,6 +34,15 @@ def batch_blocks( # blocks were already resolved before entering the pipeline). # Use map() instead of a generator expression to avoid holding # references to blocks. + # + # TODO: this BlockFetchResult wrapping is a typing shim — + # `batch_blocks` receives already-resolved `Block`s with no fetch + # timing, but `blocks_to_batches` consumes `Iterator[BlockFetchResult]` + # to keep a uniform type across the prefetch and non-prefetch paths. + # A future refactor could make fetch timing optional at the + # `_BatchingIterator` level (e.g. accept `Union[Block, BlockFetchResult]` + # or move the merge_fetch call site behind a capability check) so this + # shim can be removed. Out of scope for this PR. wrapped_blocks = map(lambda b: BlockFetchResult(block=b), blocks) # Build the processing pipeline diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index fed9dbab3a12..b977350d394b 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -282,7 +282,30 @@ def _attribute_blocked_time( * Stage fully inside blocked window → full stage duration credited. * Partial overlap → partial credit. - **Invariant**: ``sum(iter_blocked_*) ≤ iter_total_blocked_s``. + **Invariant**: ``sum(iter_blocked_*)`` approximates + ``iter_total_blocked_s``. It is a lower bound in the common case + (``sum ≤ total``), but two known cases can push ``sum`` above + ``total`` by design — both reflect real blocking the training + thread experienced: + + * **Split fetch stages overlap.** For multi-block batches, + ``production_wait`` and ``data_transfer`` are merged into + spanning windows per stage across blocks. In a pipelined + system those windows can overlap (data_transfer for block N + concurrent with production_wait for block N+1), so the same + blocked interval may be credited to both stages. Using + interval lists instead of spanning windows would restore + ``sum ≤ total`` but significantly increase complexity. + * **Reorder buffer wait is unattributed.** Under + ``preserve_order=True`` with a multi-worker format threadpool + (``prefetch_batches ≥ 1``), per-stage ``TimeSpan`` values are + recorded when format/collate finish — often before the batch + is released from ``restore_original_order``. The wait inside + the reorder buffer is part of the blocked window but covered + by no stage ``TimeSpan``, so it shows up as an unattributed + gap (lowers ``sum`` relative to ``total``). This is expected; + users seeing a high unattributed ratio should investigate + ordering-related stall. Runs in the training thread; no locks needed because background threads finished writing ``batch.metadata.timings`` before the batch From ea957bad39f6ecfd73eb1477e1c854fcc0041c56 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 14:37:47 +0800 Subject: [PATCH 21/32] [Data] Add unit tests for Timer.timer, _timed, merge_fetch, get_blocked_timer Closes five test-coverage gaps identified during PR review: - TestTimerTimerSpan (test_stats.py): verifies Timer.timer() yields a fresh TimeSpan per call and accumulates its duration via add(); also covers _timed(None) skipping perf_counter and _timed(Timer) yielding a real span. - TestGetBlockedTimer (test_stats.py, parametrized over 6 stages): verifies DatasetStats.get_blocked_timer(stage) returns the correct Timer attribute via the match statement (previously only covered indirectly through update_iteration_metrics). - TestMergeFetch.data_transfer cases (test_iter_batches.py): extends existing production_wait-only tests to cover data_transfer merging (multiple blocks, overlapping windows, independence from production_wait, None src preservation). - test_resolve_block_refs_accumulates_data_transfer_timer (test_util.py): pairs with the existing ref_bundles non-accumulation test to verify the data_transfer (iter_get_s / ray.get) path IS accumulated and a per-block data_transfer TimeSpan is captured. Signed-off-by: OneSizeFitsQuorum --- .../tests/block_batching/test_iter_batches.py | 71 +++++++++++++++++++ .../data/tests/block_batching/test_util.py | 29 ++++++++ python/ray/data/tests/test_stats.py | 70 ++++++++++++++++++ 3 files changed, 170 insertions(+) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 8c522771fc03..34cf715aa6a9 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -12,6 +12,7 @@ Batch, BatchMetadata, BatchTimings, + BlockFetchTiming, BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( @@ -436,6 +437,76 @@ def test_merge_into_empty_destination(self): assert dst.production_wait.start_s == 10.0 assert dst.production_wait.end_s == 20.0 + def test_merge_data_transfer_multiple_blocks(self): + """data_transfer windows are unioned across multiple blocks.""" + dst = BatchTimings() + + src1 = BlockFetchTiming( + data_transfer=TimeSpan(start_s=1.0, end_s=2.0) + ) + dst.merge_fetch(src1) + + src2 = BlockFetchTiming( + data_transfer=TimeSpan(start_s=3.0, end_s=4.0) + ) + dst.merge_fetch(src2) + + # Union: [1.0, 4.0] + assert dst.data_transfer.start_s == 1.0 + assert dst.data_transfer.end_s == 4.0 + + def test_merge_data_transfer_overlapping_blocks(self): + """Overlapping data_transfer windows are correctly merged.""" + dst = BatchTimings() + + dst.merge_fetch( + BlockFetchTiming(data_transfer=TimeSpan(start_s=1.0, end_s=5.0)) + ) + dst.merge_fetch( + BlockFetchTiming(data_transfer=TimeSpan(start_s=3.0, end_s=7.0)) + ) + + assert dst.data_transfer.start_s == 1.0 + assert dst.data_transfer.end_s == 7.0 + + def test_merge_both_stages_independent(self): + """production_wait and data_transfer merge independently.""" + dst = BatchTimings() + + # Block 1: prod [1,2], xfer [2,3] + dst.merge_fetch( + BlockFetchTiming( + production_wait=TimeSpan(start_s=1.0, end_s=2.0), + data_transfer=TimeSpan(start_s=2.0, end_s=3.0), + ) + ) + # Block 2: prod [5,6], xfer [6,7] + dst.merge_fetch( + BlockFetchTiming( + production_wait=TimeSpan(start_s=5.0, end_s=6.0), + data_transfer=TimeSpan(start_s=6.0, end_s=7.0), + ) + ) + + # Each stage unions independently. + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 6.0 + assert dst.data_transfer.start_s == 2.0 + assert dst.data_transfer.end_s == 7.0 + + def test_merge_data_transfer_none_preserves_destination(self): + """Merging a block with no data_transfer timing leaves dst unchanged.""" + dst = BatchTimings() + dst.data_transfer = TimeSpan(start_s=2.0, end_s=3.0) + + # src has only production_wait, data_transfer is None + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) + + assert dst.data_transfer.start_s == 2.0 + assert dst.data_transfer.end_s == 3.0 + class TestEndToEndTimingPropagation: """Tests that stage timings propagate correctly through the full pipeline.""" diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 8885db22b491..d78cd9e7cf4c 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -80,6 +80,35 @@ def slow_block_ref_iter(): assert stats.iter_get_ref_bundles_s.get() == 0.0 +def test_resolve_block_refs_accumulates_data_transfer_timer( + ray_start_regular_shared, +): + """``resolve_block_refs`` accumulates ``ray.get()`` time into + ``iter_get_s`` (the data_transfer stage) and captures a per-block + ``data_transfer`` TimeSpan for overlap attribution. + + Pairs with ``test_resolve_block_refs_does_not_accumulate_ref_bundles_timer`` + which verifies the production_wait path does NOT accumulate. + """ + from ray.data._internal.stats import DatasetStats + + block_refs = [ray.put(i) for i in range(3)] + + stats = DatasetStats(metadata={}, parent=None) + resolved = list(resolve_block_refs(iter(block_refs), stats=stats)) + + assert len(resolved) == 3 + + # data_transfer TimeSpan captured per block. + for r in resolved: + assert r.fetch is not None + assert r.fetch.data_transfer is not None + assert r.fetch.data_transfer.duration >= 0.0 + + # iter_get_s (ray.get time) IS accumulated by resolve_block_refs. + assert stats.iter_get_s.get() >= 0.0 + + @pytest.mark.parametrize("block_size", [1, 10]) @pytest.mark.parametrize("drop_last", [True, False]) def test_blocks_to_batches(block_size, drop_last): diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index c2bec8a2be56..9f7be780bcae 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -34,9 +34,12 @@ DatasetStatsSummary, NodeMetrics, OperatorStatsSummary, + PipelineStage, StatsSummary, Timer, + TimeSpan, _StatsActor, + _timed, get_or_create_stats_actor, ) from ray.data._internal.util import MemoryProfiler @@ -2646,6 +2649,73 @@ def test_from_dict_handles_none_values(self): assert t.max() == 0.0 +class TestTimerTimerSpan: + """Tests for Timer.timer() returning a TimeSpan and accumulating.""" + + def test_timer_yields_timespan(self): + """timer() yields a fresh TimeSpan whose duration is accumulated.""" + t = Timer() + with t.timer() as span: + time.sleep(0.01) + assert isinstance(span, TimeSpan) + assert span.duration > 0 + # The span's duration is accumulated into the Timer. + assert t.get() == pytest.approx(span.duration, rel=0.5) + assert t.max() > 0 + assert t.min() == pytest.approx(span.duration, rel=0.5) + + def test_each_call_returns_fresh_span(self): + """Each timer() call yields a distinct TimeSpan instance.""" + t = Timer() + spans = [] + with t.timer() as s1: + spans.append(s1) + with t.timer() as s2: + spans.append(s2) + assert spans[0] is not spans[1] + assert t.get() == pytest.approx(spans[0].duration + spans[1].duration, rel=0.5) + + def test_timed_skips_when_timer_none(self): + """_timed(None) yields None and skips perf_counter entirely.""" + with _timed(None) as span: + assert span is None + assert span is None + + def test_timed_yields_span_when_timer_given(self): + """_timed(Timer) yields a TimeSpan backed by the Timer.""" + t = Timer() + with _timed(t) as span: + time.sleep(0.01) + assert isinstance(span, TimeSpan) + assert span.duration > 0 + assert t.get() == pytest.approx(span.duration, rel=0.5) + + +@pytest.mark.parametrize( + "stage,attr", + [ + (PipelineStage.PRODUCTION_WAIT, "iter_blocked_production_wait_s"), + (PipelineStage.DATA_TRANSFER, "iter_blocked_data_transfer_s"), + (PipelineStage.BATCHING, "iter_blocked_batching_s"), + (PipelineStage.FORMAT, "iter_blocked_format_s"), + (PipelineStage.COLLATE, "iter_blocked_collate_s"), + (PipelineStage.FINALIZE, "iter_blocked_finalize_s"), + ], +) +class TestGetBlockedTimer: + """Tests for DatasetStats.get_blocked_timer() stage->Timer mapping.""" + + def test_get_blocked_timer_returns_correct_attribute(self, stage, attr): + """get_blocked_timer(stage) returns the Timer matching the stage.""" + stats = DatasetStats(metadata={}, parent=None) + assert stats.get_blocked_timer(stage) is getattr(stats, attr) + + def test_get_blocked_timer_returns_timer_instance(self, stage, attr): + """get_blocked_timer returns a real Timer (not None).""" + stats = DatasetStats(metadata={}, parent=None) + assert isinstance(stats.get_blocked_timer(stage), Timer) + + def test_streaming_exec_schedule_percentiles_populated(ray_start_regular_shared): # KLL-sketch percentile tracking is always on (bounded memory), so # the percentile fields are populated end-to-end with no env-var From 8fe547955f43a45d6b75264f550cee899723a6ac Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sat, 27 Jun 2026 14:47:15 +0800 Subject: [PATCH 22/32] [Data] Apply black formatting to recent changes Fixes CI lint failures from unformatted code in the Cursor-2 fix and the new TestMergeFetch data_transfer cases. Black 22.10.0 (the version pinned in .pre-commit-config.yaml) collapses two patterns onto single lines: - util.py: `TimeSpan(...) if stats else None` ternary inside the prod_span assignment. - test_iter_batches.py: `BlockFetchTiming(data_transfer=TimeSpan(...))` constructor calls in the new merge tests. No behavior change. ruff and black --check both pass on all changed files. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/util.py | 4 +--- python/ray/data/tests/block_batching/test_iter_batches.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 80b6a9a556f2..268dfc78d201 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -218,9 +218,7 @@ def resolve_block_refs( except StopIteration: break prod_span = ( - TimeSpan(start_s=prod_start_s, end_s=time.perf_counter()) - if stats - else None + TimeSpan(start_s=prod_start_s, end_s=time.perf_counter()) if stats else None ) current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 34cf715aa6a9..e925865a7915 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -441,14 +441,10 @@ def test_merge_data_transfer_multiple_blocks(self): """data_transfer windows are unioned across multiple blocks.""" dst = BatchTimings() - src1 = BlockFetchTiming( - data_transfer=TimeSpan(start_s=1.0, end_s=2.0) - ) + src1 = BlockFetchTiming(data_transfer=TimeSpan(start_s=1.0, end_s=2.0)) dst.merge_fetch(src1) - src2 = BlockFetchTiming( - data_transfer=TimeSpan(start_s=3.0, end_s=4.0) - ) + src2 = BlockFetchTiming(data_transfer=TimeSpan(start_s=3.0, end_s=4.0)) dst.merge_fetch(src2) # Union: [1.0, 4.0] From 4ac2b5ccdb9e6b080e7946cd1eb0288c46c11078 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 22:42:32 +0800 Subject: [PATCH 23/32] [Data] Address review: rename IterationStage/_maybe_time, simplify comments, remove gauge dict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses 5 new review threads from JasonLi1909 (2026-06-27) plus comment-hygiene cleanup: Renames (Jason #5, #6): - PipelineStage → IterationStage (stats.py, interfaces.py, tests) - _timed → _maybe_time (stats.py, util.py, tests) - IterationStage docstring expanded with per-stage descriptions and Prometheus label mapping note. Comment cleanup: - block_batching.py: typing-shim TODO compressed from 8 to 2 lines, "Out of scope for this PR" removed (Jason #4). - stats.py: rank-label TODO compressed, "this PR's scope" removed. - stats.py: Timer.timer() docstring "thread-local" → "fresh per call" (the span is a fresh instance, not TLS). - util.py: resolve_block_refs docstring Note compressed, "matching master behavior" removed; redundant inline comment replaced with a one-liner pointing at the docstring. - util.py: _pending_timings field gets a one-line comment. - iter_batches.py: _attribute_blocked_time docstring — the 15-line invariant/limitations section replaced with a brief TODO noting the design is open (interval lists as a future option). Structure simplification (Jason #7, #8): - _blocked_gauges: Dict[IterationStage, Gauge] replaced with 6 individual Gauge attributes (iter_blocked_*_s), matching the existing iteration gauge style. - for-loop in update_iteration_metrics replaced with 6 individual .set() calls. - get_blocked_timer() match statement retained (maps stage → Timer on DatasetStats, orthogonal to gauge storage). - Test fixture updated: FakeGauge setup uses individual attributes instead of a dict. All affected tests pass (43 in test_stats + test_util + test_iter_batches). black 22.10.0 and ruff both clean. Signed-off-by: OneSizeFitsQuorum --- .../block_batching/block_batching.py | 10 +- .../_internal/block_batching/interfaces.py | 16 +-- .../_internal/block_batching/iter_batches.py | 28 +--- .../ray/data/_internal/block_batching/util.py | 30 ++-- python/ray/data/_internal/stats.py | 135 ++++++++++-------- .../tests/block_batching/test_iter_batches.py | 12 +- python/ray/data/tests/test_stats.py | 50 +++---- 7 files changed, 135 insertions(+), 146 deletions(-) diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index 95566ee9e037..268347d92f3a 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -35,14 +35,8 @@ def batch_blocks( # Use map() instead of a generator expression to avoid holding # references to blocks. # - # TODO: this BlockFetchResult wrapping is a typing shim — - # `batch_blocks` receives already-resolved `Block`s with no fetch - # timing, but `blocks_to_batches` consumes `Iterator[BlockFetchResult]` - # to keep a uniform type across the prefetch and non-prefetch paths. - # A future refactor could make fetch timing optional at the - # `_BatchingIterator` level (e.g. accept `Union[Block, BlockFetchResult]` - # or move the merge_fetch call site behind a capability check) so this - # shim can be removed. Out of scope for this PR. + # TODO: make fetch timing optional at the _BatchingIterator level so + # this BlockFetchResult wrapping shim can be removed. wrapped_blocks = map(lambda b: BlockFetchResult(block=b), blocks) # Build the processing pipeline diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 30fb9d79f422..cf2acf083755 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Iterable, List, Optional, Tuple -from ray.data._internal.stats import PipelineStage, TimeSpan +from ray.data._internal.stats import IterationStage, TimeSpan from ray.data.block import Block, DataBatch from ray.types import ObjectRef @@ -52,15 +52,15 @@ class BatchTimings: collate: Optional[TimeSpan] = None finalize: Optional[TimeSpan] = None - def stages(self) -> Iterable[Tuple[PipelineStage, Optional[TimeSpan]]]: + def stages(self) -> Iterable[Tuple[IterationStage, Optional[TimeSpan]]]: """Iterate over ``(stage, timing)`` pairs for all pipeline stages.""" return ( - (PipelineStage.PRODUCTION_WAIT, self.production_wait), - (PipelineStage.DATA_TRANSFER, self.data_transfer), - (PipelineStage.BATCHING, self.batching), - (PipelineStage.FORMAT, self.format), - (PipelineStage.COLLATE, self.collate), - (PipelineStage.FINALIZE, self.finalize), + (IterationStage.PRODUCTION_WAIT, self.production_wait), + (IterationStage.DATA_TRANSFER, self.data_transfer), + (IterationStage.BATCHING, self.batching), + (IterationStage.FORMAT, self.format), + (IterationStage.COLLATE, self.collate), + (IterationStage.FINALIZE, self.finalize), ) def merge_fetch(self, src: BlockFetchTiming) -> None: diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index b977350d394b..f65eafdeb83f 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -283,29 +283,11 @@ def _attribute_blocked_time( * Partial overlap → partial credit. **Invariant**: ``sum(iter_blocked_*)`` approximates - ``iter_total_blocked_s``. It is a lower bound in the common case - (``sum ≤ total``), but two known cases can push ``sum`` above - ``total`` by design — both reflect real blocking the training - thread experienced: - - * **Split fetch stages overlap.** For multi-block batches, - ``production_wait`` and ``data_transfer`` are merged into - spanning windows per stage across blocks. In a pipelined - system those windows can overlap (data_transfer for block N - concurrent with production_wait for block N+1), so the same - blocked interval may be credited to both stages. Using - interval lists instead of spanning windows would restore - ``sum ≤ total`` but significantly increase complexity. - * **Reorder buffer wait is unattributed.** Under - ``preserve_order=True`` with a multi-worker format threadpool - (``prefetch_batches ≥ 1``), per-stage ``TimeSpan`` values are - recorded when format/collate finish — often before the batch - is released from ``restore_original_order``. The wait inside - the reorder buffer is part of the blocked window but covered - by no stage ``TimeSpan``, so it shows up as an unattributed - gap (lowers ``sum`` relative to ``total``). This is expected; - users seeing a high unattributed ratio should investigate - ordering-related stall. + ``iter_total_blocked_s`` (``≤ total`` in the common case). + TODO: two cases violate it by design — split fetch stages overlap + for multi-block batches, and reorder buffer wait under + ``preserve_order`` is unattributed. Consider interval lists for + precise non-overlapping attribution. Runs in the training thread; no locks needed because background threads finished writing ``batch.metadata.timings`` before the batch diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 268dfc78d201..4a7366a6bc54 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -28,7 +28,7 @@ BlockPrefetcher, CollatedBatch, ) -from ray.data._internal.stats import DatasetStats, TimeSpan, _timed +from ray.data._internal.stats import DatasetStats, TimeSpan, _maybe_time from ray.data.block import Block, BlockAccessor, DataBatch from ray.types import ObjectRef @@ -186,13 +186,10 @@ def resolve_block_refs( When *stats* is provided, the cumulative fetch time is also recorded in ``stats.iter_get_s``. - Note: ``production_wait`` is captured as a :class:`TimeSpan` for - per-batch overlap attribution only — it does **not** accumulate into - ``stats.iter_get_ref_bundles_s``. When prefetching is enabled, that - cumulative Timer is already driven by - :func:`prefetch_batches_locally.get_next_ref_bundle`; accumulating here - too would double-count the upstream wait. When prefetching is off, - ``iter_get_ref_bundles_s`` is not tracked, matching master behavior. + Note: ``production_wait`` is captured for overlap attribution only + and does not accumulate into ``iter_get_ref_bundles_s`` — that + Timer is driven by :func:`prefetch_batches_locally.get_next_ref_bundle` + when prefetching is enabled; accumulating here would double-count. Args: block_ref_iter: An iterator over block object references. @@ -207,11 +204,7 @@ def resolve_block_refs( unknowns = 0 while True: - # (1) production_wait: blocked on upstream data pipeline. - # Captured as a TimeSpan for overlap attribution only — do not - # accumulate into ``iter_get_ref_bundles_s`` here, because - # ``prefetch_batches_locally.get_next_ref_bundle`` already times - # the same ``next()`` call when prefetch is enabled. + # (1) production_wait: upstream wait (not accumulated; see docstring). prod_start_s = time.perf_counter() if stats else 0.0 try: block_ref = next(block_ref_iter) @@ -229,7 +222,7 @@ def resolve_block_refs( # (2) data_transfer: cross-node transfer via ray.get(). # TODO(amogkam): Optimized further by batching multiple references # in a single `ray.get()` call. - with _timed(stats.iter_get_s if stats else None) as xfer_span: + with _maybe_time(stats.iter_get_s if stats else None) as xfer_span: block = ray.get(block_ref) fetch = BlockFetchTiming(production_wait=prod_span, data_transfer=xfer_span) @@ -284,6 +277,7 @@ def __init__( self._drop_last = drop_last self._global_counter = 0 self._done_adding = False + # Accumulates per-block fetch timings until a batch is yielded. self._pending_timings = BatchTimings() if shuffle_buffer_min_size is not None: @@ -306,7 +300,7 @@ def __next__(self) -> Batch: ) if can_yield: - with _timed( + with _maybe_time( self._stats.iter_next_batch_s if self._stats else None ) as span: next_batch = self._batcher.next_batch() @@ -351,7 +345,7 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - with _timed(stats.iter_format_batch_s if stats else None) as span: + with _maybe_time(stats.iter_format_batch_s if stats else None) as span: formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) @@ -405,7 +399,7 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with _timed(stats.iter_collate_batch_s if stats else None) as span: + with _maybe_time(stats.iter_collate_batch_s if stats else None) as span: collated_data = collate_fn(batch.data) batch.metadata.timings.collate = span return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -431,7 +425,7 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with _timed(stats.iter_finalize_batch_s if stats else None) as span: + with _maybe_time(stats.iter_finalize_batch_s if stats else None) as span: finalized_data = finalize_fn(batch.data) batch.metadata.timings.finalize = span return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index dd7b875493da..a39005cc8b2a 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -162,15 +162,22 @@ def get( ) -class PipelineStage(Enum): - """Pipeline stages for blocked attribution.""" +class IterationStage(Enum): + """Stages of the iter_batches pipeline used to attribute training-thread + blocked time. Each stage's wall-clock window is overlapped with the + training thread's ``next()`` blocked window to credit stall to the + responsible stage (see ``_attribute_blocked_time``). + + Each value is the Prometheus label for the corresponding + ``data_iter_blocked__seconds`` gauge. + """ - PRODUCTION_WAIT = "production_wait" - DATA_TRANSFER = "data_transfer" - BATCHING = "batching" - FORMAT = "format" - COLLATE = "collate" - FINALIZE = "finalize" + PRODUCTION_WAIT = "production_wait" # waiting on upstream data production + DATA_TRANSFER = "data_transfer" # cross-node ray.get() transfer + BATCHING = "batching" # slicing/shuffling blocks into batches + FORMAT = "format" # converting blocks to batch format + COLLATE = "collate" # applying user collate_fn + FINALIZE = "finalize" # applying user finalize_fn @dataclass @@ -191,7 +198,7 @@ def duration(self) -> float: @contextmanager -def _timed(timer: Optional["Timer"]) -> Iterator[Optional[TimeSpan]]: +def _maybe_time(timer: Optional["Timer"]) -> Iterator[Optional[TimeSpan]]: """Time a block of code, yielding a :class:`TimeSpan` (or ``None``). When *timer* is ``None`` (e.g. ``stats`` is not configured), yields @@ -231,10 +238,10 @@ def __init__(self): @contextmanager def timer(self) -> Iterator[TimeSpan]: - """Time a block, yielding a thread-local :class:`TimeSpan`. + """Time a block, yielding a fresh :class:`TimeSpan` per call. - The returned ``TimeSpan`` is a fresh instance per call, making - this safe to use from multiple threads sharing the same ``Timer``. + The returned span is a distinct instance each call, so multiple + threads sharing the same ``Timer`` don't race on span fields. The duration is also accumulated into ``self`` via :meth:`add`. """ span = TimeSpan(start_s=time.perf_counter()) @@ -500,11 +507,9 @@ def __init__(self, max_stats=1000): self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() iter_tag_keys = ("dataset",) - # TODO: add a per-streaming-split-worker ``rank`` label to these - # iteration metrics (including the blocked-attribution gauges below) - # in a follow-up PR, so users can distinguish which split worker is - # blocked on which stage. Deferred to keep this PR's scope focused - # on the blocked-attribution data model. + # TODO: add a per-streaming-split-worker ``rank`` label to iteration + # metrics so users can distinguish which split worker blocked on + # which stage. self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -549,38 +554,36 @@ def __init__(self, max_stats=1000): description="Total wall-clock seconds spent in the dataset iterator", tag_keys=iter_tag_keys, ) - self._blocked_gauges: Dict[PipelineStage, Gauge] = { - PipelineStage.PRODUCTION_WAIT: Gauge( - "data_iter_blocked_production_wait_seconds", - description="Seconds user thread is blocked on upstream data production", - tag_keys=iter_tag_keys, - ), - PipelineStage.DATA_TRANSFER: Gauge( - "data_iter_blocked_data_transfer_seconds", - description="Seconds user thread is blocked on cross-node data transfer (ray.get)", - tag_keys=iter_tag_keys, - ), - PipelineStage.BATCHING: Gauge( - "data_iter_blocked_batching_seconds", - description="Seconds user thread is blocked on batch creation", - tag_keys=iter_tag_keys, - ), - PipelineStage.FORMAT: Gauge( - "data_iter_blocked_format_seconds", - description="Seconds user thread is blocked on batch formatting", - tag_keys=iter_tag_keys, - ), - PipelineStage.COLLATE: Gauge( - "data_iter_blocked_collate_seconds", - description="Seconds user thread is blocked on batch collation", - tag_keys=iter_tag_keys, - ), - PipelineStage.FINALIZE: Gauge( - "data_iter_blocked_finalize_seconds", - description="Seconds user thread is blocked on batch finalization", - tag_keys=iter_tag_keys, - ), - } + self.iter_blocked_production_wait_s = Gauge( + "data_iter_blocked_production_wait_seconds", + description="Seconds user thread is blocked on upstream data production", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_data_transfer_s = Gauge( + "data_iter_blocked_data_transfer_seconds", + description="Seconds user thread is blocked on cross-node data transfer (ray.get)", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_batching_s = Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_format_s = Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_collate_s = Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_finalize_s = Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ) self.iter_batches_total = Gauge( "data_iter_batches_total", description="Total batches delivered to the user thread", @@ -852,8 +855,24 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) - for stage, gauge in self._blocked_gauges.items(): - gauge.set(stats.get_blocked_timer(stage).get(), tags) + self.iter_blocked_production_wait_s.set( + stats.get_blocked_timer(IterationStage.PRODUCTION_WAIT).get(), tags + ) + self.iter_blocked_data_transfer_s.set( + stats.get_blocked_timer(IterationStage.DATA_TRANSFER).get(), tags + ) + self.iter_blocked_batching_s.set( + stats.get_blocked_timer(IterationStage.BATCHING).get(), tags + ) + self.iter_blocked_format_s.set( + stats.get_blocked_timer(IterationStage.FORMAT).get(), tags + ) + self.iter_blocked_collate_s.set( + stats.get_blocked_timer(IterationStage.COLLATE).get(), tags + ) + self.iter_blocked_finalize_s.set( + stats.get_blocked_timer(IterationStage.FINALIZE).get(), tags + ) self.iter_batches_total.set(stats.iter_batches_total, tags) self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) @@ -1278,20 +1297,20 @@ def __init__( # Streaming split coordinator stats (dataset level) self.streaming_split_coordinator_s: Timer = Timer() - def get_blocked_timer(self, stage: PipelineStage) -> Timer: + def get_blocked_timer(self, stage: IterationStage) -> Timer: """Return the blocked-attribution Timer for the given pipeline stage.""" match stage: - case PipelineStage.PRODUCTION_WAIT: + case IterationStage.PRODUCTION_WAIT: return self.iter_blocked_production_wait_s - case PipelineStage.DATA_TRANSFER: + case IterationStage.DATA_TRANSFER: return self.iter_blocked_data_transfer_s - case PipelineStage.BATCHING: + case IterationStage.BATCHING: return self.iter_blocked_batching_s - case PipelineStage.FORMAT: + case IterationStage.FORMAT: return self.iter_blocked_format_s - case PipelineStage.COLLATE: + case IterationStage.COLLATE: return self.iter_blocked_collate_s - case PipelineStage.FINALIZE: + case IterationStage.FINALIZE: return self.iter_blocked_finalize_s @property diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index e925865a7915..dcc4f44e6d88 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -22,7 +22,7 @@ ) from ray.data._internal.block_batching.util import WaitBlockPrefetcher from ray.data._internal.execution.interfaces.ref_bundle import BlockEntry, RefBundle -from ray.data._internal.stats import DatasetStats, PipelineStage, TimeSpan +from ray.data._internal.stats import DatasetStats, IterationStage, TimeSpan from ray.data.block import Block, BlockMetadata from ray.types import ObjectRef @@ -521,11 +521,11 @@ def test_batch_carries_timings_through_pipeline(self): # Verify all stages are accessible via stages() iterator stage_dict = dict(batch.metadata.timings.stages()) assert len(stage_dict) == 6 - assert stage_dict[PipelineStage.PRODUCTION_WAIT].start_s == 1.0 - assert stage_dict[PipelineStage.BATCHING].end_s == 3.0 - assert stage_dict[PipelineStage.FORMAT].start_s == 3.0 - assert stage_dict[PipelineStage.COLLATE].end_s == 5.0 - assert stage_dict[PipelineStage.FINALIZE].start_s == 5.0 + assert stage_dict[IterationStage.PRODUCTION_WAIT].start_s == 1.0 + assert stage_dict[IterationStage.BATCHING].end_s == 3.0 + assert stage_dict[IterationStage.FORMAT].start_s == 3.0 + assert stage_dict[IterationStage.COLLATE].end_s == 5.0 + assert stage_dict[IterationStage.FINALIZE].start_s == 5.0 assert batch.metadata.num_rows == 50 def test_full_pipeline_attribution(self): diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 9f7be780bcae..e2204fde354d 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -34,12 +34,12 @@ DatasetStatsSummary, NodeMetrics, OperatorStatsSummary, - PipelineStage, + IterationStage, StatsSummary, Timer, TimeSpan, _StatsActor, - _timed, + _maybe_time, get_or_create_stats_actor, ) from ray.data._internal.util import MemoryProfiler @@ -1931,17 +1931,17 @@ def set(self, value, tags): ]: setattr(actor, attr, FakeGauge(attr)) - # Set up the blocked gauges dict (now stored as Dict[PipelineStage, Gauge]) - from ray.data._internal.stats import PipelineStage - - actor._blocked_gauges = { - PipelineStage.PRODUCTION_WAIT: FakeGauge("iter_blocked_production_wait_s"), - PipelineStage.DATA_TRANSFER: FakeGauge("iter_blocked_data_transfer_s"), - PipelineStage.BATCHING: FakeGauge("iter_blocked_batching_s"), - PipelineStage.FORMAT: FakeGauge("iter_blocked_format_s"), - PipelineStage.COLLATE: FakeGauge("iter_blocked_collate_s"), - PipelineStage.FINALIZE: FakeGauge("iter_blocked_finalize_s"), - } + # Blocked gauges are stored as individual attributes (matching the + # other iteration gauges above). + for attr in [ + "iter_blocked_production_wait_s", + "iter_blocked_data_transfer_s", + "iter_blocked_batching_s", + "iter_blocked_format_s", + "iter_blocked_collate_s", + "iter_blocked_finalize_s", + ]: + setattr(actor, attr, FakeGauge(attr)) actor.update_iteration_metrics(stats, "train_dataset_split_3") @@ -2675,16 +2675,16 @@ def test_each_call_returns_fresh_span(self): assert spans[0] is not spans[1] assert t.get() == pytest.approx(spans[0].duration + spans[1].duration, rel=0.5) - def test_timed_skips_when_timer_none(self): - """_timed(None) yields None and skips perf_counter entirely.""" - with _timed(None) as span: + def test_maybe_time_skips_when_timer_none(self): + """_maybe_time(None) yields None and skips perf_counter entirely.""" + with _maybe_time(None) as span: assert span is None assert span is None - def test_timed_yields_span_when_timer_given(self): - """_timed(Timer) yields a TimeSpan backed by the Timer.""" + def test_maybe_time_yields_span_when_timer_given(self): + """_maybe_time(Timer) yields a TimeSpan backed by the Timer.""" t = Timer() - with _timed(t) as span: + with _maybe_time(t) as span: time.sleep(0.01) assert isinstance(span, TimeSpan) assert span.duration > 0 @@ -2694,12 +2694,12 @@ def test_timed_yields_span_when_timer_given(self): @pytest.mark.parametrize( "stage,attr", [ - (PipelineStage.PRODUCTION_WAIT, "iter_blocked_production_wait_s"), - (PipelineStage.DATA_TRANSFER, "iter_blocked_data_transfer_s"), - (PipelineStage.BATCHING, "iter_blocked_batching_s"), - (PipelineStage.FORMAT, "iter_blocked_format_s"), - (PipelineStage.COLLATE, "iter_blocked_collate_s"), - (PipelineStage.FINALIZE, "iter_blocked_finalize_s"), + (IterationStage.PRODUCTION_WAIT, "iter_blocked_production_wait_s"), + (IterationStage.DATA_TRANSFER, "iter_blocked_data_transfer_s"), + (IterationStage.BATCHING, "iter_blocked_batching_s"), + (IterationStage.FORMAT, "iter_blocked_format_s"), + (IterationStage.COLLATE, "iter_blocked_collate_s"), + (IterationStage.FINALIZE, "iter_blocked_finalize_s"), ], ) class TestGetBlockedTimer: From b13112d5e15aa14cf3fda2a0809486e27cd638f5 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 22:57:14 +0800 Subject: [PATCH 24/32] [Data] Finish IterationStage rename in docstrings; slim _attribute_blocked_time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to the IterationStage rename (commit 4ac2b5ccdb) — the enum was renamed but a handful of docstrings still said "pipeline stage". Also tightens _attribute_blocked_time per review feedback that it was still too long. - interfaces.py: 4 docstring references "pipeline stage" → "iteration stage" (BatchTimings class doc, stages() doc, BatchMetadata.timings attribute doc). - stats.py: get_blocked_timer docstring "pipeline stage" → "iteration stage". - iter_batches.py: _attribute_blocked_time docstring rewritten — drops the 3 prefetch-config bullets (obvious from the formula), drops the thread-safety paragraph (implementation detail), and folds the invariant + two limitations into a single TODO noting the design is open (interval lists as a future option). Also fixes "pipeline stage" → "stage" here. Net ~20 lines shorter. No behavior change. black 22.10.0 + ruff clean; 25 affected tests pass. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 8 ++--- .../_internal/block_batching/iter_batches.py | 36 ++++++------------- python/ray/data/_internal/stats.py | 2 +- 3 files changed, 15 insertions(+), 31 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index cf2acf083755..ee646982113b 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -33,10 +33,10 @@ class BlockFetchResult: @dataclass class BatchTimings: - """Per-batch pipeline-stage timing windows for overlap-based attribution. + """Per-batch iteration-stage timing windows for overlap-based attribution. Each field records the ``(start_s, end_s)`` wall-clock window during which - a particular pipeline stage was active for this batch. The training thread + a particular iteration stage was active for this batch. The training thread later compares these windows against its own blocked window to determine how much each stage contributed to training-thread stall (see :meth:`BatchIterator._attribute_blocked_time`). @@ -53,7 +53,7 @@ class BatchTimings: finalize: Optional[TimeSpan] = None def stages(self) -> Iterable[Tuple[IterationStage, Optional[TimeSpan]]]: - """Iterate over ``(stage, timing)`` pairs for all pipeline stages.""" + """Iterate over ``(stage, timing)`` pairs for all iteration stages.""" return ( (IterationStage.PRODUCTION_WAIT, self.production_wait), (IterationStage.DATA_TRANSFER, self.data_transfer), @@ -94,7 +94,7 @@ class BatchMetadata: batch_idx: The global index of this batch so that downstream operations can maintain ordering. num_rows: Number of rows in this batch (for ``iter_rows_total``). - timings: Pipeline-stage timing windows for this batch. + timings: Iteration-stage timing windows for this batch. """ batch_idx: int diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f65eafdeb83f..6d90087e0294 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -266,39 +266,23 @@ def _attribute_blocked_time( ) -> None: """Attribute per-stage blocked time via overlap with the training window. - For each pipeline stage we know when it ran ``[timing.start_s, - timing.end_s]`` (recorded by background threads onto - ``batch.metadata.timings``). We also know when the training thread - was blocked ``[blocked_start_s, blocked_end_s]`` (captured in - ``_iter_batches`` around ``next()``). - - The attribution for a stage is the length of the intersection:: + Each stage's window ``[timing.start_s, timing.end_s]`` (recorded by + background threads onto ``batch.metadata.timings``) is intersected + with the training thread's blocked window ``[blocked_start_s, + blocked_end_s]`` (captured around ``next()`` in ``_iter_batches``):: overlap = min(timing.end, blocked_end) - max(timing.start, blocked_start) - This correctly handles all prefetch configurations: - - * Stage finished before training blocked → overlap ≤ 0 → zero credit. - * Stage fully inside blocked window → full stage duration credited. - * Partial overlap → partial credit. - - **Invariant**: ``sum(iter_blocked_*)`` approximates - ``iter_total_blocked_s`` (``≤ total`` in the common case). - TODO: two cases violate it by design — split fetch stages overlap - for multi-block batches, and reorder buffer wait under - ``preserve_order`` is unattributed. Consider interval lists for + TODO: ``sum(iter_blocked_*)`` only approximates + ``iter_total_blocked_s`` — split fetch stages overlap for + multi-block batches, and reorder buffer wait under + ``preserve_order`` is unattributed. Interval lists would give precise non-overlapping attribution. - Runs in the training thread; no locks needed because background - threads finished writing ``batch.metadata.timings`` before the batch - was enqueued. - Args: batch: The batch whose per-stage timings should be attributed. - blocked_start_s: ``perf_counter()`` value just before the - training thread called ``next(batch_iter)``. - blocked_end_s: ``perf_counter()`` value just after ``next()`` - returned. + blocked_start_s: ``perf_counter()`` just before ``next()``. + blocked_end_s: ``perf_counter()`` just after ``next()`` returned. """ if self._stats is None: return diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index a39005cc8b2a..9998bdabacc4 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1298,7 +1298,7 @@ def __init__( self.streaming_split_coordinator_s: Timer = Timer() def get_blocked_timer(self, stage: IterationStage) -> Timer: - """Return the blocked-attribution Timer for the given pipeline stage.""" + """Return the blocked-attribution Timer for the given iteration stage.""" match stage: case IterationStage.PRODUCTION_WAIT: return self.iter_blocked_production_wait_s From 57ede0e6b7beb0398414acb95f69e72b1516c646 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:09:48 +0800 Subject: [PATCH 25/32] [Data] Record iter_total_s and flush metrics on early exit Previously `iter_total_s.add(...)` and the final `_StatsManager.update_iteration_metrics` flush ran only after `yield from batch_iterator` completed normally. On early exit (e.g. `break` in the training loop), the `finally` block only shut down the executor, so `iter_total_s` stayed at zero and Prometheus never got a final flush despite partial iteration. Moves both calls into the `finally` block (before `_on_iteration_end` to keep total time excluding executor shutdown, matching master's placement). Normal completion is unaffected; early exit now records the wall-clock time up to the break and flushes metrics. Note: mid-run staleness (iter_total_s reads 0 during iteration because add() only runs at the end) is a separate pre-existing issue and remains a follow-up. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/iterator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 9cfaea2f8dfc..d2a0275f366e 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -295,10 +295,14 @@ def callback(num_bytes: int) -> None: try: yield from batch_iterator + finally: + # Runs on both normal completion and early exit (e.g. `break` + # in the training loop). `iter_total_s` and the final metrics + # flush must happen here so partial iteration is still + # recorded; `_on_iteration_end` shuts down the executor after. if stats: stats.iter_total_s.add(time.perf_counter() - time_start) _StatsManager.update_iteration_metrics(stats, dataset_tag) - finally: # On early exit (e.g. ``break`` in the for-loop), the inner # ``_ClosingIterator`` would only shut down the executor via # its ``__del__``, which is non-deterministic. The hook From 39ebab54e24b06be5611e042f944821e43b9df3d Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:22:52 +0800 Subject: [PATCH 26/32] [Data] Test code style cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six style fixes from self-review of test code: 1. Rename TestTimerTimerSpan → TestTimerSpan (the "Timer" duplication was awkward; the class tests Timer.timer() returning TimeSpan). 2. Remove stale `assert "restore order" not in text` from test_iter_stats_to_string_omits_zero_stages — the restore_order stage was removed in an earlier commit, so the assertion was vacuously true. 3. Merge the separate for-loop that set up blocked-gauge FakeGauges into the main gauge-setup loop in test_update_iteration_metrics_exports_new_iter_metrics, and drop the now-redundant "Blocked gauges are stored as individual attributes" comment (blocked gauges are just attributes now, like the others). 4. TestMergeFetch: the five original methods passed `BatchTimings()` as the `src` argument to `merge_fetch()`, whose signature takes `BlockFetchTiming`. It worked by duck-typing but was the wrong type. Switched to `BlockFetchTiming(production_wait=TimeSpan(...))` to match the signature and the style of the newer data_transfer tests. 5. test_overlap_invariant_sum_leq_total: docstring said "always holds" but the invariant was relaxed to "approximates" with a TODO noting cases that violate it. Tightened to "holds for non-overlapping stages" to reflect what the test actually exercises. 6. _make_span / _make_batch_with_timings: replaced the 0.0-as-sentinel pattern (start==0.0 and end==0.0 → None) with explicit Optional parameters defaulting to None. Mirrors the source-code cleanup Jason requested for the 0.0 sentinel. All affected tests pass (38). black 22.10.0 + ruff clean. Signed-off-by: OneSizeFitsQuorum --- .../tests/block_batching/test_iter_batches.py | 86 +++++++++---------- python/ray/data/tests/test_stats.py | 15 +--- 2 files changed, 45 insertions(+), 56 deletions(-) diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index dcc4f44e6d88..4b5fde3854b3 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -151,27 +151,27 @@ def test_attribute_blocked_time_overlap_attribution(): assert stats.iter_rows_total == 8 -def _make_span(start: float, end: float) -> Optional[TimeSpan]: - """Create a TimeSpan, or None if the stage didn't execute (both zero).""" - if start == 0.0 and end == 0.0: +def _make_span(start: Optional[float], end: Optional[float]) -> Optional[TimeSpan]: + """Create a TimeSpan, or None if the stage didn't execute (either is None).""" + if start is None or end is None: return None return TimeSpan(start_s=start, end_s=end) def _make_batch_with_timings( - production_wait_start=0.0, - production_wait_end=0.0, - data_transfer_start=0.0, - data_transfer_end=0.0, - batching_start=0.0, - batching_end=0.0, - format_start=0.0, - format_end=0.0, - collate_start=0.0, - collate_end=0.0, - finalize_start=0.0, - finalize_end=0.0, - num_rows=0, + production_wait_start: Optional[float] = None, + production_wait_end: Optional[float] = None, + data_transfer_start: Optional[float] = None, + data_transfer_end: Optional[float] = None, + batching_start: Optional[float] = None, + batching_end: Optional[float] = None, + format_start: Optional[float] = None, + format_end: Optional[float] = None, + collate_start: Optional[float] = None, + collate_end: Optional[float] = None, + finalize_start: Optional[float] = None, + finalize_end: Optional[float] = None, + num_rows: int = 0, ): """Helper to construct a Batch with specific stage timing windows.""" timings = BatchTimings() @@ -283,7 +283,7 @@ def test_accumulation_across_batches(self): assert stats.iter_rows_total == 30 def test_overlap_invariant_sum_leq_total(self): - """sum(iter_blocked_*) <= iter_total_blocked_s always holds.""" + """sum(iter_blocked_*) <= iter_total_blocked_s holds for non-overlapping stages.""" stats = DatasetStats(metadata={}, parent=None) it = _make_report_iterator(stats) stats.iter_total_blocked_s.add(5.0) @@ -372,9 +372,9 @@ class TestMergeFetch: def test_merge_single_block(self): """Merging a single block preserves its fetch window.""" dst = BatchTimings() - src = BatchTimings() - src.production_wait = TimeSpan(start_s=1.0, end_s=2.0) - dst.merge_fetch(src) + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) assert dst.production_wait.start_s == 1.0 assert dst.production_wait.end_s == 2.0 @@ -383,31 +383,28 @@ def test_merge_multiple_blocks_expands_window(self): dst = BatchTimings() # Block 1: fetched [1.0, 2.0] - src1 = BatchTimings() - src1.production_wait = TimeSpan(start_s=1.0, end_s=2.0) - dst.merge_fetch(src1) - + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) # Block 2: fetched [3.0, 4.0] - src2 = BatchTimings() - src2.production_wait = TimeSpan(start_s=3.0, end_s=4.0) - dst.merge_fetch(src2) - + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=3.0, end_s=4.0)) + ) # Block 3: fetched [5.0, 6.0] - src3 = BatchTimings() - src3.production_wait = TimeSpan(start_s=5.0, end_s=6.0) - dst.merge_fetch(src3) + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=5.0, end_s=6.0)) + ) # Union: [1.0, 6.0] assert dst.production_wait.start_s == 1.0 assert dst.production_wait.end_s == 6.0 def test_merge_unrecorded_block_ignored(self): - """Merging a block with no fetch timing (start_s=0) is a no-op.""" + """Merging a block with no fetch timing (both fields None) is a no-op.""" dst = BatchTimings() dst.production_wait = TimeSpan(start_s=2.0, end_s=3.0) - src = BatchTimings() # fetch defaults to (0.0, 0.0) - dst.merge_fetch(src) + dst.merge_fetch(BlockFetchTiming()) # fetch fields default to None assert dst.production_wait.start_s == 2.0 assert dst.production_wait.end_s == 3.0 @@ -416,13 +413,12 @@ def test_merge_overlapping_blocks(self): """Overlapping fetch windows are correctly merged.""" dst = BatchTimings() - src1 = BatchTimings() - src1.production_wait = TimeSpan(start_s=1.0, end_s=5.0) - dst.merge_fetch(src1) - - src2 = BatchTimings() - src2.production_wait = TimeSpan(start_s=3.0, end_s=7.0) - dst.merge_fetch(src2) + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=5.0)) + ) + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=3.0, end_s=7.0)) + ) # Union: [1.0, 7.0] assert dst.production_wait.start_s == 1.0 @@ -430,10 +426,10 @@ def test_merge_overlapping_blocks(self): def test_merge_into_empty_destination(self): """Merging into an empty BatchTimings takes the source window.""" - dst = BatchTimings() # fetch = (0.0, 0.0) - src = BatchTimings() - src.production_wait = TimeSpan(start_s=10.0, end_s=20.0) - dst.merge_fetch(src) + dst = BatchTimings() + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=10.0, end_s=20.0)) + ) assert dst.production_wait.start_s == 10.0 assert dst.production_wait.end_s == 20.0 diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index e2204fde354d..e1530983e0c8 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -1925,21 +1925,15 @@ def set(self, value, tags): "iter_batch_finalizing_s", "time_to_first_batch_s", "iter_total_blocked_s", - "iter_batches_total", - "iter_rows_total", - "iter_user_s", - ]: - setattr(actor, attr, FakeGauge(attr)) - - # Blocked gauges are stored as individual attributes (matching the - # other iteration gauges above). - for attr in [ "iter_blocked_production_wait_s", "iter_blocked_data_transfer_s", "iter_blocked_batching_s", "iter_blocked_format_s", "iter_blocked_collate_s", "iter_blocked_finalize_s", + "iter_batches_total", + "iter_rows_total", + "iter_user_s", ]: setattr(actor, attr, FakeGauge(attr)) @@ -2017,7 +2011,6 @@ def test_iter_stats_to_string_omits_zero_stages(): # Zero stages should not appear assert "batching" not in text assert "collate" not in text - assert "restore order" not in text def test_iter_stats_to_string_no_breakdown_when_all_zero(): @@ -2649,7 +2642,7 @@ def test_from_dict_handles_none_values(self): assert t.max() == 0.0 -class TestTimerTimerSpan: +class TestTimerSpan: """Tests for Timer.timer() returning a TimeSpan and accumulating.""" def test_timer_yields_timespan(self): From 2e8f64d1edeac90b2b2de02e123a42ed407ff289 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:42:26 +0800 Subject: [PATCH 27/32] [Data] Trim docstrings and comments for readability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review feedback noted the documentation read as agent-generated — excessive on trivial functions, missing on core ones, and heavy on Sphinx markup / em-dashes / "Note:" callouts that humans don't write. interfaces.py: - BlockFetchTiming: drop lifecycle sentence ("Produced by ... merged into ... by ..."). A 2-field dataclass doesn't need it. - BlockFetchResult: tighten the None-description. - BatchTimings: drop "overlap-based attribution" rationale and the :meth: cross-reference; keep just what the fields mean. - stages(): drop the redundant docstring restating the signature. - _merge_span: 5-line docstring → 1 line. The code is 7 lines. - BatchMetadata.timings: "Iteration-stage timing windows for this batch" → "Per-stage timing windows." stats.py: - IterationStage: drop the 3-line overlap-mechanism explanation (that's _attribute_blocked_time's job, not the enum's). - TimeSpan: drop "Created by Timer.timer and carried per-batch for overlap-based blocked attribution" — system-level rationale doesn't belong on a 2-field dataclass. - _maybe_time: 3-line docstring → 1 line; drop :class: markup and the em-dash aside. util.py: - resolve_block_refs: drop "Note:" callout prefix; drop :func: markup on prefetch_batches_locally.get_next_ref_bundle. - Inline comments: drop "(1)"/"(2)" numbering; compress the ray.get TODO to one line. iter_batches.py: - _iter_batches: add a short docstring (was missing). - _attribute_blocked_time: drop "recorded by background threads" implementation detail; drop the _iter_batches cross-ref; drop the "Interval lists would give precise..." suggestion from the TODO; plain-backtick Args instead of ``-wrapped. - get_next_batch_context / yield_batch_context: add short docstrings (were missing). Net: -22 docstring lines, no behavior change. black 22.10.0 + ruff clean; 38 affected tests pass. Signed-off-by: OneSizeFitsQuorum --- .../_internal/block_batching/interfaces.py | 34 ++++++------------- .../_internal/block_batching/iter_batches.py | 30 +++++++++------- .../ray/data/_internal/block_batching/util.py | 19 +++++------ python/ray/data/_internal/stats.py | 21 +++--------- 4 files changed, 41 insertions(+), 63 deletions(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index ee646982113b..ebcfed59bbc4 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -9,11 +9,7 @@ @dataclass class BlockFetchTiming: - """Fetch timing for a single block (production_wait + data_transfer). - - Produced by :func:`resolve_block_refs` and merged into - :class:`BatchTimings` by :class:`_BatchingIterator`. - """ + """Fetch timing for a single block (production_wait + data_transfer).""" production_wait: Optional[TimeSpan] = None data_transfer: Optional[TimeSpan] = None @@ -23,8 +19,8 @@ class BlockFetchTiming: class BlockFetchResult: """A resolved block paired with its per-block fetch timing. - When ``fetch`` is ``None``, no fetch timing was recorded (e.g. blocks - that were already resolved before entering the pipeline). + ``fetch`` is None when no timing was recorded (e.g. blocks already + resolved before entering the pipeline). """ block: Block @@ -33,16 +29,11 @@ class BlockFetchResult: @dataclass class BatchTimings: - """Per-batch iteration-stage timing windows for overlap-based attribution. - - Each field records the ``(start_s, end_s)`` wall-clock window during which - a particular iteration stage was active for this batch. The training thread - later compares these windows against its own blocked window to determine - how much each stage contributed to training-thread stall (see - :meth:`BatchIterator._attribute_blocked_time`). + """Per-batch timing windows for each iteration stage. - A field value of ``None`` indicates the stage did not execute for this - batch (e.g. no ``collate_fn`` provided). + Each field is the ``(start_s, end_s)`` window a stage was active, or + None if the stage didn't run. Compared against the training thread's + blocked window to attribute stall (see ``_attribute_blocked_time``). """ production_wait: Optional[TimeSpan] = None @@ -53,7 +44,7 @@ class BatchTimings: finalize: Optional[TimeSpan] = None def stages(self) -> Iterable[Tuple[IterationStage, Optional[TimeSpan]]]: - """Iterate over ``(stage, timing)`` pairs for all iteration stages.""" + """Yield (stage, timing) pairs.""" return ( (IterationStage.PRODUCTION_WAIT, self.production_wait), (IterationStage.DATA_TRANSFER, self.data_transfer), @@ -70,12 +61,7 @@ def merge_fetch(self, src: BlockFetchTiming) -> None: def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]: - """Merge two optional ``TimeSpan`` windows into a spanning window. - - Returns ``dst`` unchanged if ``src`` is ``None`` (stage didn't run). - Returns a copy of ``src`` if ``dst`` is ``None`` (first block). - Otherwise returns a new ``TimeSpan`` spanning both windows. - """ + """Return the union of two optional windows (or the non-None one).""" if src is None: return dst if dst is None: @@ -94,7 +80,7 @@ class BatchMetadata: batch_idx: The global index of this batch so that downstream operations can maintain ordering. num_rows: Number of rows in this batch (for ``iter_rows_total``). - timings: Iteration-stage timing windows for this batch. + timings: Per-stage timing windows. """ batch_idx: int diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 6d90087e0294..6711105ac6a7 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -243,6 +243,12 @@ def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: yield from batch_iter def _iter_batches(self) -> Iterator[DataBatch]: + """Pull batches from the pipeline and yield batch data. + + Captures the training thread's blocked window around each ``next()`` + call and attributes it to pipeline stages via + ``_attribute_blocked_time``. + """ batch_iter = iter_threaded(self._ref_bundles, fn=self._pipeline) self.before_epoch_start() @@ -266,23 +272,19 @@ def _attribute_blocked_time( ) -> None: """Attribute per-stage blocked time via overlap with the training window. - Each stage's window ``[timing.start_s, timing.end_s]`` (recorded by - background threads onto ``batch.metadata.timings``) is intersected - with the training thread's blocked window ``[blocked_start_s, - blocked_end_s]`` (captured around ``next()`` in ``_iter_batches``):: + Each stage's window on ``batch.metadata.timings`` is intersected with + the training thread's blocked window ``[blocked_start_s, blocked_end_s]``:: overlap = min(timing.end, blocked_end) - max(timing.start, blocked_start) - TODO: ``sum(iter_blocked_*)`` only approximates - ``iter_total_blocked_s`` — split fetch stages overlap for - multi-block batches, and reorder buffer wait under - ``preserve_order`` is unattributed. Interval lists would give - precise non-overlapping attribution. + TODO: ``sum(iter_blocked_*)`` only approximates ``iter_total_blocked_s`` + — split fetch stages overlap for multi-block batches, and reorder + buffer wait under ``preserve_order`` is unattributed. Args: - batch: The batch whose per-stage timings should be attributed. - blocked_start_s: ``perf_counter()`` just before ``next()``. - blocked_end_s: ``perf_counter()`` just after ``next()`` returned. + batch: Batch whose per-stage timings should be attributed. + blocked_start_s: perf_counter() just before next(). + blocked_end_s: perf_counter() just after next() returned. """ if self._stats is None: return @@ -316,6 +318,8 @@ def after_epoch_end(self): @contextmanager def get_next_batch_context(self): + """Context around ``next(batch_iter)``: tracks total blocked time + and time-to-first-batch.""" try: if self._stats: # Always track total blocked time @@ -335,6 +339,8 @@ def get_next_batch_context(self): @contextmanager def yield_batch_context(self, batch: Batch): + """Context around yielding a batch to the user: tracks user time + and periodically flushes metrics.""" with self._stats.iter_user_s.timer() if self._stats else nullcontext(): yield diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 4a7366a6bc54..6c01e82c442c 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -179,17 +179,17 @@ def resolve_block_refs( ) -> Iterator[BlockFetchResult]: """Resolves the block references for each logical batch. - Each resolved block is wrapped in a :class:`BlockFetchResult` that carries - the per-block fetch window. The fetch window spans from the moment we + Each resolved block is wrapped in a ``BlockFetchResult`` that carries + the per-block fetch window. The fetch window spans from the moment we start waiting for the upstream iterator (blocked on the data pipeline or cross-node transfer) until ``ray.get()`` returns the resolved block. When *stats* is provided, the cumulative fetch time is also recorded in ``stats.iter_get_s``. - Note: ``production_wait`` is captured for overlap attribution only - and does not accumulate into ``iter_get_ref_bundles_s`` — that - Timer is driven by :func:`prefetch_batches_locally.get_next_ref_bundle` - when prefetching is enabled; accumulating here would double-count. + ``production_wait`` is captured for attribution only and not accumulated + into ``iter_get_ref_bundles_s`` — that Timer is driven by + ``prefetch_batches_locally.get_next_ref_bundle`` when prefetch is enabled; + accumulating here would double-count. Args: block_ref_iter: An iterator over block object references. @@ -204,7 +204,7 @@ def resolve_block_refs( unknowns = 0 while True: - # (1) production_wait: upstream wait (not accumulated; see docstring). + # production_wait: upstream wait (not accumulated; see docstring). prod_start_s = time.perf_counter() if stats else 0.0 try: block_ref = next(block_ref_iter) @@ -219,9 +219,8 @@ def resolve_block_refs( misses += current_miss unknowns += current_unknown - # (2) data_transfer: cross-node transfer via ray.get(). - # TODO(amogkam): Optimized further by batching multiple references - # in a single `ray.get()` call. + # data_transfer: cross-node transfer via ray.get(). + # TODO(amogkam): batch multiple references in one ray.get() call. with _maybe_time(stats.iter_get_s if stats else None) as xfer_span: block = ray.get(block_ref) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 9998bdabacc4..9f426a21eabc 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -163,12 +163,8 @@ def get( class IterationStage(Enum): - """Stages of the iter_batches pipeline used to attribute training-thread - blocked time. Each stage's wall-clock window is overlapped with the - training thread's ``next()`` blocked window to credit stall to the - responsible stage (see ``_attribute_blocked_time``). - - Each value is the Prometheus label for the corresponding + """Stages of the iter_batches pipeline, used to attribute training-thread + blocked time. Each value is the Prometheus label for the corresponding ``data_iter_blocked__seconds`` gauge. """ @@ -182,12 +178,7 @@ class IterationStage(Enum): @dataclass class TimeSpan: - """A measured wall-clock interval. - - Created by :meth:`Timer.timer` and carried per-batch for overlap-based - blocked attribution. ``None`` (when used as ``Optional[TimeSpan]``) - indicates the stage did not execute. - """ + """A measured wall-clock interval (start_s, end_s).""" start_s: float = 0.0 end_s: float = 0.0 @@ -199,11 +190,7 @@ def duration(self) -> float: @contextmanager def _maybe_time(timer: Optional["Timer"]) -> Iterator[Optional[TimeSpan]]: - """Time a block of code, yielding a :class:`TimeSpan` (or ``None``). - - When *timer* is ``None`` (e.g. ``stats`` is not configured), yields - ``None`` and skips timing entirely — no ``perf_counter`` calls. - """ + """Time a block, yielding a TimeSpan (or None if timer is None).""" if timer is None: yield None else: From e4d5df97a7e4244613653226461f9f37fd326d20 Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:50:47 +0800 Subject: [PATCH 28/32] [Data] Drop circular "see docstring" pointer in resolve_block_refs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The inline comment "# production_wait: upstream wait (not accumulated; see docstring)" pointed back at the docstring above it — a circular reference that reads as filler. State the fact inline instead. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 6c01e82c442c..92198b9434ec 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -204,7 +204,7 @@ def resolve_block_refs( unknowns = 0 while True: - # production_wait: upstream wait (not accumulated; see docstring). + # production_wait: upstream wait (not accumulated here). prod_start_s = time.perf_counter() if stats else 0.0 try: block_ref = next(block_ref_iter) From 3d1694e098b1051d9803221a989fc63ceb08a48d Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:51:45 +0800 Subject: [PATCH 29/32] [Data] Drop circular "see _attribute_blocked_time" pointer in BatchTimings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same LLM-tell pattern as the "see docstring" pointer fixed in the previous commit — a "(see )" cross-reference that reads as filler. The reader can find the consumer by grepping. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/block_batching/interfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index ebcfed59bbc4..f98bde36e88a 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -33,7 +33,7 @@ class BatchTimings: Each field is the ``(start_s, end_s)`` window a stage was active, or None if the stage didn't run. Compared against the training thread's - blocked window to attribute stall (see ``_attribute_blocked_time``). + blocked window to attribute stall. """ production_wait: Optional[TimeSpan] = None From dc78b7b007525f22607113dc4332152cc0fda1ed Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Sun, 28 Jun 2026 23:53:56 +0800 Subject: [PATCH 30/32] [Data] Drop Sphinx markup from Timer.timer() docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous docstring cleanup pass missed this one — Timer.timer() still had ``:class:`TimeSpan`` and ``:meth:`add`` Sphinx markup while the rest of the PR's internal docstrings use plain backticks. Switched to plain backticks for consistency. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/_internal/stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 9f426a21eabc..5186c04f74d0 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -225,11 +225,11 @@ def __init__(self): @contextmanager def timer(self) -> Iterator[TimeSpan]: - """Time a block, yielding a fresh :class:`TimeSpan` per call. + """Time a block, yielding a fresh ``TimeSpan`` per call. The returned span is a distinct instance each call, so multiple threads sharing the same ``Timer`` don't race on span fields. - The duration is also accumulated into ``self`` via :meth:`add`. + The duration is also accumulated into ``self`` via ``add``. """ span = TimeSpan(start_s=time.perf_counter()) try: From 6ec8c268a8f60c88ac628577f082c946fb624e7b Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 29 Jun 2026 00:04:19 +0800 Subject: [PATCH 31/32] [Data] Naming and test cleanup from self-review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Naming (util.py): - prod_start_s / prod_span → production_wait_start / production_wait_span - xfer_span → data_transfer_span Variable names now match the BlockFetchTiming field names they're assigned to, so readers don't have to map abbreviations. Stale test names (test_iter_batches.py): - TestReportBatchTimingsEdgeCases → TestAttributeBlockedTimeEdgeCases (method was renamed _report_batch_timings → _attribute_blocked_time in an earlier commit; the test class name was missed). - _make_report_iterator → _make_test_iterator (same stale "report"). Test hygiene (test_util.py): - Move `from ray.data._internal.stats import DatasetStats` from inside two test functions to a top-level import. - Trim the 8-line regression-test docstring to 2 lines (the source docstring already has the rationale). - Drop the "Pairs with " cross-reference (brittle to renames, reads as filler). - Drop the weak `assert stats.iter_get_s.get() >= 0.0` (always true for non-negative Timer values; wasn't verifying anything). The structural checks (data_transfer span exists per block) remain. Test hygiene (test_stats.py): - test_each_call_returns_fresh_span: drop the unnecessary `spans` list; use s1/s2 directly. - test_maybe_time_skips_when_timer_none: docstring claimed "skips perf_counter entirely" but the test only checks `span is None`; dropped the unverified claim. All affected tests pass (39). black 22.10.0 + ruff clean. Signed-off-by: OneSizeFitsQuorum --- .../ray/data/_internal/block_batching/util.py | 15 ++++++---- .../tests/block_batching/test_iter_batches.py | 28 +++++++++---------- .../data/tests/block_batching/test_util.py | 27 ++++-------------- python/ray/data/tests/test_stats.py | 11 ++++---- 4 files changed, 34 insertions(+), 47 deletions(-) diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 92198b9434ec..47c5752efca2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -205,13 +205,15 @@ def resolve_block_refs( while True: # production_wait: upstream wait (not accumulated here). - prod_start_s = time.perf_counter() if stats else 0.0 + production_wait_start = time.perf_counter() if stats else 0.0 try: block_ref = next(block_ref_iter) except StopIteration: break - prod_span = ( - TimeSpan(start_s=prod_start_s, end_s=time.perf_counter()) if stats else None + production_wait_span = ( + TimeSpan(start_s=production_wait_start, end_s=time.perf_counter()) + if stats + else None ) current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) @@ -221,10 +223,13 @@ def resolve_block_refs( # data_transfer: cross-node transfer via ray.get(). # TODO(amogkam): batch multiple references in one ray.get() call. - with _maybe_time(stats.iter_get_s if stats else None) as xfer_span: + with _maybe_time(stats.iter_get_s if stats else None) as data_transfer_span: block = ray.get(block_ref) - fetch = BlockFetchTiming(production_wait=prod_span, data_transfer=xfer_span) + fetch = BlockFetchTiming( + production_wait=production_wait_span, + data_transfer=data_transfer_span, + ) yield BlockFetchResult(block=block, fetch=fetch) if stats: diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index 4b5fde3854b3..3bb0089bc290 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -184,20 +184,20 @@ def _make_batch_with_timings( return Batch(BatchMetadata(batch_idx=0, num_rows=num_rows, timings=timings), None) -def _make_report_iterator(stats): +def _make_test_iterator(stats): """Create a BatchIterator wired to the given stats without a real pipeline.""" it = BatchIterator.__new__(BatchIterator) it._stats = stats return it -class TestReportBatchTimingsEdgeCases: +class TestAttributeBlockedTimeEdgeCases: """Edge case tests for overlap-based blocked attribution.""" def test_zero_overlap_stage_finished_before_blocked(self): """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.5 ) @@ -207,7 +207,7 @@ def test_zero_overlap_stage_finished_before_blocked(self): def test_zero_overlap_blocked_before_stage(self): """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=1.0) assert stats.iter_blocked_format_s.get() == 0.0 @@ -215,7 +215,7 @@ def test_zero_overlap_blocked_before_stage(self): def test_partial_overlap(self): """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=2.0 ) @@ -225,7 +225,7 @@ def test_partial_overlap(self): def test_full_overlap_stage_inside_blocked(self): """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) @@ -233,7 +233,7 @@ def test_full_overlap_stage_inside_blocked(self): def test_no_collate_fn_zero_attribution(self): """collate stage has start_s=0 → skipped, 0 attribution.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) @@ -242,7 +242,7 @@ def test_no_collate_fn_zero_attribution(self): def test_no_finalize_fn_zero_attribution(self): """finalize stage has start_s=0 → skipped, 0 attribution.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) @@ -251,7 +251,7 @@ def test_no_finalize_fn_zero_attribution(self): def test_prefetch_hides_fetch_from_training(self): """Effective prefetch: fetch done before training blocks → 0 fetch attribution.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.5, @@ -266,7 +266,7 @@ def test_prefetch_hides_fetch_from_training(self): def test_accumulation_across_batches(self): """Two batches each contribute to fetch — values accumulate.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 b1 = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.0, num_rows=10 @@ -285,7 +285,7 @@ def test_accumulation_across_batches(self): def test_overlap_invariant_sum_leq_total(self): """sum(iter_blocked_*) <= iter_total_blocked_s holds for non-overlapping stages.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) stats.iter_total_blocked_s.add(5.0) batch = _make_batch_with_timings( production_wait_start=0.0, @@ -311,7 +311,7 @@ def test_overlap_invariant_sum_leq_total(self): def test_blocked_inside_stage(self): """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=10.0 ) @@ -321,7 +321,7 @@ def test_blocked_inside_stage(self): def test_all_stages_simultaneous_overlap(self): """Multiple stages overlap with blocked window simultaneously.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) batch = _make_batch_with_timings( production_wait_start=0.0, production_wait_end=1.0, @@ -527,7 +527,7 @@ def test_batch_carries_timings_through_pipeline(self): def test_full_pipeline_attribution(self): """End-to-end: all 5 stages with realistic timing, full overlap.""" stats = DatasetStats(metadata={}, parent=None) - it = _make_report_iterator(stats) + it = _make_test_iterator(stats) stats.iter_total_blocked_s.add(5.0) batch = _make_batch_with_timings( diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index d78cd9e7cf4c..ec40f8623553 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -27,6 +27,7 @@ iter_threaded, resolve_block_refs, ) +from ray.data._internal.stats import DatasetStats from ray.data._internal.util import make_async_gen logger = logging.getLogger(__file__) @@ -49,16 +50,8 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_resolve_block_refs_does_not_accumulate_ref_bundles_timer( ray_start_regular_shared, ): - """Regression test for nested-timer double-counting. - - ``resolve_block_refs`` must NOT accumulate into - ``iter_get_ref_bundles_s``. When prefetching is enabled, that Timer is - already driven by ``prefetch_batches_locally.get_next_ref_bundle``; - accumulating here too would double-count the upstream wait. The - ``production_wait`` TimeSpan is still captured per block for overlap - attribution. - """ - from ray.data._internal.stats import DatasetStats + """Regression test: resolve_block_refs must not accumulate into + iter_get_ref_bundles_s (prefetch_batches_locally owns that Timer).""" def slow_block_ref_iter(): for i in range(3): @@ -83,15 +76,8 @@ def slow_block_ref_iter(): def test_resolve_block_refs_accumulates_data_transfer_timer( ray_start_regular_shared, ): - """``resolve_block_refs`` accumulates ``ray.get()`` time into - ``iter_get_s`` (the data_transfer stage) and captures a per-block - ``data_transfer`` TimeSpan for overlap attribution. - - Pairs with ``test_resolve_block_refs_does_not_accumulate_ref_bundles_timer`` - which verifies the production_wait path does NOT accumulate. - """ - from ray.data._internal.stats import DatasetStats - + """resolve_block_refs accumulates ray.get() time into iter_get_s and + captures a per-block data_transfer TimeSpan.""" block_refs = [ray.put(i) for i in range(3)] stats = DatasetStats(metadata={}, parent=None) @@ -105,9 +91,6 @@ def test_resolve_block_refs_accumulates_data_transfer_timer( assert r.fetch.data_transfer is not None assert r.fetch.data_transfer.duration >= 0.0 - # iter_get_s (ray.get time) IS accumulated by resolve_block_refs. - assert stats.iter_get_s.get() >= 0.0 - @pytest.mark.parametrize("block_size", [1, 10]) @pytest.mark.parametrize("drop_last", [True, False]) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index e1530983e0c8..00a48db461d1 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -2660,16 +2660,15 @@ def test_timer_yields_timespan(self): def test_each_call_returns_fresh_span(self): """Each timer() call yields a distinct TimeSpan instance.""" t = Timer() - spans = [] with t.timer() as s1: - spans.append(s1) + pass with t.timer() as s2: - spans.append(s2) - assert spans[0] is not spans[1] - assert t.get() == pytest.approx(spans[0].duration + spans[1].duration, rel=0.5) + pass + assert s1 is not s2 + assert t.get() == pytest.approx(s1.duration + s2.duration, rel=0.5) def test_maybe_time_skips_when_timer_none(self): - """_maybe_time(None) yields None and skips perf_counter entirely.""" + """_maybe_time(None) yields None.""" with _maybe_time(None) as span: assert span is None assert span is None From 9519cdf6fb8a1f9660665eeb774a32b31fc3bb1d Mon Sep 17 00:00:00 2001 From: OneSizeFitsQuorum Date: Mon, 29 Jun 2026 09:30:47 +0800 Subject: [PATCH 32/32] [Data] Fix import sort order in test_stats.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ruff's isort check (--select I) flagged the imports I added (IterationStage, TimeSpan, _maybe_time) as out of alphabetical order. pre-commit runs `ruff --select I --fix --exit-non-zero-on-fix`, which fixes and exits non-zero — failing microcheck CI. Reordered to satisfy ruff's case-insensitive isort. Signed-off-by: OneSizeFitsQuorum --- python/ray/data/tests/test_stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 00a48db461d1..96a6b914d4f8 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -32,14 +32,14 @@ from ray.data._internal.stats import ( DatasetStats, DatasetStatsSummary, + IterationStage, NodeMetrics, OperatorStatsSummary, - IterationStage, StatsSummary, Timer, TimeSpan, - _StatsActor, _maybe_time, + _StatsActor, get_or_create_stats_actor, ) from ray.data._internal.util import MemoryProfiler