diff --git a/python/ray/data/_internal/block_batching/block_batching.py b/python/ray/data/_internal/block_batching/block_batching.py index ef54a593920b..268347d92f3a 100644 --- a/python/ray/data/_internal/block_batching/block_batching.py +++ b/python/ray/data/_internal/block_batching/block_batching.py @@ -1,5 +1,6 @@ from typing import Callable, Iterator, Optional, TypeVar +from ray.data._internal.block_batching.interfaces import BlockFetchResult from ray.data._internal.block_batching.util import ( _MappingIterator, blocks_to_batches, @@ -29,10 +30,19 @@ def batch_blocks( This function takes in an iterator of already fetched blocks. Consequently, this function doesn't support block prefetching. """ + # Wrap raw blocks in BlockFetchResult with no fetch timing (these + # blocks were already resolved before entering the pipeline). + # Use map() instead of a generator expression to avoid holding + # references to blocks. + # + # TODO: make fetch timing optional at the _BatchingIterator level so + # this BlockFetchResult wrapping shim can be removed. + wrapped_blocks = map(lambda b: BlockFetchResult(block=b), blocks) + # Build the processing pipeline batch_iter = format_batches( blocks_to_batches( - block_iter=blocks, + block_iter=wrapped_blocks, stats=stats, batch_size=batch_size, drop_last=drop_last, diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 4f0bed6b3dd4..f98bde36e88a 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -1,11 +1,77 @@ import abc -from dataclasses import dataclass -from typing import Any, List +from dataclasses import dataclass, field +from typing import Any, Iterable, List, Optional, Tuple +from ray.data._internal.stats import IterationStage, TimeSpan from ray.data.block import Block, DataBatch from ray.types import ObjectRef +@dataclass +class BlockFetchTiming: + """Fetch timing for a single block (production_wait + data_transfer).""" + + production_wait: Optional[TimeSpan] = None + data_transfer: Optional[TimeSpan] = None + + +@dataclass +class BlockFetchResult: + """A resolved block paired with its per-block fetch timing. + + ``fetch`` is None when no timing was recorded (e.g. blocks already + resolved before entering the pipeline). + """ + + block: Block + fetch: Optional[BlockFetchTiming] = None + + +@dataclass +class BatchTimings: + """Per-batch timing windows for each iteration stage. + + Each field is the ``(start_s, end_s)`` window a stage was active, or + None if the stage didn't run. Compared against the training thread's + blocked window to attribute stall. + """ + + production_wait: Optional[TimeSpan] = None + data_transfer: Optional[TimeSpan] = None + batching: Optional[TimeSpan] = None + format: Optional[TimeSpan] = None + collate: Optional[TimeSpan] = None + finalize: Optional[TimeSpan] = None + + def stages(self) -> Iterable[Tuple[IterationStage, Optional[TimeSpan]]]: + """Yield (stage, timing) pairs.""" + return ( + (IterationStage.PRODUCTION_WAIT, self.production_wait), + (IterationStage.DATA_TRANSFER, self.data_transfer), + (IterationStage.BATCHING, self.batching), + (IterationStage.FORMAT, self.format), + (IterationStage.COLLATE, self.collate), + (IterationStage.FINALIZE, self.finalize), + ) + + def merge_fetch(self, src: BlockFetchTiming) -> None: + """Merge per-block fetch timings into this batch's fetch windows.""" + self.production_wait = _merge_span(self.production_wait, src.production_wait) + self.data_transfer = _merge_span(self.data_transfer, src.data_transfer) + + +def _merge_span(dst: Optional[TimeSpan], src: Optional[TimeSpan]) -> Optional[TimeSpan]: + """Return the union of two optional windows (or the non-None one).""" + if src is None: + return dst + if dst is None: + return TimeSpan(start_s=src.start_s, end_s=src.end_s) + return TimeSpan( + start_s=min(dst.start_s, src.start_s), + end_s=max(dst.end_s, src.end_s), + ) + + @dataclass class BatchMetadata: """Metadata associated with a batch. @@ -13,9 +79,13 @@ class BatchMetadata: Attributes: batch_idx: The global index of this batch so that downstream operations can maintain ordering. + num_rows: Number of rows in this batch (for ``iter_rows_total``). + timings: Per-stage timing windows. """ batch_idx: int + num_rows: int = 0 + timings: BatchTimings = field(default_factory=BatchTimings) @dataclass diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index f9bf0076d2af..6711105ac6a7 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -175,7 +175,7 @@ def _prefetch_blocks( def _resolve_block_refs( self, block_refs: Iterator[ObjectRef[Block]] - ) -> Iterator[Block]: + ) -> Iterator[Any]: return resolve_block_refs(block_ref_iter=block_refs, stats=self._stats) def _blocks_to_batches(self, blocks: Iterator[Block]) -> Iterator[Batch]: @@ -243,21 +243,63 @@ def _pipeline(self, ref_bundles: Iterator[RefBundle]) -> Iterator[Batch]: yield from batch_iter def _iter_batches(self) -> Iterator[DataBatch]: + """Pull batches from the pipeline and yield batch data. + + Captures the training thread's blocked window around each ``next()`` + call and attributes it to pipeline stages via + ``_attribute_blocked_time``. + """ batch_iter = iter_threaded(self._ref_bundles, fn=self._pipeline) self.before_epoch_start() while True: with self.get_next_batch_context(): + blocked_start_s = time.perf_counter() try: batch = next(batch_iter) except StopIteration: break + blocked_end_s = time.perf_counter() + self._attribute_blocked_time(batch, blocked_start_s, blocked_end_s) with self.yield_batch_context(batch): yield batch.data self.after_epoch_end() + def _attribute_blocked_time( + self, batch: Batch, blocked_start_s: float, blocked_end_s: float + ) -> None: + """Attribute per-stage blocked time via overlap with the training window. + + Each stage's window on ``batch.metadata.timings`` is intersected with + the training thread's blocked window ``[blocked_start_s, blocked_end_s]``:: + + overlap = min(timing.end, blocked_end) - max(timing.start, blocked_start) + + TODO: ``sum(iter_blocked_*)`` only approximates ``iter_total_blocked_s`` + — split fetch stages overlap for multi-block batches, and reorder + buffer wait under ``preserve_order`` is unattributed. + + Args: + batch: Batch whose per-stage timings should be attributed. + blocked_start_s: perf_counter() just before next(). + blocked_end_s: perf_counter() just after next() returned. + """ + if self._stats is None: + return + timings = batch.metadata.timings + for stage, timing in timings.stages(): + if timing is None: + continue + overlap_s = min(timing.end_s, blocked_end_s) - max( + timing.start_s, blocked_start_s + ) + if overlap_s > 0: + self._stats.get_blocked_timer(stage).add(overlap_s) + self._stats.iter_batches_total += 1 + self._stats.iter_rows_total += batch.metadata.num_rows + def __iter__(self) -> Iterator[DataBatch]: return self._iter_batches() @@ -276,6 +318,8 @@ def after_epoch_end(self): @contextmanager def get_next_batch_context(self): + """Context around ``next(batch_iter)``: tracks total blocked time + and time-to-first-batch.""" try: if self._stats: # Always track total blocked time @@ -295,6 +339,8 @@ def get_next_batch_context(self): @contextmanager def yield_batch_context(self, batch: Batch): + """Context around yielding a batch to the user: tracks user time + and periodically flushes metrics.""" with self._stats.iter_user_s.timer() if self._stats else nullcontext(): yield diff --git a/python/ray/data/_internal/block_batching/util.py b/python/ray/data/_internal/block_batching/util.py index 8a42cde7871e..47c5752efca2 100644 --- a/python/ray/data/_internal/block_batching/util.py +++ b/python/ray/data/_internal/block_batching/util.py @@ -3,7 +3,7 @@ import logging import queue import threading -from contextlib import nullcontext +import time from typing import ( Any, Callable, @@ -22,10 +22,13 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, + BlockFetchResult, + BlockFetchTiming, BlockPrefetcher, CollatedBatch, ) -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, TimeSpan, _maybe_time from ray.data.block import Block, BlockAccessor, DataBatch from ray.types import ObjectRef @@ -173,31 +176,61 @@ def _calculate_ref_hits(refs: List[ObjectRef[Any]]) -> Tuple[int, int, int]: def resolve_block_refs( block_ref_iter: Iterator[ObjectRef[Block]], stats: Optional[DatasetStats] = None, -) -> Iterator[Block]: +) -> Iterator[BlockFetchResult]: """Resolves the block references for each logical batch. + Each resolved block is wrapped in a ``BlockFetchResult`` that carries + the per-block fetch window. The fetch window spans from the moment we + start waiting for the upstream iterator (blocked on the data pipeline or + cross-node transfer) until ``ray.get()`` returns the resolved block. + When *stats* is provided, the cumulative fetch time is also recorded in + ``stats.iter_get_s``. + + ``production_wait`` is captured for attribution only and not accumulated + into ``iter_get_ref_bundles_s`` — that Timer is driven by + ``prefetch_batches_locally.get_next_ref_bundle`` when prefetch is enabled; + accumulating here would double-count. + Args: block_ref_iter: An iterator over block object references. - stats: An optional stats object to recording block hits and misses. + stats: An optional stats object to record block hits, misses, and + cumulative fetch time. Yields: - Block: The resolved blocks for each block reference. + BlockFetchResult: Each resolved block with its fetch timing window. """ hits = 0 misses = 0 unknowns = 0 - for block_ref in block_ref_iter: + while True: + # production_wait: upstream wait (not accumulated here). + production_wait_start = time.perf_counter() if stats else 0.0 + try: + block_ref = next(block_ref_iter) + except StopIteration: + break + production_wait_span = ( + TimeSpan(start_s=production_wait_start, end_s=time.perf_counter()) + if stats + else None + ) + current_hit, current_miss, current_unknown = _calculate_ref_hits([block_ref]) hits += current_hit misses += current_miss unknowns += current_unknown - # TODO(amogkam): Optimized further by batching multiple references in a single - # `ray.get()` call. - with stats.iter_get_s.timer() if stats else nullcontext(): + # data_transfer: cross-node transfer via ray.get(). + # TODO(amogkam): batch multiple references in one ray.get() call. + with _maybe_time(stats.iter_get_s if stats else None) as data_transfer_span: block = ray.get(block_ref) - yield block + + fetch = BlockFetchTiming( + production_wait=production_wait_span, + data_transfer=data_transfer_span, + ) + yield BlockFetchResult(block=block, fetch=fetch) if stats: stats.iter_blocks_local = hits @@ -206,7 +239,7 @@ def resolve_block_refs( def blocks_to_batches( - block_iter: Iterator[Block], + block_iter: Iterator[BlockFetchResult], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -235,7 +268,7 @@ class _BatchingIterator(Iterator[Batch]): def __init__( self, - block_iter: Iterator[Block], + block_iter: Iterator[BlockFetchResult], stats: Optional[DatasetStats] = None, batch_size: Optional[int] = None, drop_last: bool = False, @@ -248,6 +281,8 @@ def __init__( self._drop_last = drop_last self._global_counter = 0 self._done_adding = False + # Accumulates per-block fetch timings until a batch is yielded. + self._pending_timings = BatchTimings() if shuffle_buffer_min_size is not None: self._batcher = ShufflingBatcher( @@ -262,8 +297,6 @@ def __iter__(self) -> "_BatchingIterator": return self def __next__(self) -> Batch: - timer = self._stats.iter_next_batch_s.timer() if self._stats else nullcontext() - # Try to get a batch from current batcher state while True: can_yield = self._batcher.has_batch() or ( @@ -271,13 +304,21 @@ def __next__(self) -> Batch: ) if can_yield: - with timer: + with _maybe_time( + self._stats.iter_next_batch_s if self._stats else None + ) as span: next_batch = self._batcher.next_batch() + self._pending_timings.batching = span res = Batch( - metadata=BatchMetadata(batch_idx=self._global_counter), + metadata=BatchMetadata( + batch_idx=self._global_counter, + num_rows=BlockAccessor.for_block(next_batch).num_rows(), + timings=self._pending_timings, + ), data=next_batch, ) + self._pending_timings = BatchTimings() self._global_counter += 1 return res @@ -286,8 +327,10 @@ def __next__(self) -> Batch: # If can't yield try adding more blocks try: # NOTE: Block ref is released immediately - block = next(self._block_iter) - self._batcher.add(block) + block_result = next(self._block_iter) + if block_result.fetch is not None: + self._pending_timings.merge_fetch(block_result.fetch) + self._batcher.add(block_result.block) except StopIteration: self._batcher.done_adding() self._done_adding = True @@ -306,12 +349,13 @@ def _format_batch( stats: Optional[DatasetStats], ensure_copy: bool = False, ) -> Batch: - with stats.iter_format_batch_s.timer() if stats else nullcontext(): + with _maybe_time(stats.iter_format_batch_s if stats else None) as span: formatted_data = BlockAccessor.for_block(batch.data).to_batch_format( batch_format ) if ensure_copy: formatted_data = _copy_batch(formatted_data) + batch.metadata.timings.format = span return dataclasses.replace(batch, data=formatted_data) @@ -359,8 +403,9 @@ def _collate_batch( collate_fn: Callable[[DataBatch], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with stats.iter_collate_batch_s.timer() if stats else nullcontext(): + with _maybe_time(stats.iter_collate_batch_s if stats else None) as span: collated_data = collate_fn(batch.data) + batch.metadata.timings.collate = span return CollatedBatch(metadata=batch.metadata, data=collated_data) @@ -384,8 +429,9 @@ def _finalize_batch( finalize_fn: Callable[[Any], Any], stats: Optional[DatasetStats], ) -> CollatedBatch: - with stats.iter_finalize_batch_s.timer() if stats else nullcontext(): + with _maybe_time(stats.iter_finalize_batch_s if stats else None) as span: finalized_data = finalize_fn(batch.data) + batch.metadata.timings.finalize = span return dataclasses.replace(batch, data=finalized_data) diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index 618ed8d6a376..5186c04f74d0 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -5,11 +5,13 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, fields +from enum import Enum from typing import ( TYPE_CHECKING, Any, DefaultDict, Dict, + Iterator, List, Mapping, Optional, @@ -160,6 +162,42 @@ def get( ) +class IterationStage(Enum): + """Stages of the iter_batches pipeline, used to attribute training-thread + blocked time. Each value is the Prometheus label for the corresponding + ``data_iter_blocked__seconds`` gauge. + """ + + PRODUCTION_WAIT = "production_wait" # waiting on upstream data production + DATA_TRANSFER = "data_transfer" # cross-node ray.get() transfer + BATCHING = "batching" # slicing/shuffling blocks into batches + FORMAT = "format" # converting blocks to batch format + COLLATE = "collate" # applying user collate_fn + FINALIZE = "finalize" # applying user finalize_fn + + +@dataclass +class TimeSpan: + """A measured wall-clock interval (start_s, end_s).""" + + start_s: float = 0.0 + end_s: float = 0.0 + + @property + def duration(self) -> float: + return self.end_s - self.start_s + + +@contextmanager +def _maybe_time(timer: Optional["Timer"]) -> Iterator[Optional[TimeSpan]]: + """Time a block, yielding a TimeSpan (or None if timer is None).""" + if timer is None: + yield None + else: + with timer.timer() as span: + yield span + + class Timer: """Helper class for tracking accumulated time (in seconds). @@ -186,12 +224,19 @@ def __init__(self): self._distribution: DistributionTracker = DistributionTracker() @contextmanager - def timer(self) -> None: - time_start = time.perf_counter() + def timer(self) -> Iterator[TimeSpan]: + """Time a block, yielding a fresh ``TimeSpan`` per call. + + The returned span is a distinct instance each call, so multiple + threads sharing the same ``Timer`` don't race on span fields. + The duration is also accumulated into ``self`` via ``add``. + """ + span = TimeSpan(start_s=time.perf_counter()) try: - yield + yield span finally: - self.add(time.perf_counter() - time_start) + span.end_s = time.perf_counter() + self.add(span.duration) def add(self, value: float) -> None: self._total += value @@ -449,6 +494,9 @@ def __init__(self, max_stats=1000): self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics() iter_tag_keys = ("dataset",) + # TODO: add a per-streaming-split-worker ``rank`` label to iteration + # metrics so users can distinguish which split worker blocked on + # which stage. self.time_to_first_batch_s = Gauge( "data_iter_time_to_first_batch_seconds", @@ -488,6 +536,51 @@ def __init__(self, max_stats=1000): description="Seconds user thread is blocked by iter_batches()", tag_keys=iter_tag_keys, ) + self.iter_total_s = Gauge( + "data_iter_total_seconds", + description="Total wall-clock seconds spent in the dataset iterator", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_production_wait_s = Gauge( + "data_iter_blocked_production_wait_seconds", + description="Seconds user thread is blocked on upstream data production", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_data_transfer_s = Gauge( + "data_iter_blocked_data_transfer_seconds", + description="Seconds user thread is blocked on cross-node data transfer (ray.get)", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_batching_s = Gauge( + "data_iter_blocked_batching_seconds", + description="Seconds user thread is blocked on batch creation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_format_s = Gauge( + "data_iter_blocked_format_seconds", + description="Seconds user thread is blocked on batch formatting", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_collate_s = Gauge( + "data_iter_blocked_collate_seconds", + description="Seconds user thread is blocked on batch collation", + tag_keys=iter_tag_keys, + ) + self.iter_blocked_finalize_s = Gauge( + "data_iter_blocked_finalize_seconds", + description="Seconds user thread is blocked on batch finalization", + tag_keys=iter_tag_keys, + ) + self.iter_batches_total = Gauge( + "data_iter_batches_total", + description="Total batches delivered to the user thread", + tag_keys=iter_tag_keys, + ) + self.iter_rows_total = Gauge( + "data_iter_rows_total", + description="Total rows delivered to the user thread", + tag_keys=iter_tag_keys, + ) self.iter_user_s = Gauge( "data_iter_user_seconds", description="Seconds spent in user code", @@ -728,6 +821,7 @@ def update_iteration_metrics( tags = self._create_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) + self.iter_total_s.set(stats.iter_total_s.get(), tags) self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags) self.iter_get_s.set(stats.iter_get_s.get(), tags) self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags) @@ -748,6 +842,26 @@ def update_iteration_metrics( self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags) self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags) + self.iter_blocked_production_wait_s.set( + stats.get_blocked_timer(IterationStage.PRODUCTION_WAIT).get(), tags + ) + self.iter_blocked_data_transfer_s.set( + stats.get_blocked_timer(IterationStage.DATA_TRANSFER).get(), tags + ) + self.iter_blocked_batching_s.set( + stats.get_blocked_timer(IterationStage.BATCHING).get(), tags + ) + self.iter_blocked_format_s.set( + stats.get_blocked_timer(IterationStage.FORMAT).get(), tags + ) + self.iter_blocked_collate_s.set( + stats.get_blocked_timer(IterationStage.COLLATE).get(), tags + ) + self.iter_blocked_finalize_s.set( + stats.get_blocked_timer(IterationStage.FINALIZE).get(), tags + ) + self.iter_batches_total.set(stats.iter_batches_total, tags) + self.iter_rows_total.set(stats.iter_rows_total, tags) self.iter_user_s.set(stats.iter_user_s.get(), tags) def register_dataset( @@ -1138,9 +1252,17 @@ def __init__( self.iter_finalize_batch_s: Timer = Timer() self.iter_time_to_first_batch_s: Timer = Timer() self.iter_total_blocked_s: Timer = Timer() + self.iter_blocked_production_wait_s: Timer = Timer() + self.iter_blocked_data_transfer_s: Timer = Timer() + self.iter_blocked_batching_s: Timer = Timer() + self.iter_blocked_format_s: Timer = Timer() + self.iter_blocked_collate_s: Timer = Timer() + self.iter_blocked_finalize_s: Timer = Timer() self.iter_user_s: Timer = Timer() self.iter_initialize_s: Timer = Timer() self.iter_total_s: Timer = Timer() + self.iter_batches_total: int = 0 + self.iter_rows_total: int = 0 self.extra_metrics = {} # Block fetch stats during iteration. @@ -1162,6 +1284,22 @@ def __init__( # Streaming split coordinator stats (dataset level) self.streaming_split_coordinator_s: Timer = Timer() + def get_blocked_timer(self, stage: IterationStage) -> Timer: + """Return the blocked-attribution Timer for the given iteration stage.""" + match stage: + case IterationStage.PRODUCTION_WAIT: + return self.iter_blocked_production_wait_s + case IterationStage.DATA_TRANSFER: + return self.iter_blocked_data_transfer_s + case IterationStage.BATCHING: + return self.iter_blocked_batching_s + case IterationStage.FORMAT: + return self.iter_blocked_format_s + case IterationStage.COLLATE: + return self.iter_blocked_collate_s + case IterationStage.FINALIZE: + return self.iter_blocked_finalize_s + @property def stats_actor(self): return get_or_create_stats_actor() @@ -1196,6 +1334,14 @@ def to_summary(self) -> "DatasetStatsSummary": self.iter_blocks_remote, self.iter_unknown_location, self.iter_prefetched_bytes, + self.iter_blocked_production_wait_s, + self.iter_blocked_data_transfer_s, + self.iter_blocked_batching_s, + self.iter_blocked_format_s, + self.iter_blocked_collate_s, + self.iter_blocked_finalize_s, + self.iter_batches_total, + self.iter_rows_total, ) stats_summary_parents = [] @@ -1878,6 +2024,16 @@ class IterStatsSummary: iter_unknown_location: int # Current bytes of prefetched blocks in the iterator iter_prefetched_bytes: int + # Per-stage training-thread blocked attribution timers. + blocked_production_wait_time: Timer + blocked_data_transfer_time: Timer + blocked_batching_time: Timer + blocked_format_time: Timer + blocked_collate_time: Timer + blocked_finalize_time: Timer + # Cumulative batch and row counters. + batches_total: int + rows_total: int def __str__(self) -> str: return self.to_string() @@ -1984,6 +2140,25 @@ def to_string(self) -> str: out += "Streaming split coordinator overhead time: " out += f"{fmt(self.streaming_split_coord_time.get())}\n" + # Per-stage training-thread blocked attribution. + stage_totals = [ + ("production wait", self.blocked_production_wait_time), + ("data transfer (ray.get)", self.blocked_data_transfer_time), + ("batching", self.blocked_batching_time), + ("format", self.blocked_format_time), + ("collate", self.blocked_collate_time), + ("finalize (host->device)", self.blocked_finalize_time), + ] + active_stages = [(name, t) for name, t in stage_totals if t.get() > 0] + if active_stages: + out += "\nPer-stage training-thread blocked time breakdown:\n" + for stage_name, timer in active_stages: + out += " * {}: {}\n".format(stage_name, fmt(timer.get())) + if self.batches_total: + out += "Total batches consumed: {}\n".format(self.batches_total) + if self.rows_total: + out += "Total rows consumed: {}\n".format(self.rows_total) + return out def __repr__(self, level=0) -> str: diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 1f69962435ad..d2a0275f366e 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -22,7 +22,7 @@ from ray.data._internal.execution.interfaces import RefBundle from ray.data._internal.logical.interfaces import LogicalPlan from ray.data._internal.logical.operators import InputData -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, _StatsManager from ray.data.block import BlockAccessor, DataBatch, _apply_batch_format from ray.data.collate_fn import ( ArrowBatchCollateFn, @@ -295,9 +295,14 @@ def callback(num_bytes: int) -> None: try: yield from batch_iterator + finally: + # Runs on both normal completion and early exit (e.g. `break` + # in the training loop). `iter_total_s` and the final metrics + # flush must happen here so partial iteration is still + # recorded; `_on_iteration_end` shuts down the executor after. if stats: stats.iter_total_s.add(time.perf_counter() - time_start) - finally: + _StatsManager.update_iteration_metrics(stats, dataset_tag) # On early exit (e.g. ``break`` in the for-loop), the inner # ``_ClosingIterator`` would only shut down the executor via # its ``__del__``, which is non-deterministic. The hook diff --git a/python/ray/data/tests/block_batching/test_iter_batches.py b/python/ray/data/tests/block_batching/test_iter_batches.py index b0b55d1cb8fb..3bb0089bc290 100644 --- a/python/ray/data/tests/block_batching/test_iter_batches.py +++ b/python/ray/data/tests/block_batching/test_iter_batches.py @@ -1,7 +1,7 @@ import queue import threading import time -from typing import Iterator, List +from typing import Iterator, List, Optional import pandas as pd import pyarrow as pa @@ -11,6 +11,8 @@ from ray.data._internal.block_batching.interfaces import ( Batch, BatchMetadata, + BatchTimings, + BlockFetchTiming, BlockPrefetcher, ) from ray.data._internal.block_batching.iter_batches import ( @@ -20,7 +22,7 @@ ) from ray.data._internal.block_batching.util import WaitBlockPrefetcher from ray.data._internal.execution.interfaces.ref_bundle import BlockEntry, RefBundle -from ray.data._internal.stats import DatasetStats +from ray.data._internal.stats import DatasetStats, IterationStage, TimeSpan from ray.data.block import Block, BlockMetadata from ray.types import ObjectRef @@ -114,6 +116,458 @@ def test_restore_from_original_order(): assert idx == [0, 1, 2, 3] +def test_restore_original_order_stats(): + base_iterator = [ + Batch(BatchMetadata(batch_idx=2), None), + Batch(BatchMetadata(batch_idx=0), None), + Batch(BatchMetadata(batch_idx=1), None), + ] + + ordered = list(restore_original_order(iter(base_iterator))) + + assert [batch.metadata.batch_idx for batch in ordered] == [0, 1, 2] + + +def test_attribute_blocked_time_overlap_attribution(): + stats = DatasetStats(metadata={}, parent=None) + batch_iterator = BatchIterator(iter([]), stats=stats) + timings = BatchTimings() + timings.production_wait = TimeSpan(start_s=10.0, end_s=20.0) + timings.batching = TimeSpan(start_s=20.0, end_s=30.0) + timings.format = TimeSpan(start_s=30.0, end_s=40.0) + timings.finalize = TimeSpan(start_s=50.0, end_s=60.0) + batch = Batch(BatchMetadata(batch_idx=0, num_rows=8, timings=timings), None) + + batch_iterator._attribute_blocked_time( + batch, blocked_start_s=15.0, blocked_end_s=35.0 + ) + + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(10.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(5.0) + assert stats.iter_blocked_collate_s.get() == 0 + assert stats.iter_blocked_finalize_s.get() == 0 + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 8 + + +def _make_span(start: Optional[float], end: Optional[float]) -> Optional[TimeSpan]: + """Create a TimeSpan, or None if the stage didn't execute (either is None).""" + if start is None or end is None: + return None + return TimeSpan(start_s=start, end_s=end) + + +def _make_batch_with_timings( + production_wait_start: Optional[float] = None, + production_wait_end: Optional[float] = None, + data_transfer_start: Optional[float] = None, + data_transfer_end: Optional[float] = None, + batching_start: Optional[float] = None, + batching_end: Optional[float] = None, + format_start: Optional[float] = None, + format_end: Optional[float] = None, + collate_start: Optional[float] = None, + collate_end: Optional[float] = None, + finalize_start: Optional[float] = None, + finalize_end: Optional[float] = None, + num_rows: int = 0, +): + """Helper to construct a Batch with specific stage timing windows.""" + timings = BatchTimings() + timings.production_wait = _make_span(production_wait_start, production_wait_end) + timings.data_transfer = _make_span(data_transfer_start, data_transfer_end) + timings.batching = _make_span(batching_start, batching_end) + timings.format = _make_span(format_start, format_end) + timings.collate = _make_span(collate_start, collate_end) + timings.finalize = _make_span(finalize_start, finalize_end) + return Batch(BatchMetadata(batch_idx=0, num_rows=num_rows, timings=timings), None) + + +def _make_test_iterator(stats): + """Create a BatchIterator wired to the given stats without a real pipeline.""" + it = BatchIterator.__new__(BatchIterator) + it._stats = stats + return it + + +class TestAttributeBlockedTimeEdgeCases: + """Edge case tests for overlap-based blocked attribution.""" + + def test_zero_overlap_stage_finished_before_blocked(self): + """Fetch [0, 1.5] finished before training blocked at t=2 → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=1.5 + ) + it._attribute_blocked_time(batch, blocked_start_s=2.0, blocked_end_s=3.0) + assert stats.iter_blocked_production_wait_s.get() == 0.0 + + def test_zero_overlap_blocked_before_stage(self): + """Training blocked [0, 1], stage ran [2, 3] → 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings(format_start=2.0, format_end=3.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=1.0) + assert stats.iter_blocked_format_s.get() == 0.0 + + def test_partial_overlap(self): + """Fetch [0, 2], blocked [1, 3] → overlap = min(2,3)-max(0,1) = 1.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=2.0 + ) + it._attribute_blocked_time(batch, blocked_start_s=1.0, blocked_end_s=3.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) + + def test_full_overlap_stage_inside_blocked(self): + """Stage [1, 2] entirely inside blocked [0, 3] → full 1.0 credit.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings(batching_start=1.0, batching_end=2.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + + def test_no_collate_fn_zero_attribution(self): + """collate stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings(format_start=1.0, format_end=2.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == 0.0 + + def test_no_finalize_fn_zero_attribution(self): + """finalize stage has start_s=0 → skipped, 0 attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings(collate_start=1.0, collate_end=2.0) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=3.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == 0.0 + + def test_prefetch_hides_fetch_from_training(self): + """Effective prefetch: fetch done before training blocks → 0 fetch attribution.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings( + production_wait_start=0.0, + production_wait_end=1.5, + collate_start=2.3, + collate_end=2.6, + ) + # Training only starts blocking at t=2 (prefetch worked) + it._attribute_blocked_time(batch, blocked_start_s=2.0, blocked_end_s=2.6) + assert stats.iter_blocked_production_wait_s.get() == 0.0 + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.3) + + def test_accumulation_across_batches(self): + """Two batches each contribute to fetch — values accumulate.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + # Batch 1: fetch [0,1], blocked [0,2] → overlap 1.0 + b1 = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=1.0, num_rows=10 + ) + it._attribute_blocked_time(b1, blocked_start_s=0.0, blocked_end_s=2.0) + # Batch 2: fetch [5,6], blocked [5,7] → overlap 1.0 + b2 = _make_batch_with_timings( + production_wait_start=5.0, production_wait_end=6.0, num_rows=20 + ) + it._attribute_blocked_time(b2, blocked_start_s=5.0, blocked_end_s=7.0) + + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 30 + + def test_overlap_invariant_sum_leq_total(self): + """sum(iter_blocked_*) <= iter_total_blocked_s holds for non-overlapping stages.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + batch = _make_batch_with_timings( + production_wait_start=0.0, + production_wait_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + num_rows=5, + ) + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + total = stats.iter_total_blocked_s.get() + sum_stages = ( + stats.iter_blocked_production_wait_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + ) + assert sum_stages <= total + 1e-9 + + def test_blocked_inside_stage(self): + """Stage [0, 10] fully contains blocked [3, 5] → overlap = 2.0.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings( + production_wait_start=0.0, production_wait_end=10.0 + ) + it._attribute_blocked_time(batch, blocked_start_s=3.0, blocked_end_s=5.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(2.0) + + def test_all_stages_simultaneous_overlap(self): + """Multiple stages overlap with blocked window simultaneously.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + batch = _make_batch_with_timings( + production_wait_start=0.0, + production_wait_end=1.0, + batching_start=1.0, + batching_end=2.0, + format_start=2.0, + format_end=3.0, + collate_start=3.0, + collate_end=4.0, + finalize_start=4.0, + finalize_end=5.0, + num_rows=100, + ) + # Blocked window covers all stages + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_batching_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(1.0) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 100 + + +class TestTimeSpan: + """Tests for TimeSpan dataclass.""" + + def test_default_values(self): + """Default TimeSpan has start_s=0 and end_s=0.""" + t = TimeSpan() + assert t.start_s == 0.0 + assert t.end_s == 0.0 + + def test_duration(self): + """Duration is end_s - start_s.""" + t = TimeSpan(start_s=1.0, end_s=3.5) + assert t.duration == pytest.approx(2.5) + + def test_zero_duration(self): + """Default TimeSpan has zero duration.""" + t = TimeSpan() + assert t.duration == 0.0 + + +class TestMergeFetch: + """Tests for BatchTimings.merge_fetch() with multiple blocks per batch.""" + + def test_merge_single_block(self): + """Merging a single block preserves its fetch window.""" + dst = BatchTimings() + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 2.0 + + def test_merge_multiple_blocks_expands_window(self): + """Merging multiple blocks produces the union window.""" + dst = BatchTimings() + + # Block 1: fetched [1.0, 2.0] + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) + # Block 2: fetched [3.0, 4.0] + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=3.0, end_s=4.0)) + ) + # Block 3: fetched [5.0, 6.0] + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=5.0, end_s=6.0)) + ) + + # Union: [1.0, 6.0] + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 6.0 + + def test_merge_unrecorded_block_ignored(self): + """Merging a block with no fetch timing (both fields None) is a no-op.""" + dst = BatchTimings() + dst.production_wait = TimeSpan(start_s=2.0, end_s=3.0) + + dst.merge_fetch(BlockFetchTiming()) # fetch fields default to None + + assert dst.production_wait.start_s == 2.0 + assert dst.production_wait.end_s == 3.0 + + def test_merge_overlapping_blocks(self): + """Overlapping fetch windows are correctly merged.""" + dst = BatchTimings() + + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=5.0)) + ) + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=3.0, end_s=7.0)) + ) + + # Union: [1.0, 7.0] + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 7.0 + + def test_merge_into_empty_destination(self): + """Merging into an empty BatchTimings takes the source window.""" + dst = BatchTimings() + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=10.0, end_s=20.0)) + ) + assert dst.production_wait.start_s == 10.0 + assert dst.production_wait.end_s == 20.0 + + def test_merge_data_transfer_multiple_blocks(self): + """data_transfer windows are unioned across multiple blocks.""" + dst = BatchTimings() + + src1 = BlockFetchTiming(data_transfer=TimeSpan(start_s=1.0, end_s=2.0)) + dst.merge_fetch(src1) + + src2 = BlockFetchTiming(data_transfer=TimeSpan(start_s=3.0, end_s=4.0)) + dst.merge_fetch(src2) + + # Union: [1.0, 4.0] + assert dst.data_transfer.start_s == 1.0 + assert dst.data_transfer.end_s == 4.0 + + def test_merge_data_transfer_overlapping_blocks(self): + """Overlapping data_transfer windows are correctly merged.""" + dst = BatchTimings() + + dst.merge_fetch( + BlockFetchTiming(data_transfer=TimeSpan(start_s=1.0, end_s=5.0)) + ) + dst.merge_fetch( + BlockFetchTiming(data_transfer=TimeSpan(start_s=3.0, end_s=7.0)) + ) + + assert dst.data_transfer.start_s == 1.0 + assert dst.data_transfer.end_s == 7.0 + + def test_merge_both_stages_independent(self): + """production_wait and data_transfer merge independently.""" + dst = BatchTimings() + + # Block 1: prod [1,2], xfer [2,3] + dst.merge_fetch( + BlockFetchTiming( + production_wait=TimeSpan(start_s=1.0, end_s=2.0), + data_transfer=TimeSpan(start_s=2.0, end_s=3.0), + ) + ) + # Block 2: prod [5,6], xfer [6,7] + dst.merge_fetch( + BlockFetchTiming( + production_wait=TimeSpan(start_s=5.0, end_s=6.0), + data_transfer=TimeSpan(start_s=6.0, end_s=7.0), + ) + ) + + # Each stage unions independently. + assert dst.production_wait.start_s == 1.0 + assert dst.production_wait.end_s == 6.0 + assert dst.data_transfer.start_s == 2.0 + assert dst.data_transfer.end_s == 7.0 + + def test_merge_data_transfer_none_preserves_destination(self): + """Merging a block with no data_transfer timing leaves dst unchanged.""" + dst = BatchTimings() + dst.data_transfer = TimeSpan(start_s=2.0, end_s=3.0) + + # src has only production_wait, data_transfer is None + dst.merge_fetch( + BlockFetchTiming(production_wait=TimeSpan(start_s=1.0, end_s=2.0)) + ) + + assert dst.data_transfer.start_s == 2.0 + assert dst.data_transfer.end_s == 3.0 + + +class TestEndToEndTimingPropagation: + """Tests that stage timings propagate correctly through the full pipeline.""" + + def test_batch_carries_timings_through_pipeline(self): + """A Batch's metadata.timings carries all stage windows.""" + timings = BatchTimings() + timings.production_wait = TimeSpan(start_s=1.0, end_s=2.0) + timings.batching = TimeSpan(start_s=2.0, end_s=3.0) + timings.format = TimeSpan(start_s=3.0, end_s=4.0) + timings.collate = TimeSpan(start_s=4.0, end_s=5.0) + timings.finalize = TimeSpan(start_s=5.0, end_s=6.0) + + batch = Batch(BatchMetadata(batch_idx=0, num_rows=50, timings=timings), None) + + # Verify all stages are accessible via stages() iterator + stage_dict = dict(batch.metadata.timings.stages()) + assert len(stage_dict) == 6 + assert stage_dict[IterationStage.PRODUCTION_WAIT].start_s == 1.0 + assert stage_dict[IterationStage.BATCHING].end_s == 3.0 + assert stage_dict[IterationStage.FORMAT].start_s == 3.0 + assert stage_dict[IterationStage.COLLATE].end_s == 5.0 + assert stage_dict[IterationStage.FINALIZE].start_s == 5.0 + assert batch.metadata.num_rows == 50 + + def test_full_pipeline_attribution(self): + """End-to-end: all 5 stages with realistic timing, full overlap.""" + stats = DatasetStats(metadata={}, parent=None) + it = _make_test_iterator(stats) + stats.iter_total_blocked_s.add(5.0) + + batch = _make_batch_with_timings( + production_wait_start=0.0, + production_wait_end=0.5, + batching_start=0.5, + batching_end=1.0, + format_start=1.0, + format_end=2.0, + collate_start=2.0, + collate_end=2.5, + finalize_start=2.5, + finalize_end=3.0, + num_rows=256, + ) + + # Blocked window covers all stages + it._attribute_blocked_time(batch, blocked_start_s=0.0, blocked_end_s=5.0) + + # Each stage gets its full duration + assert stats.iter_blocked_production_wait_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_batching_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_format_s.get() == pytest.approx(1.0) + assert stats.iter_blocked_collate_s.get() == pytest.approx(0.5) + assert stats.iter_blocked_finalize_s.get() == pytest.approx(0.5) + assert stats.iter_batches_total == 1 + assert stats.iter_rows_total == 256 + + # Invariant: sum = 3.0 <= total_blocked = 5.0 + sum_stages = ( + stats.iter_blocked_production_wait_s.get() + + stats.iter_blocked_batching_s.get() + + stats.iter_blocked_format_s.get() + + stats.iter_blocked_collate_s.get() + + stats.iter_blocked_finalize_s.get() + ) + assert sum_stages == pytest.approx(3.0) + assert sum_stages <= stats.iter_total_blocked_s.get() + 1e-9 + + def test_finalize_fn_uses_single_thread(ray_start_regular_shared): """Tests that finalize_fn is not run with multiple threads.""" ref_bundles_iter = ref_bundle_generator(num_blocks=20, num_rows=2) @@ -194,6 +648,27 @@ def collate_fn(batch: pd.DataFrame): assert concat_df["foo"].iloc[i + 1] >= concat_df["foo"].iloc[i] +def test_iter_batches_counts_rows_at_pipeline_exit(ray_start_regular_shared): + stats = DatasetStats(metadata={}, parent=None) + ref_bundles_iter = ref_bundle_generator(num_blocks=4, num_rows=2) + + output_batches = list( + BatchIterator( + ref_bundles_iter, + stats=stats, + batch_size=3, + prefetch_batches=0, + batch_format="pandas", + drop_last=True, + ) + ) + + assert len(output_batches) == 2 + assert [len(batch) for batch in output_batches] == [3, 3] + assert stats.iter_batches_total == 2 + assert stats.iter_rows_total == 6 + + def test_iter_batches_e2e_async(ray_start_regular_shared): """We add time.sleep in 3 places: 1. In the base generator to simulate streaming executor blocking on next results. diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index 6ead5741f0e1..ec40f8623553 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -13,7 +13,11 @@ import pytest import ray -from ray.data._internal.block_batching.interfaces import Batch, BatchMetadata +from ray.data._internal.block_batching.interfaces import ( + Batch, + BatchMetadata, + BlockFetchResult, +) from ray.data._internal.block_batching.util import ( _calculate_ref_hits, blocks_to_batches, @@ -23,6 +27,7 @@ iter_threaded, resolve_block_refs, ) +from ray.data._internal.stats import DatasetStats from ray.data._internal.util import make_async_gen logger = logging.getLogger(__file__) @@ -37,7 +42,54 @@ def test_resolve_block_refs(ray_start_regular_shared): block_refs = [ray.put(0), ray.put(1), ray.put(2)] resolved_iter = resolve_block_refs(iter(block_refs)) - assert list(resolved_iter) == [0, 1, 2] + resolved = list(resolved_iter) + assert all(isinstance(b, BlockFetchResult) for b in resolved) + assert [b.block for b in resolved] == [0, 1, 2] + + +def test_resolve_block_refs_does_not_accumulate_ref_bundles_timer( + ray_start_regular_shared, +): + """Regression test: resolve_block_refs must not accumulate into + iter_get_ref_bundles_s (prefetch_batches_locally owns that Timer).""" + + def slow_block_ref_iter(): + for i in range(3): + time.sleep(0.05) + yield ray.put(i) + + stats = DatasetStats(metadata={}, parent=None) + resolved = list(resolve_block_refs(slow_block_ref_iter(), stats=stats)) + + assert len(resolved) == 3 + + # production_wait TimeSpan captured per block for overlap attribution. + for r in resolved: + assert r.fetch is not None + assert r.fetch.production_wait is not None + assert r.fetch.production_wait.duration >= 0.0 + + # iter_get_ref_bundles_s must NOT be accumulated here. + assert stats.iter_get_ref_bundles_s.get() == 0.0 + + +def test_resolve_block_refs_accumulates_data_transfer_timer( + ray_start_regular_shared, +): + """resolve_block_refs accumulates ray.get() time into iter_get_s and + captures a per-block data_transfer TimeSpan.""" + block_refs = [ray.put(i) for i in range(3)] + + stats = DatasetStats(metadata={}, parent=None) + resolved = list(resolve_block_refs(iter(block_refs), stats=stats)) + + assert len(resolved) == 3 + + # data_transfer TimeSpan captured per block. + for r in resolved: + assert r.fetch is not None + assert r.fetch.data_transfer is not None + assert r.fetch.data_transfer.duration >= 0.0 @pytest.mark.parametrize("block_size", [1, 10]) @@ -45,10 +97,12 @@ def test_resolve_block_refs(ray_start_regular_shared): def test_blocks_to_batches(block_size, drop_last): num_blocks = 5 block_iter = block_generator(num_rows=block_size, num_blocks=num_blocks) + # Wrap raw blocks in BlockFetchResult (fetch=None) as blocks_to_batches now expects + wrapped_blocks = (BlockFetchResult(block=b) for b in block_iter) batch_size = 3 batch_iter = list( - blocks_to_batches(block_iter, batch_size=batch_size, drop_last=drop_last) + blocks_to_batches(wrapped_blocks, batch_size=batch_size, drop_last=drop_last) ) if drop_last: diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index cb4c31553541..96a6b914d4f8 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -32,10 +32,13 @@ from ray.data._internal.stats import ( DatasetStats, DatasetStatsSummary, + IterationStage, NodeMetrics, OperatorStatsSummary, StatsSummary, Timer, + TimeSpan, + _maybe_time, _StatsActor, get_or_create_stats_actor, ) @@ -1878,6 +1881,147 @@ def test_stats_actor_iter_metrics(): assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1] +def test_update_iteration_metrics_exports_new_iter_metrics(): + stats = DatasetStats(metadata={}, parent=None) + stats.iter_total_s.add(11.0) + stats.iter_blocked_production_wait_s.add(1.0) + stats.iter_blocked_data_transfer_s.add(1.5) + stats.iter_blocked_batching_s.add(2.0) + stats.iter_blocked_format_s.add(3.0) + stats.iter_blocked_collate_s.add(4.0) + stats.iter_blocked_finalize_s.add(5.0) + stats.iter_batches_total = 7 + stats.iter_rows_total = 8 + + actor = _StatsActor.__ray_metadata__.modified_class.__new__( + _StatsActor.__ray_metadata__.modified_class + ) + recorded = {} + + class FakeGauge: + def __init__(self, name): + self.name = name + + def set(self, value, tags): + recorded[self.name] = (value, tags) + + for attr in [ + "iter_initialize_s", + "iter_total_s", + "iter_get_ref_bundles_s", + "iter_get_s", + "iter_next_batch_s", + "iter_format_batch_s", + "iter_collate_batch_s", + "iter_finalize_batch_s", + "iter_blocks_local", + "iter_blocks_remote", + "iter_unknown_location", + "iter_prefetched_bytes", + "iter_block_fetching_s", + "iter_batch_shaping_s", + "iter_batch_formatting_s", + "iter_batch_collating_s", + "iter_batch_finalizing_s", + "time_to_first_batch_s", + "iter_total_blocked_s", + "iter_blocked_production_wait_s", + "iter_blocked_data_transfer_s", + "iter_blocked_batching_s", + "iter_blocked_format_s", + "iter_blocked_collate_s", + "iter_blocked_finalize_s", + "iter_batches_total", + "iter_rows_total", + "iter_user_s", + ]: + setattr(actor, attr, FakeGauge(attr)) + + actor.update_iteration_metrics(stats, "train_dataset_split_3") + + expected_tags = {"dataset": "train_dataset_split_3"} + assert recorded["iter_total_s"] == (11.0, expected_tags) + assert recorded["iter_blocked_production_wait_s"] == (1.0, expected_tags) + assert recorded["iter_blocked_data_transfer_s"] == (1.5, expected_tags) + assert recorded["iter_blocked_batching_s"] == (2.0, expected_tags) + assert recorded["iter_blocked_format_s"] == (3.0, expected_tags) + assert recorded["iter_blocked_collate_s"] == (4.0, expected_tags) + assert recorded["iter_blocked_finalize_s"] == (5.0, expected_tags) + assert recorded["iter_batches_total"] == (7, expected_tags) + assert recorded["iter_rows_total"] == (8, expected_tags) + + +def test_iter_stats_summary_has_new_fields(): + """IterStatsSummary includes per-stage blocked timers and counters.""" + stats = DatasetStats(metadata={}, parent=None) + summary = stats.to_summary() + iter_summary = summary.iter_stats + + assert hasattr(iter_summary, "blocked_production_wait_time") + assert hasattr(iter_summary, "blocked_data_transfer_time") + assert hasattr(iter_summary, "blocked_batching_time") + assert hasattr(iter_summary, "blocked_format_time") + assert hasattr(iter_summary, "blocked_collate_time") + assert hasattr(iter_summary, "blocked_finalize_time") + assert hasattr(iter_summary, "batches_total") + assert hasattr(iter_summary, "rows_total") + + +def test_iter_stats_summary_reflects_accumulated_values(): + """IterStatsSummary carries the accumulated timer values.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_production_wait_s.add(0.5) + stats.iter_blocked_batching_s.add(0.2) + stats.iter_batches_total = 10 + stats.iter_rows_total = 320 + + summary = stats.to_summary().iter_stats + assert summary.blocked_production_wait_time.get() == pytest.approx(0.5) + assert summary.blocked_data_transfer_time.get() == pytest.approx(0.0) + assert summary.blocked_batching_time.get() == pytest.approx(0.2) + assert summary.batches_total == 10 + assert summary.rows_total == 320 + + +def test_iter_stats_to_string_shows_stage_breakdown(): + """to_string() renders per-stage breakdown when values are non-zero.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_production_wait_s.add(1.5) + stats.iter_blocked_format_s.add(0.8) + stats.iter_batches_total = 5 + stats.iter_rows_total = 160 + stats.iter_total_blocked_s.add(2.3) + + text = str(stats.to_summary().iter_stats) + assert "production wait" in text + assert "format" in text + assert "Total batches consumed: 5" in text + assert "Total rows consumed: 160" in text + assert "Per-stage training-thread blocked time breakdown" in text + + +def test_iter_stats_to_string_omits_zero_stages(): + """to_string() omits stages with zero values from the breakdown.""" + stats = DatasetStats(metadata={}, parent=None) + stats.iter_blocked_production_wait_s.add(0.5) + stats.iter_total_blocked_s.add(0.5) + + text = str(stats.to_summary().iter_stats) + assert "production wait" in text + # Zero stages should not appear + assert "batching" not in text + assert "collate" not in text + + +def test_iter_stats_to_string_no_breakdown_when_all_zero(): + """When all blocked_* stages are zero, no breakdown section appears.""" + stats = DatasetStats(metadata={}, parent=None) + text = str(stats.to_summary().iter_stats) + assert "Per-stage training-thread blocked time breakdown" not in text + assert "Total batches consumed" not in text + assert "Total rows consumed" not in text + + def test_dataset_name_and_id(): # Test deprecated APIs: _set_name and _name ds = ray.data.range(1) @@ -2498,6 +2642,72 @@ def test_from_dict_handles_none_values(self): assert t.max() == 0.0 +class TestTimerSpan: + """Tests for Timer.timer() returning a TimeSpan and accumulating.""" + + def test_timer_yields_timespan(self): + """timer() yields a fresh TimeSpan whose duration is accumulated.""" + t = Timer() + with t.timer() as span: + time.sleep(0.01) + assert isinstance(span, TimeSpan) + assert span.duration > 0 + # The span's duration is accumulated into the Timer. + assert t.get() == pytest.approx(span.duration, rel=0.5) + assert t.max() > 0 + assert t.min() == pytest.approx(span.duration, rel=0.5) + + def test_each_call_returns_fresh_span(self): + """Each timer() call yields a distinct TimeSpan instance.""" + t = Timer() + with t.timer() as s1: + pass + with t.timer() as s2: + pass + assert s1 is not s2 + assert t.get() == pytest.approx(s1.duration + s2.duration, rel=0.5) + + def test_maybe_time_skips_when_timer_none(self): + """_maybe_time(None) yields None.""" + with _maybe_time(None) as span: + assert span is None + assert span is None + + def test_maybe_time_yields_span_when_timer_given(self): + """_maybe_time(Timer) yields a TimeSpan backed by the Timer.""" + t = Timer() + with _maybe_time(t) as span: + time.sleep(0.01) + assert isinstance(span, TimeSpan) + assert span.duration > 0 + assert t.get() == pytest.approx(span.duration, rel=0.5) + + +@pytest.mark.parametrize( + "stage,attr", + [ + (IterationStage.PRODUCTION_WAIT, "iter_blocked_production_wait_s"), + (IterationStage.DATA_TRANSFER, "iter_blocked_data_transfer_s"), + (IterationStage.BATCHING, "iter_blocked_batching_s"), + (IterationStage.FORMAT, "iter_blocked_format_s"), + (IterationStage.COLLATE, "iter_blocked_collate_s"), + (IterationStage.FINALIZE, "iter_blocked_finalize_s"), + ], +) +class TestGetBlockedTimer: + """Tests for DatasetStats.get_blocked_timer() stage->Timer mapping.""" + + def test_get_blocked_timer_returns_correct_attribute(self, stage, attr): + """get_blocked_timer(stage) returns the Timer matching the stage.""" + stats = DatasetStats(metadata={}, parent=None) + assert stats.get_blocked_timer(stage) is getattr(stats, attr) + + def test_get_blocked_timer_returns_timer_instance(self, stage, attr): + """get_blocked_timer returns a real Timer (not None).""" + stats = DatasetStats(metadata={}, parent=None) + assert isinstance(stats.get_blocked_timer(stage), Timer) + + def test_streaming_exec_schedule_percentiles_populated(ray_start_regular_shared): # KLL-sketch percentile tracking is always on (bounded memory), so # the percentile fields are populated end-to-end with no env-var