Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,21 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
* Adds profiling of multiprocess workers when using XProf profiler. To enable,
set flag `grain_enable_multiprocess_worker_profiling=true` and add
`"profile_subprocesses" = True` in advanced profiler options.
* Configures automatic, thread-safe reader connection pooling
(`BoundedReaderPool`) per shard inside `ArrayRecordDataSource` to support
high-performance, multi-threaded parallel dataset prefetching without file
descriptor exhaustion. Exposes safe context manager connection lease API
`borrow()` and custom configuration parameter `reader_pool_size` /
`grain_reader_pool_size` flag.

* Breaking changes:
* Upgrades `ArrayRecordDataSource` to implement the new
`RandomAccessDataSource` single-indexing protocol. The standard index
method `__getitem__` now accepts only a single `SupportsIndex` index key
(returning a single byte string). Caller threads performing sequential
batch loading must migrate execution to the batch method `__getitems__`
which continues to perform highly-optimized direct counting-sort parallel
fetches.

* Deprecations:

Expand Down
25 changes: 15 additions & 10 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, *args, **kwargs):
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
]


ArrayRecordReaderOptions = dict[str, str] | None


Expand All @@ -71,6 +72,7 @@ def __init__(
self,
paths: ArrayRecordDataSourcePaths,
reader_options: ArrayRecordReaderOptions = None,
reader_pool_size: int | None = None,
):
"""Creates a new ArrayRecordDataSource object.

Expand All @@ -82,18 +84,21 @@ def __init__(
example, {index_storage_option:"in_memory"} stores the reader indices in
memory versus {index_storage_option:"offloaded"} stores the indices on
disk to save memory usage.
reader_pool_size: The number of readers to pool per shard.
"""
array_record_signature = inspect.signature(ARDataSource.__init__)
if "reader_options" in array_record_signature.parameters:
super().__init__(paths, reader_options)
elif reader_options is not None:
# Reader options should not be set if they are not supported by the
# current version of ArrayRecord.
raise ValueError(
"reader_options is not supported in this version of ArrayRecord."
)
else:
super().__init__(paths)
kwargs = {}
if (
"reader_options" in array_record_signature.parameters
and reader_options is not None
):
kwargs["reader_options"] = reader_options
if (
"reader_pool_size" in array_record_signature.parameters
and reader_pool_size is not None
):
kwargs["reader_pool_size"] = reader_pool_size
super().__init__(paths, **kwargs)
_api_usage_counter.Increment("ArrayRecordDataSource")

@dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ)
Expand Down
Loading