From 6b160aa1f02e8a2af07b7a946f31f434e012c36c Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Fri, 22 Nov 2019 12:02:27 +0000 Subject: [PATCH 01/19] Added sortagrad strategy --- src/myrtlespeech/builders/task_config.py | 42 +++++++---- .../configs/deep_speech_2_en.config | 1 + src/myrtlespeech/data/batch.py | 72 +++++++++++++++++++ src/myrtlespeech/data/sampler.py | 70 ++++++++++++++++-- src/myrtlespeech/protos/train_config.proto | 2 + 5 files changed, 169 insertions(+), 18 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index 0d4e44f5..d99cd04d 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -5,7 +5,9 @@ from myrtlespeech.builders.dataset import build as build_dataset from myrtlespeech.builders.speech_to_text import build as build_stt from myrtlespeech.data.batch import seq_to_seq_collate_fn -from myrtlespeech.data.sampler import RandomBatchSampler +from myrtlespeech.data.batch import seq_to_seq_collate_fn_sorted +from myrtlespeech.data.sampler import SequentialRandomSampler +from myrtlespeech.data.sampler import SortaGrad from myrtlespeech.model.seq_to_seq import SeqToSeq from myrtlespeech.protos import task_config_pb2 @@ -110,18 +112,32 @@ def target_transform(target): ) shuffle = task_config.train_config.shuffle_batches_before_every_epoch - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_sampler=RandomBatchSampler( - indices=range(len(train_dataset)), - batch_size=task_config.train_config.batch_size, - shuffle=shuffle, - drop_last=False, - ), - num_workers=num_workers, - collate_fn=seq_to_seq_collate_fn, - pin_memory=torch.cuda.is_available(), - ) + if task_config.train_config.sortagrad: + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_sampler=SortaGrad( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=shuffle, + drop_last=False, + ), + num_workers=num_workers, + collate_fn=seq_to_seq_collate_fn_sorted, + pin_memory=torch.cuda.is_available(), + ) + else: + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_sampler=SequentialRandomSampler( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=shuffle, + drop_last=False, + ), + num_workers=num_workers, + collate_fn=seq_to_seq_collate_fn, + pin_memory=torch.cuda.is_available(), + ) # eval eval_dataset = build_dataset( diff --git a/src/myrtlespeech/configs/deep_speech_2_en.config b/src/myrtlespeech/configs/deep_speech_2_en.config index 7a9c75a9..239255e9 100644 --- a/src/myrtlespeech/configs/deep_speech_2_en.config +++ b/src/myrtlespeech/configs/deep_speech_2_en.config @@ -109,6 +109,7 @@ train_config { } } shuffle_batches_before_every_epoch: true; + sortagrad: true; } eval_config { diff --git a/src/myrtlespeech/data/batch.py b/src/myrtlespeech/data/batch.py index 1474d3c6..766f16c4 100644 --- a/src/myrtlespeech/data/batch.py +++ b/src/myrtlespeech/data/batch.py @@ -100,3 +100,75 @@ def seq_to_seq_collate_fn( target_seq_lens = torch.tensor(target_seq_lens, requires_grad=False) return (inputs, in_seq_lens), (targets, target_seq_lens) + + +def seq_to_seq_collate_fn_sorted( + batch: List[ + Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + ] + ] +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor] +]: + r"""Collates a list of ``((tensor, tensor_len), (target, target_len))``. + + A ``collate_fn`` for sequence-to-sequence tasks. + + Args: + batch: A list of ``((tensor, tensor_len), (target, target_len))`` + where: + + tensor: + A :py:class:`torch.Tensor` of input for a model. The sequence + length dimension must be last. + + tensor_len: + A scalar, integer :py:class:`torch.Tensor` giving the length of + ``tensor``. + + target: + A :py:class:`torch.Tensor` target for the model. The sequence + length dimension must be last. + + target_len: + A scalar, integer :py:class:`torch.Tensor` giving the length of + ``target``. + + Returns: + A tuple of ``((batch_tensor, batch_tensor_len), (batch_target, + batch_target_len))`` where ``batch_tensor`` is the + result of applying :py:func:`.pad_sequence` to all ``tensor``\s in + ascending order by tensor length, ``batch_tensor_lens`` is the result + of stacking all ``tensor_len``\s, ``batch_target`` is the result of + appying :py:func:`.pad_sequence` to all ``target``\s (in an order that + corresponds to the samples in `batch_tensor`) and ``batch_target_len`` + is the result of stacking all ``target_len``\s. + """ + + inputs, in_seq_lens = [], [] + targets, target_seq_lens = [], [] + + for (input, in_seq_len), (target, target_seq_len) in batch: + inputs.append(input) + in_seq_lens.append(in_seq_len) + targets.append(target) + target_seq_lens.append(target_seq_len) + + # Sort the samples + samples = [ + (input, in_seq_len, target, target_seq_len) + for input, in_seq_len, target, target_seq_len in zip( + inputs, in_seq_lens, targets, target_seq_lens + ) + ] + sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) + inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) + + inputs = pad_sequence(inputs) + in_seq_lens = torch.tensor(in_seq_lens, requires_grad=False) + targets = pad_sequence(targets) + target_seq_lens = torch.tensor(target_seq_lens, requires_grad=False) + + return (inputs, in_seq_lens), (targets, target_seq_lens) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 1cd069ff..770cbf11 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -1,14 +1,37 @@ import random -class RandomBatchSampler: - """TODO""" +class SequentialRandomSampler: + """A sequential or random iterable over batches. + The iterator used each time this iterable is iterated over will yield + batches either sequentially (i.e. in-order) or randomly (uniform without + replacement) from `batches`. + This iterable records the number of times it has returned an iterator. A + sequential iterator is returned if the current count is in `sequential`. + Args: + batches: A list of batches. + n_iterators (optional, int): Number of iterators returned so far. + sequential (optional, set of int): Counts at which to return a + sequential iterator. + Yields: + Batches from `batches`. + """ - def __init__(self, indices, batch_size, shuffle, drop_last=False): + def __init__( + self, + indices, + batch_size, + shuffle, + drop_last=False, + n_iterators=0, + sequential=None, + ): self.shuffle = shuffle self.batch_indices = self._batch_indices( indices, batch_size, drop_last ) + self._n_iterators = n_iterators + self._sequential = sequential or {} def _batch_indices(self, indices, batch_size, drop_last): batches = [] @@ -23,10 +46,47 @@ def _batch_indices(self, indices, batch_size, drop_last): return batches def __iter__(self): - if self.shuffle: - random.shuffle(self.batch_indices) + if self._n_iterators in self._sequential: + iter_ = self._seq_iter() + else: + iter_ = self._rnd_iter() + self._n_iterators += 1 + return iter_ + + def _seq_iter(self): for b in self.batch_indices: yield b + def _rnd_iter(self): + indices = list(range(len(self.batch_indices))) + if self.shuffle: + random.shuffle(indices) + for index in indices: + yield self.batch_indices[index] + def __len__(self): return len(self.batch_indices) + + +class SortaGrad(SequentialRandomSampler): + """An iterable over batch indices according to the SortaGrad strategy. + The SortaGrad curriculum learning strategy iterates over batches from the + batched dataset sequentially for the first pass and then randomly for all + other passes. + Args: + batches: A list of batches. + Yields: + Batches from `batches`. + """ + + def __init__( + self, indices, batch_size, shuffle, drop_last=False, start_epoch=0 + ): + super().__init__( + indices, + batch_size, + shuffle, + drop_last, + n_iterators=start_epoch, + sequential={0}, + ) diff --git a/src/myrtlespeech/protos/train_config.proto b/src/myrtlespeech/protos/train_config.proto index 0f0be08e..0d218771 100644 --- a/src/myrtlespeech/protos/train_config.proto +++ b/src/myrtlespeech/protos/train_config.proto @@ -25,4 +25,6 @@ message TrainConfig { // Shuffle batches before every epoch. bool shuffle_batches_before_every_epoch = 6; } + + bool sortagrad = 11; } From 3751ddeddbd899e6ee9e2566e4fd899ea7b910a1 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Fri, 22 Nov 2019 12:05:21 +0000 Subject: [PATCH 02/19] Added tests for sortagrad --- tests/data/test_batch.py | 39 ++++++++++++ tests/data/test_sampler.py | 99 +++++++++++++++++++++++++++++++ tests/protos/test_train_config.py | 2 + 3 files changed, 140 insertions(+) create mode 100644 tests/data/test_sampler.py diff --git a/tests/data/test_batch.py b/tests/data/test_batch.py index 037b2a5f..7bebcedc 100644 --- a/tests/data/test_batch.py +++ b/tests/data/test_batch.py @@ -127,3 +127,42 @@ def test_seq_to_seq_collate_fn() -> None: assert torch.all(y[0] == pad_sequence(targets)) assert torch.all(y[1] == target_lengths) + + +def test_seq_to_seq_collate_fn_sorted() -> None: + """Unit test to ensure seq_to_seq_collate_fn_sorted returns correct + values.""" + inputs = [rand([1, 2, 3]), rand([1, 2, 5])] + seq_lens = tensor([3, 5]) + + targets = [rand([10]), rand([7])] + target_lengths = tensor([10, 7]) + + batch = [ + ((inputs[0], seq_lens[0]), (targets[0], target_lengths[0])), + ((inputs[1], seq_lens[1]), (targets[1], target_lengths[1])), + ] + + x, y = seq_to_seq_collate_fn(batch) + + assert isinstance(x, tuple) + assert len(x) == 2 + + assert isinstance(y, tuple) + assert len(y) == 2 + + # Sort the input samples + samples = [ + (input, seq_lens, target, target_lengths) + for input, seq_len, target, target_length in zip( + inputs, seq_lens, targets, target_lengths + ) + ] + sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) + inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) + + assert torch.all(x[0] == pad_sequence(inputs)) + assert torch.all(x[1] == seq_lens) + + assert torch.all(y[0] == pad_sequence(targets)) + assert torch.all(y[1] == target_lengths) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py new file mode 100644 index 00000000..9dcb28f4 --- /dev/null +++ b/tests/data/test_sampler.py @@ -0,0 +1,99 @@ +from myrtlespeech.data.sampler import SequentialRandomSampler +from myrtlespeech.data.sampler import SortaGrad + + +def dataset_gen(n_batches, batch_size, full_last_batch): + """Returns [0, ..., (n_batches*batch_size)-1-int(not(full_last_batch))]""" + indices = list(range(n_batches * batch_size)) + if not full_last_batch: + del indices[-1] + return indices + + +def test_sorta_grad_correct_len(): + n_batches = 10 + batch_size = 16 + + for full_last_batch in [True, False]: + dataset = sorted(dataset_gen(n_batches, batch_size, full_last_batch)) + sampler = SortaGrad( + dataset, + drop_last=full_last_batch, + batch_size=batch_size, + shuffle=False, + ) + assert len(sampler) == n_batches + + +def test_sorta_grad_batches_non_empty(): + n_batches = 10 + batch_size = 16 + + for full_last_batch in [True, False]: + dataset = sorted(dataset_gen(n_batches, batch_size, full_last_batch)) + sampler = SortaGrad( + dataset, + drop_last=full_last_batch, + batch_size=batch_size, + shuffle=False, + ) + for batch in sampler: + assert len(batch) > 0 + + +def test_sorta_grad_first_pass_sequential_remaining_random(): + n_batches = 10 + batch_size = 16 + dataset = sorted(dataset_gen(n_batches, batch_size, False)) + + sortagrad = SortaGrad( + dataset, drop_last=False, batch_size=batch_size, shuffle=True + ) + + for pass_ in range(100): + indices = [] + for batch in sortagrad: + indices.extend(batch) + + assert sorted(indices) == dataset + + if pass_ == 0: + assert indices == sorted(indices) + else: + assert indices != sorted(indices) + + +def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs(): + dataset = list(range(10)) + dataset_batches = [[elem] for elem in dataset] + n_iterators = 5 + sequential = {0, 2, 3, 7, 8, 10} + + seq_strat = SequentialRandomSampler( + dataset, + batch_size=1, + shuffle=True, + n_iterators=n_iterators, + sequential=sequential, + ) + + for epoch in range(n_iterators, max(sequential) + 2): + sampler_batches = [batch for batch in iter(seq_strat)] + + assert len(sampler_batches) == len(dataset_batches) + assert sorted(sampler_batches) == sorted(dataset_batches) + + if epoch in sequential: + assert all( + sample_batch == dataset_batch + for sample_batch, dataset_batch in zip( + sampler_batches, dataset_batches + ) + ) + else: + assert not all( + sample_batch == dataset_batch + for sample_batch, dataset_batch in zip( + sampler_batches, dataset_batches + ) + ) diff --git a/tests/protos/test_train_config.py b/tests/protos/test_train_config.py index a8fc4312..34fd8013 100644 --- a/tests/protos/test_train_config.py +++ b/tests/protos/test_train_config.py @@ -60,6 +60,8 @@ def train_configs( else: raise ValueError(f"unknown shuffle type {shuffle_str}") + kwargs["sortagrad"] = draw(st.booleans()) + # initialise and return all_fields_set(train_config_pb2.TrainConfig, kwargs) train_config = train_config_pb2.TrainConfig(**kwargs) From e054ed95f295ccef18490d4044e957790fb7cc38 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Wed, 27 Nov 2019 11:12:50 +0000 Subject: [PATCH 03/19] Delete seq_to_seq_collate_fn_sorted and add sort flag in seq_to_seq_collate_fn --- src/myrtlespeech/builders/task_config.py | 5 +- src/myrtlespeech/data/batch.py | 88 +++++------------------- tests/data/test_batch.py | 11 +-- 3 files changed, 24 insertions(+), 80 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index d99cd04d..71338944 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -5,7 +5,6 @@ from myrtlespeech.builders.dataset import build as build_dataset from myrtlespeech.builders.speech_to_text import build as build_stt from myrtlespeech.data.batch import seq_to_seq_collate_fn -from myrtlespeech.data.batch import seq_to_seq_collate_fn_sorted from myrtlespeech.data.sampler import SequentialRandomSampler from myrtlespeech.data.sampler import SortaGrad from myrtlespeech.model.seq_to_seq import SeqToSeq @@ -122,7 +121,7 @@ def target_transform(target): drop_last=False, ), num_workers=num_workers, - collate_fn=seq_to_seq_collate_fn_sorted, + collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=True), pin_memory=torch.cuda.is_available(), ) else: @@ -135,7 +134,7 @@ def target_transform(target): drop_last=False, ), num_workers=num_workers, - collate_fn=seq_to_seq_collate_fn, + collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=False), pin_memory=torch.cuda.is_available(), ) diff --git a/src/myrtlespeech/data/batch.py b/src/myrtlespeech/data/batch.py index 766f16c4..f11c2659 100644 --- a/src/myrtlespeech/data/batch.py +++ b/src/myrtlespeech/data/batch.py @@ -48,7 +48,8 @@ def seq_to_seq_collate_fn( Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], ] - ] + ], + sort: bool = False, ) -> Tuple[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor] ]: @@ -76,6 +77,9 @@ def seq_to_seq_collate_fn( A scalar, integer :py:class:`torch.Tensor` giving the length of ``target``. + sort: A boolean value used to decide whether the batch should be sorted + by the input tensor_len + Returns: A tuple of ``((batch_tensor, batch_tensor_len), (batch_target, batch_target_len))`` where ``batch_tensor`` is the @@ -83,7 +87,8 @@ def seq_to_seq_collate_fn( ``batch_tensor_lens`` is the result of stacking all ``tensor_len``\s, ``batch_target`` is the result of appying :py:func:`.pad_sequence` to all ``target``\s and ``batch_target_len`` is the result of stacking all - ``target_len``\s. + ``target_len``\s. If sort is set to True then the output is sorted by + the input tensor_len. """ inputs, in_seq_lens = [], [] targets, target_seq_lens = [], [] @@ -94,77 +99,16 @@ def seq_to_seq_collate_fn( targets.append(target) target_seq_lens.append(target_seq_len) - inputs = pad_sequence(inputs) - in_seq_lens = torch.tensor(in_seq_lens, requires_grad=False) - targets = pad_sequence(targets) - target_seq_lens = torch.tensor(target_seq_lens, requires_grad=False) - - return (inputs, in_seq_lens), (targets, target_seq_lens) - - -def seq_to_seq_collate_fn_sorted( - batch: List[ - Tuple[ - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], + if sort: + # Sort the samples + samples = [ + (input, in_seq_len, target, target_seq_len) + for input, in_seq_len, target, target_seq_len in zip( + inputs, in_seq_lens, targets, target_seq_lens + ) ] - ] -) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor] -]: - r"""Collates a list of ``((tensor, tensor_len), (target, target_len))``. - - A ``collate_fn`` for sequence-to-sequence tasks. - - Args: - batch: A list of ``((tensor, tensor_len), (target, target_len))`` - where: - - tensor: - A :py:class:`torch.Tensor` of input for a model. The sequence - length dimension must be last. - - tensor_len: - A scalar, integer :py:class:`torch.Tensor` giving the length of - ``tensor``. - - target: - A :py:class:`torch.Tensor` target for the model. The sequence - length dimension must be last. - - target_len: - A scalar, integer :py:class:`torch.Tensor` giving the length of - ``target``. - - Returns: - A tuple of ``((batch_tensor, batch_tensor_len), (batch_target, - batch_target_len))`` where ``batch_tensor`` is the - result of applying :py:func:`.pad_sequence` to all ``tensor``\s in - ascending order by tensor length, ``batch_tensor_lens`` is the result - of stacking all ``tensor_len``\s, ``batch_target`` is the result of - appying :py:func:`.pad_sequence` to all ``target``\s (in an order that - corresponds to the samples in `batch_tensor`) and ``batch_target_len`` - is the result of stacking all ``target_len``\s. - """ - - inputs, in_seq_lens = [], [] - targets, target_seq_lens = [], [] - - for (input, in_seq_len), (target, target_seq_len) in batch: - inputs.append(input) - in_seq_lens.append(in_seq_len) - targets.append(target) - target_seq_lens.append(target_seq_len) - - # Sort the samples - samples = [ - (input, in_seq_len, target, target_seq_len) - for input, in_seq_len, target, target_seq_len in zip( - inputs, in_seq_lens, targets, target_seq_lens - ) - ] - sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) - inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) + sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) + inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) inputs = pad_sequence(inputs) in_seq_lens = torch.tensor(in_seq_lens, requires_grad=False) diff --git a/tests/data/test_batch.py b/tests/data/test_batch.py index 7bebcedc..171a5dc3 100644 --- a/tests/data/test_batch.py +++ b/tests/data/test_batch.py @@ -102,7 +102,8 @@ def test_pad_sequences_returns_tensor_with_correct_values( def test_seq_to_seq_collate_fn() -> None: - """Unit test to ensure seq_to_seq_collate_fn returns correct values.""" + """Unit test to ensure seq_to_seq_collate_fn returns correct values when + sort=False.""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) @@ -114,7 +115,7 @@ def test_seq_to_seq_collate_fn() -> None: ((inputs[1], seq_lens[1]), (targets[1], target_lengths[1])), ] - x, y = seq_to_seq_collate_fn(batch) + x, y = seq_to_seq_collate_fn(batch, sort=False) assert isinstance(x, tuple) assert len(x) == 2 @@ -130,8 +131,8 @@ def test_seq_to_seq_collate_fn() -> None: def test_seq_to_seq_collate_fn_sorted() -> None: - """Unit test to ensure seq_to_seq_collate_fn_sorted returns correct - values.""" + """Unit test to ensure seq_to_seq_collate_fn returns correct values when + sort=True.""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) @@ -143,7 +144,7 @@ def test_seq_to_seq_collate_fn_sorted() -> None: ((inputs[1], seq_lens[1]), (targets[1], target_lengths[1])), ] - x, y = seq_to_seq_collate_fn(batch) + x, y = seq_to_seq_collate_fn(batch, sort=True) assert isinstance(x, tuple) assert len(x) == 2 From 8bdddb215cb0b8c5b68c13457925bf050b520cce Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Wed, 27 Nov 2019 12:34:01 +0000 Subject: [PATCH 04/19] Fix doc strings + add sortagrad reference --- src/myrtlespeech/data/sampler.py | 9 ++++++++- tests/data/test_batch.py | 6 ++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 770cbf11..05b8bcf0 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -3,16 +3,19 @@ class SequentialRandomSampler: """A sequential or random iterable over batches. + The iterator used each time this iterable is iterated over will yield batches either sequentially (i.e. in-order) or randomly (uniform without replacement) from `batches`. This iterable records the number of times it has returned an iterator. A sequential iterator is returned if the current count is in `sequential`. + Args: batches: A list of batches. n_iterators (optional, int): Number of iterators returned so far. sequential (optional, set of int): Counts at which to return a sequential iterator. + Yields: Batches from `batches`. """ @@ -70,11 +73,15 @@ def __len__(self): class SortaGrad(SequentialRandomSampler): """An iterable over batch indices according to the SortaGrad strategy. + The SortaGrad curriculum learning strategy iterates over batches from the batched dataset sequentially for the first pass and then randomly for all - other passes. + other passes. See Deep Speech 2 paper for more information on this: + https://arxiv.org/abs/1512.02595 + Args: batches: A list of batches. + Yields: Batches from `batches`. """ diff --git a/tests/data/test_batch.py b/tests/data/test_batch.py index 171a5dc3..f4150e1d 100644 --- a/tests/data/test_batch.py +++ b/tests/data/test_batch.py @@ -102,8 +102,7 @@ def test_pad_sequences_returns_tensor_with_correct_values( def test_seq_to_seq_collate_fn() -> None: - """Unit test to ensure seq_to_seq_collate_fn returns correct values when - sort=False.""" + """Ensure seq_to_seq_collate_fn returns correct values when sort=False.""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) @@ -131,8 +130,7 @@ def test_seq_to_seq_collate_fn() -> None: def test_seq_to_seq_collate_fn_sorted() -> None: - """Unit test to ensure seq_to_seq_collate_fn returns correct values when - sort=True.""" + """Ensure seq_to_seq_collate_fn returns correct values when sort=True.""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) From a0b3afea18156f7cb771db242e894d216b4b9fc0 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Wed, 27 Nov 2019 15:09:27 +0000 Subject: [PATCH 05/19] Add docs about arguments in samplers functions --- src/myrtlespeech/data/sampler.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 05b8bcf0..ac5d7707 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -11,7 +11,14 @@ class SequentialRandomSampler: sequential iterator is returned if the current count is in `sequential`. Args: - batches: A list of batches. + indices: data with which batches are created. + batch_size: Batch dimension. + shuffle: Set to True to have the data reshuffled at every epoch if a + random iterator is used. + drop_last: (optional, bool): Set to True to drop the last incomplete + batch, if the dataset size is not divisible by the batch size. If + False and the size of dataset is not divisible by the batch size, + then the last batch will be smaller. n_iterators (optional, int): Number of iterators returned so far. sequential (optional, set of int): Counts at which to return a sequential iterator. @@ -80,7 +87,16 @@ class SortaGrad(SequentialRandomSampler): https://arxiv.org/abs/1512.02595 Args: - batches: A list of batches. + indices: data with which batches are created. + batch_size: Batch dimension. + shuffle: Set to True to have the data reshuffled at every epoch if a + random iterator is used. + drop_last: (optional, bool): Set to True to drop the last incomplete + batch, if the dataset size is not divisible by the batch size. If + False and the size of dataset is not divisible by the batch size, + then the last batch will be smaller. + start_epoch (optional, int): Number of iterators returned so far by the + sampler. Yields: Batches from `batches`. From 60929210e64b5223e647fc49f8d5fb6dac654704 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Wed, 27 Nov 2019 15:49:37 +0000 Subject: [PATCH 06/19] Avoid dataloader repetition --- src/myrtlespeech/builders/task_config.py | 44 ++++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index 71338944..08503aa5 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -1,5 +1,6 @@ import multiprocessing from typing import Tuple +from typing import Union import torch from myrtlespeech.builders.dataset import build as build_dataset @@ -111,32 +112,31 @@ def target_transform(target): ) shuffle = task_config.train_config.shuffle_batches_before_every_epoch + batch_sampler: Union[SortaGrad, SequentialRandomSampler] if task_config.train_config.sortagrad: - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_sampler=SortaGrad( - indices=range(len(train_dataset)), - batch_size=task_config.train_config.batch_size, - shuffle=shuffle, - drop_last=False, - ), - num_workers=num_workers, - collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=True), - pin_memory=torch.cuda.is_available(), + batch_sampler = SortaGrad( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=shuffle, + drop_last=False, ) + sort = True else: - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_sampler=SequentialRandomSampler( - indices=range(len(train_dataset)), - batch_size=task_config.train_config.batch_size, - shuffle=shuffle, - drop_last=False, - ), - num_workers=num_workers, - collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=False), - pin_memory=torch.cuda.is_available(), + batch_sampler = SequentialRandomSampler( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=shuffle, + drop_last=False, ) + sort = False + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=sort), + pin_memory=torch.cuda.is_available(), + ) # eval eval_dataset = build_dataset( From 852c5744ee0241f09285ad40c896b4352be917cb Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 28 Nov 2019 16:41:00 +0000 Subject: [PATCH 07/19] Change API to define shuffle strategy --- src/myrtlespeech/builders/task_config.py | 20 ++++++++++++++----- .../configs/deep_speech_1_en.config | 3 ++- .../configs/deep_speech_2_en.config | 4 ++-- src/myrtlespeech/data/sampler.py | 15 ++------------ src/myrtlespeech/protos/train_config.proto | 11 ++++++---- tests/protos/test_train_config.py | 15 ++++++++------ 6 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index 08503aa5..63018e9c 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -111,24 +111,34 @@ def target_transform(target): add_seq_len_to_transforms=True, ) - shuffle = task_config.train_config.shuffle_batches_before_every_epoch + shuffle_str = task_config.train_config.WhichOneof("shuffle_strategy") batch_sampler: Union[SortaGrad, SequentialRandomSampler] - if task_config.train_config.sortagrad: + if shuffle_str == "sortagrad": batch_sampler = SortaGrad( indices=range(len(train_dataset)), batch_size=task_config.train_config.batch_size, - shuffle=shuffle, + shuffle=True, drop_last=False, ) sort = True - else: + elif shuffle_str == "random_batches": + batch_sampler = SequentialRandomSampler( + indices=range(len(train_dataset)), + batch_size=task_config.train_config.batch_size, + shuffle=True, + drop_last=False, + ) + sort = False + elif shuffle_str == "sequential_batches": batch_sampler = SequentialRandomSampler( indices=range(len(train_dataset)), batch_size=task_config.train_config.batch_size, - shuffle=shuffle, + shuffle=False, drop_last=False, ) sort = False + else: + raise ValueError(f"unsupported shuffle strategy {shuffle_str}") train_loader = torch.utils.data.DataLoader( dataset=train_dataset, diff --git a/src/myrtlespeech/configs/deep_speech_1_en.config b/src/myrtlespeech/configs/deep_speech_1_en.config index 0d3e7f79..ed96e122 100644 --- a/src/myrtlespeech/configs/deep_speech_1_en.config +++ b/src/myrtlespeech/configs/deep_speech_1_en.config @@ -60,7 +60,8 @@ train_config { } } } - shuffle_batches_before_every_epoch: true; + random_batches { + } } eval_config { diff --git a/src/myrtlespeech/configs/deep_speech_2_en.config b/src/myrtlespeech/configs/deep_speech_2_en.config index 239255e9..853cbf35 100644 --- a/src/myrtlespeech/configs/deep_speech_2_en.config +++ b/src/myrtlespeech/configs/deep_speech_2_en.config @@ -108,8 +108,8 @@ train_config { } } } - shuffle_batches_before_every_epoch: true; - sortagrad: true; + sortagrad { + } } eval_config { diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index ac5d7707..4ac8d344 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -56,23 +56,12 @@ def _batch_indices(self, indices, batch_size, drop_last): return batches def __iter__(self): - if self._n_iterators in self._sequential: - iter_ = self._seq_iter() - else: - iter_ = self._rnd_iter() - self._n_iterators += 1 - return iter_ - - def _seq_iter(self): - for b in self.batch_indices: - yield b - - def _rnd_iter(self): indices = list(range(len(self.batch_indices))) - if self.shuffle: + if self.shuffle and self._n_iterators not in self._sequential: random.shuffle(indices) for index in indices: yield self.batch_indices[index] + self._n_iterators += 1 def __len__(self): return len(self.batch_indices) diff --git a/src/myrtlespeech/protos/train_config.proto b/src/myrtlespeech/protos/train_config.proto index 0d218771..4b8892d0 100644 --- a/src/myrtlespeech/protos/train_config.proto +++ b/src/myrtlespeech/protos/train_config.proto @@ -4,6 +4,7 @@ package myrtlespeech.protos; import "myrtlespeech/protos/dataset.proto"; import "myrtlespeech/protos/optimizer.proto"; +import "myrtlespeech/protos/shuffle_strategy.proto"; // Configuration for training. @@ -21,10 +22,12 @@ message TrainConfig { Dataset dataset = 5; - oneof supported_shuffles { + oneof shuffle_strategy { + // Mantain a sequential batch order. + SequentialBatches sequential_batches = 6; // Shuffle batches before every epoch. - bool shuffle_batches_before_every_epoch = 6; + RandomBatches random_batches = 7; + // Sequential for the first epoch and random for the following ones. + SortaGrad sortagrad = 8; } - - bool sortagrad = 11; } diff --git a/tests/protos/test_train_config.py b/tests/protos/test_train_config.py index 34fd8013..ca33fdce 100644 --- a/tests/protos/test_train_config.py +++ b/tests/protos/test_train_config.py @@ -3,6 +3,7 @@ from typing import Union import hypothesis.strategies as st +from myrtlespeech.protos import shuffle_strategy_pb2 from myrtlespeech.protos import train_config_pb2 from tests.protos.test_dataset import datasets @@ -51,16 +52,18 @@ def train_configs( st.sampled_from( [ f.name - for f in descript.oneofs_by_name["supported_shuffles"].fields + for f in descript.oneofs_by_name["shuffle_strategy"].fields ] ) ) - if shuffle_str == "shuffle_batches_before_every_epoch": - kwargs[shuffle_str] = draw(st.booleans()) + if shuffle_str == "sequential_batches": + kwargs[shuffle_str] = shuffle_strategy_pb2.SequentialBatches() + elif shuffle_str == "random_batches": + kwargs[shuffle_str] = shuffle_strategy_pb2.RandomBatches() + elif shuffle_str == "sortagrad": + kwargs[shuffle_str] = shuffle_strategy_pb2.SortaGrad() else: - raise ValueError(f"unknown shuffle type {shuffle_str}") - - kwargs["sortagrad"] = draw(st.booleans()) + raise ValueError(f"unknown shuffle strategy type {shuffle_str}") # initialise and return all_fields_set(train_config_pb2.TrainConfig, kwargs) From 971552f3ca14ae3fdb5734e06f32cb137f21a596 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 28 Nov 2019 17:29:33 +0000 Subject: [PATCH 08/19] Add proto file for shuffle strategy --- src/myrtlespeech/protos/shuffle_strategy.proto | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/myrtlespeech/protos/shuffle_strategy.proto diff --git a/src/myrtlespeech/protos/shuffle_strategy.proto b/src/myrtlespeech/protos/shuffle_strategy.proto new file mode 100644 index 00000000..748865e0 --- /dev/null +++ b/src/myrtlespeech/protos/shuffle_strategy.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package myrtlespeech.protos; + +import "google/protobuf/wrappers.proto"; + + +message SequentialBatches { +} + + +message RandomBatches { +} + + +message SortaGrad { +} From 29976284a2eec9a9374e4bea9b61afa50604147f Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Mon, 2 Dec 2019 15:49:20 +0000 Subject: [PATCH 09/19] rename sortagrad variable using snake_case + remove unused import --- src/myrtlespeech/protos/shuffle_strategy.proto | 2 -- src/myrtlespeech/protos/train_config.proto | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/myrtlespeech/protos/shuffle_strategy.proto b/src/myrtlespeech/protos/shuffle_strategy.proto index 748865e0..e477edc8 100644 --- a/src/myrtlespeech/protos/shuffle_strategy.proto +++ b/src/myrtlespeech/protos/shuffle_strategy.proto @@ -2,8 +2,6 @@ syntax = "proto3"; package myrtlespeech.protos; -import "google/protobuf/wrappers.proto"; - message SequentialBatches { } diff --git a/src/myrtlespeech/protos/train_config.proto b/src/myrtlespeech/protos/train_config.proto index ba0bb974..2a3e6c55 100644 --- a/src/myrtlespeech/protos/train_config.proto +++ b/src/myrtlespeech/protos/train_config.proto @@ -36,6 +36,6 @@ message TrainConfig { // Shuffle batches before every epoch. RandomBatches random_batches = 11; // Sequential for the first epoch and random for the following ones. - SortaGrad sortagrad = 12; + SortaGrad sorta_grad = 12; } } From 25b787e3001089ddcb0897fb89d36c566146f9df Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Mon, 2 Dec 2019 17:28:40 +0000 Subject: [PATCH 10/19] Fix test error caused by changed sorta_grad variable name --- src/myrtlespeech/configs/deep_speech_2_en.config | 2 +- tests/protos/test_train_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/myrtlespeech/configs/deep_speech_2_en.config b/src/myrtlespeech/configs/deep_speech_2_en.config index e118902d..edc69ebd 100644 --- a/src/myrtlespeech/configs/deep_speech_2_en.config +++ b/src/myrtlespeech/configs/deep_speech_2_en.config @@ -111,7 +111,7 @@ train_config { } } } - sortagrad { + sorta_grad { } } diff --git a/tests/protos/test_train_config.py b/tests/protos/test_train_config.py index 7306a538..db39d996 100644 --- a/tests/protos/test_train_config.py +++ b/tests/protos/test_train_config.py @@ -86,7 +86,7 @@ def train_configs( kwargs[shuffle_str] = shuffle_strategy_pb2.SequentialBatches() elif shuffle_str == "random_batches": kwargs[shuffle_str] = shuffle_strategy_pb2.RandomBatches() - elif shuffle_str == "sortagrad": + elif shuffle_str == "sorta_grad": kwargs[shuffle_str] = shuffle_strategy_pb2.SortaGrad() else: raise ValueError(f"unknown shuffle strategy type {shuffle_str}") From 7965f44dc6ce1612784a101da1677eaccaad5eed Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Tue, 3 Dec 2019 09:36:23 +0000 Subject: [PATCH 11/19] Move update of num_iterators before for loop in sampler + change sorta_grad variable name --- src/myrtlespeech/builders/task_config.py | 2 +- src/myrtlespeech/data/sampler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index d255753a..ec73e883 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -118,7 +118,7 @@ def target_transform(target): shuffle_str = task_config.train_config.WhichOneof("shuffle_strategy") batch_sampler: Union[SortaGrad, SequentialRandomSampler] - if shuffle_str == "sortagrad": + if shuffle_str == "sorta_grad": batch_sampler = SortaGrad( indices=range(len(train_dataset)), batch_size=task_config.train_config.batch_size, diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 4ac8d344..d8d09d89 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -59,9 +59,9 @@ def __iter__(self): indices = list(range(len(self.batch_indices))) if self.shuffle and self._n_iterators not in self._sequential: random.shuffle(indices) + self._n_iterators += 1 for index in indices: yield self.batch_indices[index] - self._n_iterators += 1 def __len__(self): return len(self.batch_indices) From 26282f12abec4b2e1bd488118cd3681df1081896 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Tue, 3 Dec 2019 12:44:23 +0000 Subject: [PATCH 12/19] Remove sort within single batches --- src/myrtlespeech/builders/task_config.py | 5 +-- src/myrtlespeech/data/batch.py | 20 ++--------- tests/data/test_batch.py | 42 ++---------------------- 3 files changed, 5 insertions(+), 62 deletions(-) diff --git a/src/myrtlespeech/builders/task_config.py b/src/myrtlespeech/builders/task_config.py index ec73e883..7d914610 100644 --- a/src/myrtlespeech/builders/task_config.py +++ b/src/myrtlespeech/builders/task_config.py @@ -125,7 +125,6 @@ def target_transform(target): shuffle=True, drop_last=False, ) - sort = True elif shuffle_str == "random_batches": batch_sampler = SequentialRandomSampler( indices=range(len(train_dataset)), @@ -133,7 +132,6 @@ def target_transform(target): shuffle=True, drop_last=False, ) - sort = False elif shuffle_str == "sequential_batches": batch_sampler = SequentialRandomSampler( indices=range(len(train_dataset)), @@ -141,7 +139,6 @@ def target_transform(target): shuffle=False, drop_last=False, ) - sort = False else: raise ValueError(f"unsupported shuffle strategy {shuffle_str}") @@ -149,7 +146,7 @@ def target_transform(target): dataset=train_dataset, batch_sampler=batch_sampler, num_workers=num_workers, - collate_fn=lambda batch: seq_to_seq_collate_fn(batch, sort=sort), + collate_fn=seq_to_seq_collate_fn, pin_memory=torch.cuda.is_available(), ) diff --git a/src/myrtlespeech/data/batch.py b/src/myrtlespeech/data/batch.py index f11c2659..1474d3c6 100644 --- a/src/myrtlespeech/data/batch.py +++ b/src/myrtlespeech/data/batch.py @@ -48,8 +48,7 @@ def seq_to_seq_collate_fn( Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], ] - ], - sort: bool = False, + ] ) -> Tuple[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor] ]: @@ -77,9 +76,6 @@ def seq_to_seq_collate_fn( A scalar, integer :py:class:`torch.Tensor` giving the length of ``target``. - sort: A boolean value used to decide whether the batch should be sorted - by the input tensor_len - Returns: A tuple of ``((batch_tensor, batch_tensor_len), (batch_target, batch_target_len))`` where ``batch_tensor`` is the @@ -87,8 +83,7 @@ def seq_to_seq_collate_fn( ``batch_tensor_lens`` is the result of stacking all ``tensor_len``\s, ``batch_target`` is the result of appying :py:func:`.pad_sequence` to all ``target``\s and ``batch_target_len`` is the result of stacking all - ``target_len``\s. If sort is set to True then the output is sorted by - the input tensor_len. + ``target_len``\s. """ inputs, in_seq_lens = [], [] targets, target_seq_lens = [], [] @@ -99,17 +94,6 @@ def seq_to_seq_collate_fn( targets.append(target) target_seq_lens.append(target_seq_len) - if sort: - # Sort the samples - samples = [ - (input, in_seq_len, target, target_seq_len) - for input, in_seq_len, target, target_seq_len in zip( - inputs, in_seq_lens, targets, target_seq_lens - ) - ] - sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) - inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) - inputs = pad_sequence(inputs) in_seq_lens = torch.tensor(in_seq_lens, requires_grad=False) targets = pad_sequence(targets) diff --git a/tests/data/test_batch.py b/tests/data/test_batch.py index f4150e1d..6310efcd 100644 --- a/tests/data/test_batch.py +++ b/tests/data/test_batch.py @@ -102,7 +102,7 @@ def test_pad_sequences_returns_tensor_with_correct_values( def test_seq_to_seq_collate_fn() -> None: - """Ensure seq_to_seq_collate_fn returns correct values when sort=False.""" + """Ensure seq_to_seq_collate_fn returns correct values""" inputs = [rand([1, 2, 3]), rand([1, 2, 5])] seq_lens = tensor([3, 5]) @@ -114,7 +114,7 @@ def test_seq_to_seq_collate_fn() -> None: ((inputs[1], seq_lens[1]), (targets[1], target_lengths[1])), ] - x, y = seq_to_seq_collate_fn(batch, sort=False) + x, y = seq_to_seq_collate_fn(batch) assert isinstance(x, tuple) assert len(x) == 2 @@ -127,41 +127,3 @@ def test_seq_to_seq_collate_fn() -> None: assert torch.all(y[0] == pad_sequence(targets)) assert torch.all(y[1] == target_lengths) - - -def test_seq_to_seq_collate_fn_sorted() -> None: - """Ensure seq_to_seq_collate_fn returns correct values when sort=True.""" - inputs = [rand([1, 2, 3]), rand([1, 2, 5])] - seq_lens = tensor([3, 5]) - - targets = [rand([10]), rand([7])] - target_lengths = tensor([10, 7]) - - batch = [ - ((inputs[0], seq_lens[0]), (targets[0], target_lengths[0])), - ((inputs[1], seq_lens[1]), (targets[1], target_lengths[1])), - ] - - x, y = seq_to_seq_collate_fn(batch, sort=True) - - assert isinstance(x, tuple) - assert len(x) == 2 - - assert isinstance(y, tuple) - assert len(y) == 2 - - # Sort the input samples - samples = [ - (input, seq_lens, target, target_lengths) - for input, seq_len, target, target_length in zip( - inputs, seq_lens, targets, target_lengths - ) - ] - sorted_samples = sorted(samples, key=lambda s: s[0].size(-1)) - inputs, in_seq_lens, targets, target_seq_lens = zip(*sorted_samples) - - assert torch.all(x[0] == pad_sequence(inputs)) - assert torch.all(x[1] == seq_lens) - - assert torch.all(y[0] == pad_sequence(targets)) - assert torch.all(y[1] == target_lengths) From b46ed75486072ce70f95ea754841d5a1e8916c44 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Fri, 6 Dec 2019 10:42:04 +0000 Subject: [PATCH 13/19] Add types + change doc syntaxt for hyperlink --- src/myrtlespeech/data/sampler.py | 53 ++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index d8d09d89..a20d8729 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -1,4 +1,8 @@ import random +from typing import Dict +from typing import Optional +from typing import Set +from typing import Union class SequentialRandomSampler: @@ -15,13 +19,12 @@ class SequentialRandomSampler: batch_size: Batch dimension. shuffle: Set to True to have the data reshuffled at every epoch if a random iterator is used. - drop_last: (optional, bool): Set to True to drop the last incomplete - batch, if the dataset size is not divisible by the batch size. If - False and the size of dataset is not divisible by the batch size, - then the last batch will be smaller. - n_iterators (optional, int): Number of iterators returned so far. - sequential (optional, set of int): Counts at which to return a - sequential iterator. + drop_last: Set to True to drop the last incomplete batch, if the + dataset size is not divisible by the batch size. If False and the + size of dataset is not divisible by the batch size, then the last + batch will be smaller. + n_iterators: Number of iterators returned so far. + sequential: Counts at which to return a sequential iterator. Yields: Batches from `batches`. @@ -29,19 +32,19 @@ class SequentialRandomSampler: def __init__( self, - indices, - batch_size, - shuffle, - drop_last=False, - n_iterators=0, - sequential=None, + indices: range, + batch_size: int, + shuffle: bool, + drop_last: Optional[bool] = False, + n_iterators: Optional[int] = 0, + sequential: Optional[set] = None, ): self.shuffle = shuffle self.batch_indices = self._batch_indices( indices, batch_size, drop_last ) - self._n_iterators = n_iterators - self._sequential = sequential or {} + self._n_iterators: Optional[int] = n_iterators + self._sequential: Union[Set, Dict] = sequential or {} def _batch_indices(self, indices, batch_size, drop_last): batches = [] @@ -73,26 +76,30 @@ class SortaGrad(SequentialRandomSampler): The SortaGrad curriculum learning strategy iterates over batches from the batched dataset sequentially for the first pass and then randomly for all other passes. See Deep Speech 2 paper for more information on this: - https://arxiv.org/abs/1512.02595 + `Deep Speech 2 `_ Args: indices: data with which batches are created. batch_size: Batch dimension. shuffle: Set to True to have the data reshuffled at every epoch if a random iterator is used. - drop_last: (optional, bool): Set to True to drop the last incomplete - batch, if the dataset size is not divisible by the batch size. If - False and the size of dataset is not divisible by the batch size, - then the last batch will be smaller. - start_epoch (optional, int): Number of iterators returned so far by the - sampler. + drop_last: Set to True to drop the last incomplete batch, if the + dataset size is not divisible by the batch size. If False and the + size of dataset is not divisible by the batch size, then the last + batch will be smaller. + start_epoch: Number of iterators returned so far by the sampler. Yields: Batches from `batches`. """ def __init__( - self, indices, batch_size, shuffle, drop_last=False, start_epoch=0 + self, + indices: range, + batch_size: int, + shuffle: bool, + drop_last: Optional[bool] = False, + start_epoch: Optional[int] = 0, ): super().__init__( indices, From ac09f5a9e3dc03a6c29e250e60a95c17e11ea5ff Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Fri, 6 Dec 2019 11:29:31 +0000 Subject: [PATCH 14/19] Add sampler documentation to docs --- docs/source/myrtlespeech/data/index.rst | 1 + docs/source/myrtlespeech/data/sampler.rst | 12 ++++++++++++ src/myrtlespeech/data/sampler.py | 4 ++-- 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 docs/source/myrtlespeech/data/sampler.rst diff --git a/docs/source/myrtlespeech/data/index.rst b/docs/source/myrtlespeech/data/index.rst index 46c185c1..5f4a72e8 100644 --- a/docs/source/myrtlespeech/data/index.rst +++ b/docs/source/myrtlespeech/data/index.rst @@ -9,3 +9,4 @@ alphabet dataset/index preprocess + sampler diff --git a/docs/source/myrtlespeech/data/sampler.rst b/docs/source/myrtlespeech/data/sampler.rst new file mode 100644 index 00000000..be12d73f --- /dev/null +++ b/docs/source/myrtlespeech/data/sampler.rst @@ -0,0 +1,12 @@ +============ + sampler +============ + +.. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler + :members: + :show-inheritance: + + +.. autoclass:: myrtlespeech.data.sampler.SortaGrad + :members: + :show-inheritance: diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index a20d8729..5480a0cb 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -27,7 +27,7 @@ class SequentialRandomSampler: sequential: Counts at which to return a sequential iterator. Yields: - Batches from `batches`. + Batches from `indices`. """ def __init__( @@ -90,7 +90,7 @@ class SortaGrad(SequentialRandomSampler): start_epoch: Number of iterators returned so far by the sampler. Yields: - Batches from `batches`. + Batches from `indices`. """ def __init__( From b4e03670bb43d9d2080fa63802808ab5bb0f66c2 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Fri, 6 Dec 2019 16:30:35 +0000 Subject: [PATCH 15/19] Update sampler tests to use Hypothesis --- src/myrtlespeech/data/sampler.py | 5 +- tests/data/test_sampler.py | 133 +++++++++++++++++++++---------- 2 files changed, 95 insertions(+), 43 deletions(-) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 5480a0cb..9ee447d3 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -1,5 +1,6 @@ import random from typing import Dict +from typing import List from typing import Optional from typing import Set from typing import Union @@ -32,7 +33,7 @@ class SequentialRandomSampler: def __init__( self, - indices: range, + indices: Union[range, List], batch_size: int, shuffle: bool, drop_last: Optional[bool] = False, @@ -95,7 +96,7 @@ class SortaGrad(SequentialRandomSampler): def __init__( self, - indices: range, + indices: Union[range, List], batch_size: int, shuffle: bool, drop_last: Optional[bool] = False, diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 9dcb28f4..39966780 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,55 +1,85 @@ +import random +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import hypothesis.strategies as st +from hypothesis import given from myrtlespeech.data.sampler import SequentialRandomSampler from myrtlespeech.data.sampler import SortaGrad -def dataset_gen(n_batches, batch_size, full_last_batch): - """Returns [0, ..., (n_batches*batch_size)-1-int(not(full_last_batch))]""" - indices = list(range(n_batches * batch_size)) - if not full_last_batch: +# Fixtures and Strategies ----------------------------------------------------- + + +@st.composite +def dataset_gen( + draw, return_kwargs: bool = False +) -> Union[st.SearchStrategy[List], st.SearchStrategy[Tuple[List, Dict]]]: + """Returns a SearchStrategy for a list of dataset indices.""" + kwargs = {} + kwargs["n_batches"] = draw(st.integers(10, 20)) + kwargs["batch_size"] = draw(st.integers(2, 16)) + kwargs["full_last_batch"] = draw(st.booleans()) + + # indices = [0, ..., (n_batches*batch_size)-1-int(not(full_last_batch))] + indices = list(range(kwargs["n_batches"] * kwargs["batch_size"])) + if not kwargs["full_last_batch"]: del indices[-1] - return indices + if not return_kwargs: + return indices + return indices, kwargs -def test_sorta_grad_correct_len(): - n_batches = 10 - batch_size = 16 - for full_last_batch in [True, False]: - dataset = sorted(dataset_gen(n_batches, batch_size, full_last_batch)) - sampler = SortaGrad( - dataset, - drop_last=full_last_batch, - batch_size=batch_size, - shuffle=False, - ) - assert len(sampler) == n_batches +# Tests ----------------------------------------------------------------------- -def test_sorta_grad_batches_non_empty(): - n_batches = 10 - batch_size = 16 +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_correct_len(dataset_kwargs: Tuple[List, Dict]): + dataset, kwargs = dataset_kwargs - for full_last_batch in [True, False]: - dataset = sorted(dataset_gen(n_batches, batch_size, full_last_batch)) - sampler = SortaGrad( - dataset, - drop_last=full_last_batch, - batch_size=batch_size, - shuffle=False, - ) - for batch in sampler: - assert len(batch) > 0 + dataset = sorted(dataset) + sampler = SortaGrad( + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=False, + ) + assert len(sampler) == kwargs["n_batches"] -def test_sorta_grad_first_pass_sequential_remaining_random(): - n_batches = 10 - batch_size = 16 - dataset = sorted(dataset_gen(n_batches, batch_size, False)) +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_batches_non_empty(dataset_kwargs: Tuple[List, Dict]): + dataset, kwargs = dataset_kwargs + dataset = sorted(dataset) + sampler = SortaGrad( + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=False, + ) + for batch in sampler: + assert len(batch) > 0 + + +@given(dataset_kwargs=dataset_gen(return_kwargs=True)) +def test_sorta_grad_first_pass_sequential_remaining_random( + dataset_kwargs: Tuple[List, Dict] +): + dataset, kwargs = dataset_kwargs + + dataset = sorted(dataset) sortagrad = SortaGrad( - dataset, drop_last=False, batch_size=batch_size, shuffle=True + dataset, + drop_last=kwargs["full_last_batch"], + batch_size=kwargs["batch_size"], + shuffle=True, ) + indices: list for pass_ in range(100): indices = [] for batch in sortagrad: @@ -63,15 +93,36 @@ def test_sorta_grad_first_pass_sequential_remaining_random(): assert indices != sorted(indices) -def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs(): - dataset = list(range(10)) - dataset_batches = [[elem] for elem in dataset] - n_iterators = 5 - sequential = {0, 2, 3, 7, 8, 10} +@given( + dataset_kwargs=dataset_gen(return_kwargs=True), + n_iterators=st.integers(min_value=1, max_value=10), + n_sequential=st.integers(min_value=1, max_value=10), + max_sequential=st.integers(min_value=11, max_value=20), +) +def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( + dataset_kwargs: Tuple[List, Dict], + n_iterators: int, + n_sequential: int, + max_sequential: int, +): + dataset, kwargs = dataset_kwargs + + dataset_batches = [] + batch = [] + for elem in dataset: + batch.append(elem) + if len(batch) == kwargs["batch_size"]: + dataset_batches.append(batch) + batch = [] + if batch and not kwargs["full_last_batch"]: + dataset_batches.append(batch) + sequential = set( + sorted(random.sample(range(max_sequential), n_sequential)) + ) seq_strat = SequentialRandomSampler( dataset, - batch_size=1, + batch_size=kwargs["batch_size"], shuffle=True, n_iterators=n_iterators, sequential=sequential, From 4b683faeeec6ca357ebc79a4c94dd17b288f0133 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 12 Dec 2019 10:27:11 +0000 Subject: [PATCH 16/19] Fix doc strings and parameter types --- src/myrtlespeech/data/sampler.py | 35 +++++++++++++++----------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/myrtlespeech/data/sampler.py b/src/myrtlespeech/data/sampler.py index 9ee447d3..8a0e08c2 100644 --- a/src/myrtlespeech/data/sampler.py +++ b/src/myrtlespeech/data/sampler.py @@ -1,22 +1,19 @@ import random -from typing import Dict -from typing import List +from typing import Iterable from typing import Optional -from typing import Set -from typing import Union class SequentialRandomSampler: - """A sequential or random iterable over batches. + """A sequential or random iterable over batches of indices. The iterator used each time this iterable is iterated over will yield - batches either sequentially (i.e. in-order) or randomly (uniform without - replacement) from `batches`. + batches of indices either sequentially (i.e. in-order) or randomly (uniform + without replacement). This iterable records the number of times it has returned an iterator. A sequential iterator is returned if the current count is in `sequential`. Args: - indices: data with which batches are created. + indices: Data with which batches are created. batch_size: Batch dimension. shuffle: Set to True to have the data reshuffled at every epoch if a random iterator is used. @@ -33,19 +30,19 @@ class SequentialRandomSampler: def __init__( self, - indices: Union[range, List], + indices: Iterable[int], batch_size: int, shuffle: bool, - drop_last: Optional[bool] = False, - n_iterators: Optional[int] = 0, + drop_last: bool = False, + n_iterators: int = 0, sequential: Optional[set] = None, ): self.shuffle = shuffle self.batch_indices = self._batch_indices( indices, batch_size, drop_last ) - self._n_iterators: Optional[int] = n_iterators - self._sequential: Union[Set, Dict] = sequential or {} + self._n_iterators = n_iterators + self._sequential = sequential or set() def _batch_indices(self, indices, batch_size, drop_last): batches = [] @@ -76,11 +73,11 @@ class SortaGrad(SequentialRandomSampler): The SortaGrad curriculum learning strategy iterates over batches from the batched dataset sequentially for the first pass and then randomly for all - other passes. See Deep Speech 2 paper for more information on this: - `Deep Speech 2 `_ + other passes. See `Deep Speech 2 `_ paper + for more information. Args: - indices: data with which batches are created. + indices: Data with which batches are created. batch_size: Batch dimension. shuffle: Set to True to have the data reshuffled at every epoch if a random iterator is used. @@ -96,11 +93,11 @@ class SortaGrad(SequentialRandomSampler): def __init__( self, - indices: Union[range, List], + indices: Iterable[int], batch_size: int, shuffle: bool, - drop_last: Optional[bool] = False, - start_epoch: Optional[int] = 0, + drop_last: bool = False, + start_epoch: int = 0, ): super().__init__( indices, From ea6c2f69a42d3728ead06a081ee79ccaaa2726b7 Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 12 Dec 2019 10:37:10 +0000 Subject: [PATCH 17/19] Let the sampler doc be auto-generated --- docs/source/myrtlespeech/data/sampler.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/myrtlespeech/data/sampler.rst b/docs/source/myrtlespeech/data/sampler.rst index be12d73f..e8eedbe9 100644 --- a/docs/source/myrtlespeech/data/sampler.rst +++ b/docs/source/myrtlespeech/data/sampler.rst @@ -1,6 +1,10 @@ -============ +========= sampler -============ +========= + +.. automodule:: myrtlespeech.data.sampler + :members: + :show-inheritance: .. autoclass:: myrtlespeech.data.sampler.SequentialRandomSampler :members: From e31db684592f29c4ed62aa588155033d9fb7e82e Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 12 Dec 2019 11:53:41 +0000 Subject: [PATCH 18/19] Make Hypothesis generate set of sequential epoch numbers --- tests/data/test_sampler.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 39966780..1d559dc1 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,4 +1,3 @@ -import random from typing import Dict from typing import List from typing import Tuple @@ -96,17 +95,21 @@ def test_sorta_grad_first_pass_sequential_remaining_random( @given( dataset_kwargs=dataset_gen(return_kwargs=True), n_iterators=st.integers(min_value=1, max_value=10), + sequential=st.lists( + range(st.integers(min_value=11, max_value=20)), + min_size=1, + max_size=st.integers(min_value=1, max_value=10), + unique=True, + ), n_sequential=st.integers(min_value=1, max_value=10), max_sequential=st.integers(min_value=11, max_value=20), ) def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( - dataset_kwargs: Tuple[List, Dict], - n_iterators: int, - n_sequential: int, - max_sequential: int, + dataset_kwargs: Tuple[List, Dict], n_iterators: int, sequential: list ): dataset, kwargs = dataset_kwargs - + print("type(sequential):", type(sequential)) + print("sequential:", sequential) dataset_batches = [] batch = [] for elem in dataset: @@ -116,16 +119,14 @@ def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( batch = [] if batch and not kwargs["full_last_batch"]: dataset_batches.append(batch) - sequential = set( - sorted(random.sample(range(max_sequential), n_sequential)) - ) + sequential_epochs = set(sorted(sequential)) seq_strat = SequentialRandomSampler( dataset, batch_size=kwargs["batch_size"], shuffle=True, n_iterators=n_iterators, - sequential=sequential, + sequential=sequential_epochs, ) for epoch in range(n_iterators, max(sequential) + 2): From e21db25c3fdb80c16a71fa210ab7ac29cb09dc0e Mon Sep 17 00:00:00 2001 From: giuseppeCoccia Date: Thu, 12 Dec 2019 15:15:15 +0000 Subject: [PATCH 19/19] Add separate function to create list of sequential epoch numbers --- tests/data/test_sampler.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 1d559dc1..d33742b9 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -32,6 +32,23 @@ def dataset_gen( return indices, kwargs +@st.composite +def sequential_epochs_gen(draw) -> st.SearchStrategy[List]: + """Returns a SearchStrategy for a list of sequential epoch numbers.""" + max_size = draw(st.integers(min_value=1, max_value=10)) + + sequential = draw( + st.lists( + elements=st.integers(min_value=11, max_value=20), + min_size=1, + max_size=max_size, + unique=True, + ) + ) + + return sequential + + # Tests ----------------------------------------------------------------------- @@ -95,21 +112,13 @@ def test_sorta_grad_first_pass_sequential_remaining_random( @given( dataset_kwargs=dataset_gen(return_kwargs=True), n_iterators=st.integers(min_value=1, max_value=10), - sequential=st.lists( - range(st.integers(min_value=11, max_value=20)), - min_size=1, - max_size=st.integers(min_value=1, max_value=10), - unique=True, - ), - n_sequential=st.integers(min_value=1, max_value=10), - max_sequential=st.integers(min_value=11, max_value=20), + sequential=sequential_epochs_gen(), ) def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( - dataset_kwargs: Tuple[List, Dict], n_iterators: int, sequential: list + dataset_kwargs: Tuple[List, Dict], n_iterators: int, sequential: List ): dataset, kwargs = dataset_kwargs - print("type(sequential):", type(sequential)) - print("sequential:", sequential) + dataset_batches = [] batch = [] for elem in dataset: @@ -129,13 +138,13 @@ def test_sequential_strategy_seq_iter_when_epoch_in_seq_epochs( sequential=sequential_epochs, ) - for epoch in range(n_iterators, max(sequential) + 2): + for epoch in range(n_iterators, max(sequential_epochs) + 2): sampler_batches = [batch for batch in iter(seq_strat)] assert len(sampler_batches) == len(dataset_batches) assert sorted(sampler_batches) == sorted(dataset_batches) - if epoch in sequential: + if epoch in sequential_epochs: assert all( sample_batch == dataset_batch for sample_batch, dataset_batch in zip(