From 9ab257183e2c46cd492fac93fd00dc95fcd4683c Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Tue, 30 Jun 2026 16:14:19 -0700 Subject: [PATCH] Fix multi-host ImageNet sharding when subset sizes are used PiperOrigin-RevId: 940713067 --- init2winit/dataset_lib/imagenet_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/init2winit/dataset_lib/imagenet_dataset.py b/init2winit/dataset_lib/imagenet_dataset.py index 5193ae40..5e0eb585 100644 --- a/init2winit/dataset_lib/imagenet_dataset.py +++ b/init2winit/dataset_lib/imagenet_dataset.py @@ -309,7 +309,12 @@ def load_split( # entirely to the end of it on the last host, because otherwise we will drop # the last `{train,valid}_size % split_size` elements. if jax.process_index() == jax.process_count() - 1: - end = '' + if split in ['train', 'eval_train']: + end = hps.train_size + elif split == 'valid': + end = hps.valid_size + else: + end = hps.test_size logging.info('Loaded data [%d: %s] from %s', start, str(end), split) if split in ['train', 'eval_train']: