diff --git a/init2winit/dataset_lib/imagenet_dataset.py b/init2winit/dataset_lib/imagenet_dataset.py index 5193ae40..5e0eb585 100644 --- a/init2winit/dataset_lib/imagenet_dataset.py +++ b/init2winit/dataset_lib/imagenet_dataset.py @@ -309,7 +309,12 @@ def load_split( # entirely to the end of it on the last host, because otherwise we will drop # the last `{train,valid}_size % split_size` elements. if jax.process_index() == jax.process_count() - 1: - end = '' + if split in ['train', 'eval_train']: + end = hps.train_size + elif split == 'valid': + end = hps.valid_size + else: + end = hps.test_size logging.info('Loaded data [%d: %s] from %s', start, str(end), split) if split in ['train', 'eval_train']: