diff --git a/docs/source/myrtlespeech/data/index.rst b/docs/source/myrtlespeech/data/index.rst index 46c185c1..5f4a72e8 100644 --- a/docs/source/myrtlespeech/data/index.rst +++ b/docs/source/myrtlespeech/data/index.rst @@ -9,3 +9,4 @@ alphabet dataset/index preprocess + sampler diff --git a/docs/source/myrtlespeech/data/sampler.rst b/docs/source/myrtlespeech/data/sampler.rst new file mode 100644 index 00000000..e8eedbe9 --- /dev/null +++ b/docs/source/myrtlespeech/data/sampler.rst @@ -0,0 +1,16 @@ +========= + sampler +========= + +.. automodule:: myrtlespeech.data.sampler + :members: + :show-inheritance: + +.. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler + :members: + :show-inheritance: + + +.. autoclass:: myrtlespeech.data.sampler.SortaGrad + :members: + :show-inheritance: diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index 9ad0c092..7d914610 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -1,11 +1,13 @@ import multiprocessing from typing import Tuple +from typing import Union import torch from myrtlespeech.builders.dataset import build as build_dataset from myrtlespeech.builders.speech_to_text import build as build_stt from myrtlespeech.data.batch import seq_to_seq_collate_fn -from myrtlespeech.data.sampler import RandomBatchSampler +from myrtlespeech.data.sampler import SequentialRandomSampler +from myrtlespeech.data.sampler import SortaGrad from myrtlespeech.model.seq_to_seq import SeqToSeq from myrtlespeech.protos import task_config_pb2 @@ -114,15 +116,35 @@ def target_transform(target): add_seq_len_to_transforms=True, ) - shuffle = task_config.train_config.shuffle_batches_before_every_epoch - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_sampler=RandomBatchSampler( + shuffle_str = task_config.train_config.WhichOneof("shuffle_strategy") + batch_sampler: Union[SortaGrad, SequentialRandomSampler] + if shuffle_str == "sorta_grad": + batch_sampler = SortaGrad( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=True, + drop_last=False, + ) + elif shuffle_str == "random_batches": + batch_sampler = SequentialRandomSampler( indices=range(len(train_dataset)), batch_size=task_config.train_config.batch_size, - shuffle=shuffle, + shuffle=True, drop_last=False, - ), + ) + elif shuffle_str == "sequential_batches": + batch_sampler = SequentialRandomSampler( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=False, + drop_last=False, + ) + else: + raise ValueError(f"unsupported shuffle strategy {shuffle_str}") + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=seq_to_seq_collate_fn, pin_memory=torch.cuda.is_available(), diff --git a/src/myrtlespeech/configs/deep_speech_1_en.config b/src/myrtlespeech/configs/deep_speech_1_en.config index 0d3e7f79..ed96e122 100644 --- a/src/myrtlespeech/configs/deep_speech_1_en.config +++ b/src/myrtlespeech/configs/deep_speech_1_en.config @@ -60,7 +60,8 @@ train_config { } } } - shuffle_batches_before_every_epoch: true; + random_batches { + } } eval_config { diff --git a/src/myrtlespeech/configs/deep_speech_2_en.config b/src/myrtlespeech/configs/deep_speech_2_en.config index 7bb4db62..edc69ebd 100644 --- a/src/myrtlespeech/configs/deep_speech_2_en.config +++ b/src/myrtlespeech/configs/deep_speech_2_en.config @@ -111,7 +111,8 @@ train_config { } } } - shuffle_batches_before_every_epoch: true; + sorta_grad { + } } eval_config { diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 1cd069ff..8a0e08c2 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -1,14 +1,48 @@ import random +from typing import Iterable +from typing import Optional -class RandomBatchSampler: - """TODO""" +class SequentialRandomSampler: + """A sequential or random iterable over batches of indices. - def __init__(self, indices, batch_size, shuffle, drop_last=False): + The iterator used each time this iterable is iterated over will yield + batches of indices either sequentially (i.e. in-order) or randomly (uniform + without replacement). + This iterable records the number of times it has returned an iterator. A + sequential iterator is returned if the current count is in `sequential`. + + Args: + indices: Data with which batches are created. + batch_size: Batch dimension. + shuffle: Set to True to have the data reshuffled at every epoch if a + random iterator is used. + drop_last: Set to True to drop the last incomplete batch, if the + dataset size is not divisible by the batch size. If False and the + size of dataset is not divisible by the batch size, then the last + batch will be smaller. + n_iterators: Number of iterators returned so far. + sequential: Counts at which to return a sequential iterator. + + Yields: + Batches from `indices`. + """ + + def __init__( + self, + indices: Iterable[int], + batch_size: int, + shuffle: bool, + drop_last: bool = False, + n_iterators: int = 0, + sequential: Optional[set] = None, + ): self.shuffle = shuffle self.batch_indices = self._batch_indices( indices, batch_size, drop_last ) + self._n_iterators = n_iterators + self._sequential = sequential or set() def _batch_indices(self, indices, batch_size, drop_last): batches = [] @@ -23,10 +57,53 @@ def _batch_indices(self, indices, batch_size, drop_last): return batches def __iter__(self): - if self.shuffle: - random.shuffle(self.batch_indices) - for b in self.batch_indices: - yield b + indices = list(range(len(self.batch_indices))) + if self.shuffle and self._n_iterators not in self._sequential: + random.shuffle(indices) + self._n_iterators += 1 + for index in indices: + yield self.batch_indices[index] def __len__(self): return len(self.batch_indices) + + +class SortaGrad(SequentialRandomSampler): + """An iterable over batch indices according to the SortaGrad strategy. + + The SortaGrad curriculum learning strategy iterates over batches from the + batched dataset sequentially for the first pass and then randomly for all + other passes. See `Deep Speech 2 `_ paper + for more information. + + Args: + indices: Data with which batches are created. + batch_size: Batch dimension. + shuffle: Set to True to have the data reshuffled at every epoch if a + random iterator is used. + drop_last: Set to True to drop the last incomplete batch, if the + dataset size is not divisible by the batch size. If False and the + size of dataset is not divisible by the batch size, then the last + batch will be smaller. + start_epoch: Number of iterators returned so far by the sampler. + + Yields: + Batches from `indices`. + """ + + def __init__( + self, + indices: Iterable[int], + batch_size: int, + shuffle: bool, + drop_last: bool = False, + start_epoch: int = 0, + ): + super().__init__( + indices, + batch_size, + shuffle, + drop_last, + n_iterators=start_epoch, + sequential={0}, + ) diff --git a/src/myrtlespeech/protos/shuffle_strategy.proto b/src/myrtlespeech/protos/shuffle_strategy.proto new file mode 100644 index 00000000..e477edc8 --- /dev/null +++ b/src/myrtlespeech/protos/shuffle_strategy.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package myrtlespeech.protos; + + +message SequentialBatches { +} + + +message RandomBatches { +} + + +message SortaGrad { +} diff --git a/src/myrtlespeech/protos/train_config.proto b/src/myrtlespeech/protos/train_config.proto index 7c13cbb0..2a3e6c55 100644 --- a/src/myrtlespeech/protos/train_config.proto +++ b/src/myrtlespeech/protos/train_config.proto @@ -5,6 +5,7 @@ package myrtlespeech.protos; import "myrtlespeech/protos/dataset.proto"; import "myrtlespeech/protos/lr_scheduler.proto"; import "myrtlespeech/protos/optimizer.proto"; +import "myrtlespeech/protos/shuffle_strategy.proto"; // Configuration for training. @@ -29,8 +30,12 @@ message TrainConfig { Dataset dataset = 9; - oneof supported_shuffles { + oneof shuffle_strategy { + // Mantain a sequential batch order. + SequentialBatches sequential_batches = 10; // Shuffle batches before every epoch. - bool shuffle_batches_before_every_epoch = 10; + RandomBatches random_batches = 11; + // Sequential for the first epoch and random for the following ones. + SortaGrad sorta_grad = 12; } } diff --git a/tests/data/test_batch.py b/tests/data/test_batch.py index 037b2a5f..6310efcd 100644 --- a/tests/data/test_batch.py +++ b/tests/data/test_batch.py @@ -102,7 +102,7 @@ def test_pad_sequences_returns_tensor_with_correct_values( def test_seq_to_seq_collate_fn() -> None: - """Unit test to ensure seq_to_seq_collate_fn returns correct values.""" + """Ensure seq_to_seq_collate_fn returns correct values""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py new file mode 100644 index 00000000..d33742b9 --- /dev/null +++ b/tests/data/test_sampler.py @@ -0,0 +1,160 @@ +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import hypothesis.strategies as st +from hypothesis import given +from myrtlespeech.data.sampler import SequentialRandomSampler +from myrtlespeech.data.sampler import SortaGrad + + +# Fixtures and Strategies ----------------------------------------------------- + + +@st.composite +def dataset_gen( + draw, return_kwargs: bool = False +) -> Union[st.SearchStrategy[List], st.SearchStrategy[Tuple[List, Dict]]]: + """Returns a SearchStrategy for a list of dataset indices.""" + kwargs = {} + kwargs["n_batches"] = draw(st.integers(10, 20)) + kwargs["batch_size"] = draw(st.integers(2, 16)) + kwargs["full_last_batch"] = draw(st.booleans()) + + # indices = [0, ..., (n_batches*batch_size)-1-int(not(full_last_batch))] + indices = list(range(kwargs["n_batches"] * kwargs["batch_size"])) + if not kwargs["full_last_batch"]: + del indices[-1] + + if not return_kwargs: + return indices + return indices, kwargs + + +@st.composite +def sequential_epochs_gen(draw) -> st.SearchStrategy[List]: + """Returns a SearchStrategy for a list of sequential epoch numbers.""" + max_size = draw(st.integers(min_value=1, max_value=10)) + + sequential = draw( + st.lists( + elements=st.integers(min_value=11, max_value=20), + min_size=1, + max_size=max_size, + unique=True, + ) + ) + + return sequential + + +# Tests ----------------------------------------------------------------------- + + +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_correct_len(dataset_kwargs: Tuple[List, Dict]): + dataset, kwargs = dataset_kwargs + + dataset = sorted(dataset) + sampler = SortaGrad( + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=False, + ) + assert len(sampler) == kwargs["n_batches"] + + +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_batches_non_empty(dataset_kwargs: Tuple[List, Dict]): + dataset, kwargs = dataset_kwargs + + dataset = sorted(dataset) + sampler = SortaGrad( + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=False, + ) + for batch in sampler: + assert len(batch) > 0 + + +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_first_pass_sequential_remaining_random( + dataset_kwargs: Tuple[List, Dict] +): + dataset, kwargs = dataset_kwargs + + dataset = sorted(dataset) + sortagrad = SortaGrad( + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=True, + ) + + indices: list + for pass_ in range(100): + indices = [] + for batch in sortagrad: + indices.extend(batch) + + assert sorted(indices) == dataset + + if pass_ == 0: + assert indices == sorted(indices) + else: + assert indices != sorted(indices) + + +@given( + dataset_kwargs=dataset_gen(return_kwargs=True), + n_iterators=st.integers(min_value=1, max_value=10), + sequential=sequential_epochs_gen(), +) +def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( + dataset_kwargs: Tuple[List, Dict], n_iterators: int, sequential: List +): + dataset, kwargs = dataset_kwargs + + dataset_batches = [] + batch = [] + for elem in dataset: + batch.append(elem) + if len(batch) == kwargs["batch_size"]: + dataset_batches.append(batch) + batch = [] + if batch and not kwargs["full_last_batch"]: + dataset_batches.append(batch) + sequential_epochs = set(sorted(sequential)) + + seq_strat = SequentialRandomSampler( + dataset, + batch_size=kwargs["batch_size"], + shuffle=True, + n_iterators=n_iterators, + sequential=sequential_epochs, + ) + + for epoch in range(n_iterators, max(sequential_epochs) + 2): + sampler_batches = [batch for batch in iter(seq_strat)] + + assert len(sampler_batches) == len(dataset_batches) + assert sorted(sampler_batches) == sorted(dataset_batches) + + if epoch in sequential_epochs: + assert all( + sample_batch == dataset_batch + for sample_batch, dataset_batch in zip( + sampler_batches, dataset_batches + ) + ) + else: + assert not all( + sample_batch == dataset_batch + for sample_batch, dataset_batch in zip( + sampler_batches, dataset_batches + ) + ) diff --git a/tests/protos/test_train_config.py b/tests/protos/test_train_config.py index 99fef759..db39d996 100644 --- a/tests/protos/test_train_config.py +++ b/tests/protos/test_train_config.py @@ -3,6 +3,7 @@ from typing import Union import hypothesis.strategies as st +from myrtlespeech.protos import shuffle_strategy_pb2 from myrtlespeech.protos import train_config_pb2 from tests.protos.test_dataset import datasets @@ -77,14 +78,18 @@ def train_configs( st.sampled_from( [ f.name - for f in descript.oneofs_by_name["supported_shuffles"].fields + for f in descript.oneofs_by_name["shuffle_strategy"].fields ] ) ) - if shuffle_str == "shuffle_batches_before_every_epoch": - kwargs[shuffle_str] = draw(st.booleans()) + if shuffle_str == "sequential_batches": + kwargs[shuffle_str] = shuffle_strategy_pb2.SequentialBatches() + elif shuffle_str == "random_batches": + kwargs[shuffle_str] = shuffle_strategy_pb2.RandomBatches() + elif shuffle_str == "sorta_grad": + kwargs[shuffle_str] = shuffle_strategy_pb2.SortaGrad() else: - raise ValueError(f"unknown shuffle type {shuffle_str}") + raise ValueError(f"unknown shuffle strategy type {shuffle_str}") # initialise and return all_fields_set(train_config_pb2.TrainConfig, kwargs)