From eada52345c6ee06db3898dfef4e79153d7ddf67c Mon Sep 17 00:00:00 2001 From: Grain Team Date: Fri, 10 Apr 2026 06:45:00 -0700 Subject: [PATCH] Fix concurrent read determinism and file descriptor leaks in ArrayRecordDataSource. PiperOrigin-RevId: 897665704 --- CHANGELOG.md | 13 +++++++++++++ grain/_src/python/data_sources.py | 25 +++++++++++++++---------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f68835e14..000e8c978 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,21 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change * Adds profiling of multiprocess workers when using XProf profiler. To enable, set flag `grain_enable_multiprocess_worker_profiling=true` and add `"profile_subprocesses" = True` in advanced profiler options. + * Configures automatic, thread-safe reader connection pooling + (`BoundedReaderPool`) per shard inside `ArrayRecordDataSource` to support + high-performance, multi-threaded parallel dataset prefetching without file + descriptor exhaustion. Exposes safe context manager connection lease API + `borrow()` and custom configuration parameter `reader_pool_size` / + `grain_reader_pool_size` flag. * Breaking changes: + * Upgrades `ArrayRecordDataSource` to implement the new + `RandomAccessDataSource` single-indexing protocol. The standard index + method `__getitem__` now accepts only a single `SupportsIndex` index key + (returning a single byte string). Caller threads performing sequential + batch loading must migrate execution to the batch method `__getitems__` + which continues to perform highly-optimized direct counting-sort parallel + fetches. * Deprecations: diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index f885daeb8..5e7983e9b 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs): PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] ] + ArrayRecordReaderOptions = dict[str, str] | None @@ -71,6 +72,7 @@ def __init__( self, paths: ArrayRecordDataSourcePaths, reader_options: ArrayRecordReaderOptions = None, + reader_pool_size: int | None = None, ): """Creates a new ArrayRecordDataSource object. @@ -82,18 +84,21 @@ def __init__( example, {index_storage_option:"in_memory"} stores the reader indices in memory versus {index_storage_option:"offloaded"} stores the indices on disk to save memory usage. + reader_pool_size: The number of readers to pool per shard. """ array_record_signature = inspect.signature(ARDataSource.__init__) - if "reader_options" in array_record_signature.parameters: - super().__init__(paths, reader_options) - elif reader_options is not None: - # Reader options should not be set if they are not supported by the - # current version of ArrayRecord. - raise ValueError( - "reader_options is not supported in this version of ArrayRecord." - ) - else: - super().__init__(paths) + kwargs = {} + if ( + "reader_options" in array_record_signature.parameters + and reader_options is not None + ): + kwargs["reader_options"] = reader_options + if ( + "reader_pool_size" in array_record_signature.parameters + and reader_pool_size is not None + ): + kwargs["reader_pool_size"] = reader_pool_size + super().__init__(paths, **kwargs) _api_usage_counter.Increment("ArrayRecordDataSource") @dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ)